Compute (Target) Permutation Importances of a machine learning model
Project description
Target Permutation Importances
Overview
This method aims at lower the feature attribution due to the variance of a feature. If a feature is important after the target vector is shuffled, it is fitting to noise.
Overall, this package
- Fit the given model class $M$ times to get $M$ actual feature importances ($A$).
- Fit the given model class with shuffled targets for $N$ times to get $N$ feature random importances ($R$).
- Compute the final importances by various methods, such as:
- $A - R$
- $A / (MinMaxScale(R) + 1)$
Not to be confused with sklearn.inspection.permutation_importance, this sklearn method is about feature permutation instead of target permutation.
This method were originally proposed/implemented by:
- Permutation importance: a corrected feature importance measure
- Feature Selection with Null Importances
Basic Usage
# Import the function
from target_permutation_importances import compute
# Prepare a dataset
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
data = load_breast_cancer()
# Compute permutation importances with default settings
result_df = compute(
# RandomForestClassifier, XGBClassifier, CatBoostClassifier, LGBMClassifier...
model_cls=RandomForestClassifier,
model_cls_params={ # The params for the model class construction
"n_estimators": 1,
},
model_fit_params={}, # The params for model.fit
X=Xpd,
y=data.target,
num_actual_runs=2,
num_random_runs=10,
permutation_importance_calculator=compute_permutation_importance_by_subtraction,
)
You can find more detailed examples in the "Feature Selection Examples" section.
Advance Usage / Customization
Instead of calling compute
this package also expose generic_compute
to allow customization.
Read target_permutation_importances.__init__
for details.
Feature Selection Examples
TODO
Benchmarks
Benchmark has been done with some tabular datasets from the Tabular data learning benchmark. It is also hosted on Hugging Face.
The following models with their default params are used in the benchmark:
sklearn.ensemble.RandomForestClassifier
sklearn.ensemble.RandomForestRegressor
xgboost.XGBClassifier
xgboost.XGBRegressor
catboost.CatBoostClassifier
catboost.CatBoostRegressor
lightgbm.LGBMClassifier
lightgbm.LGBMRegressor
For binary classification task, sklearn.metrics.f1_score
is used for evaluation. For regression task, sklearn.metrics.mean_squared_error
is used for evaluation.
The downloaded datasets are divided into 3 sections: train
: 50%, val
: 10%, test
: 40%.
Feature importance is calculated from the train
set. Feature selection is done on the val
set.
The final benchmark is evaluated on the test
set. Therefore the test
set is unseen to both the feature importance and selection process.
Raw result data are in target-permutation-importances/benchmarks/results/tabular_benchmark.csv
.
Kaggle Competitions
Many Kaggle Competition top solutions involve this method, here are some examples
Year | Competition | Medal | Link |
---|---|---|---|
2023 | Predict Student Performance from Game Play | Gold | 3rd place solution |
2019 | Elo Merchant Category Recommendation | Gold | 16th place solution |
2018 | Home Credit Default Risk | Gold | 10th place solution |
Development Setup and Contribution Guide
Python Version
You can find the suggested development Python version in .python-version
.
You might consider setting up Pyenv
if you want to have multiple Python versions in your machine.
Python packages
This repository is setup with Poetry
. If you are not familiar with Poetry, you can find packages requirements are listed in pyproject.toml
.
Otherwise, you can just set up with poetry install
Run Benchmarks
To run benchmark locally on your machine, run make run_tabular_benchmark
or python -m benchmarks.run_tabular_benchmark
Make Changes
Following the Make Changes Guide from Github
Before committing or merging, please run the linters defined in make lint
and the tests defined in make test
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for target_permutation_importances-1.0.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 65d83cedee0c0ba4c00bd28374a5cd27545c3d3e575b120ac4f3c967f4241915 |
|
MD5 | 37dccce917dd899e4c34754890465849 |
|
BLAKE2b-256 | 9554ba06e5d714920ad19a85ee5f6f9d28e83103b9b74ae86a77ccb9c0d7f204 |
Hashes for target_permutation_importances-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 33abdb1b3434fc7dda4b3e3db2356a42c08838c4efe93ba33928a0d90e1544fc |
|
MD5 | 1b4041da7509b2fb72fc997c57c9ce33 |
|
BLAKE2b-256 | 5135d3c2d0f4ef3ac9878d3f3367d97d1cab792da60c26989f4cb580f5afd663 |