Skip to main content

Compute (Target) Permutation Importances of a machine learning model

Project description

Target Permutation Importances

Overview

This method aims at lower the feature attribution due to the variance of a feature. If a feature is important after the target vector is shuffled, it is fitting to noise.

Overall, this package

  1. Fit the given model class $M$ times to get $M$ actual feature importances ($A$).
  2. Fit the given model class with shuffled targets for $N$ times to get $N$ feature random importances ($R$).
  3. Compute the final importances by various methods, such as:
    • $A - R$
    • $A / (MinMaxScale(R) + 1)$

Not to be confused with sklearn.inspection.permutation_importance, this sklearn method is about feature permutation instead of target permutation.

This method were originally proposed/implemented by:

Basic Usage

# Import the function
from target_permutation_importances import compute

# Prepare a dataset
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
import pandas as pd

data = load_breast_cancer()

# Compute permutation importances with default settings
result_df = compute(
    # RandomForestClassifier, XGBClassifier, CatBoostClassifier, LGBMClassifier...
    model_cls=RandomForestClassifier,
    model_cls_params={ # The params for the model class construction
        "n_estimators": 1,
    },
    model_fit_params={}, # The params for model.fit
    X=Xpd,
    y=data.target,
    num_actual_runs=2,
    num_random_runs=10,
    permutation_importance_calculator=compute_permutation_importance_by_subtraction,
)

You can find more detailed examples in the "Feature Selection Examples" section.

Advance Usage / Customization

Instead of calling compute this package also expose generic_compute to allow customization. Read target_permutation_importances.__init__ for details.

Feature Selection Examples

TODO

Benchmarks

Benchmark has been done with some tabular datasets from the Tabular data learning benchmark. It is also hosted on Hugging Face.

The following models with their default params are used in the benchmark:

  • sklearn.ensemble.RandomForestClassifier
  • sklearn.ensemble.RandomForestRegressor
  • xgboost.XGBClassifier
  • xgboost.XGBRegressor
  • catboost.CatBoostClassifier
  • catboost.CatBoostRegressor
  • lightgbm.LGBMClassifier
  • lightgbm.LGBMRegressor

For binary classification task, sklearn.metrics.f1_score is used for evaluation. For regression task, sklearn.metrics.mean_squared_error is used for evaluation.

The downloaded datasets are divided into 3 sections: train: 50%, val: 10%, test: 40%. Feature importance is calculated from the train set. Feature selection is done on the val set. The final benchmark is evaluated on the test set. Therefore the test set is unseen to both the feature importance and selection process.

Raw result data are in target-permutation-importances/benchmarks/results/tabular_benchmark.csv.

Kaggle Competitions

Many Kaggle Competition top solutions involve this method, here are some examples

Year Competition Medal Link
2023 Predict Student Performance from Game Play Gold 3rd place solution
2019 Elo Merchant Category Recommendation Gold 16th place solution
2018 Home Credit Default Risk Gold 10th place solution

Development Setup and Contribution Guide

Python Version

You can find the suggested development Python version in .python-version. You might consider setting up Pyenv if you want to have multiple Python versions in your machine.

Python packages

This repository is setup with Poetry. If you are not familiar with Poetry, you can find packages requirements are listed in pyproject.toml. Otherwise, you can just set up with poetry install

Run Benchmarks

To run benchmark locally on your machine, run make run_tabular_benchmark or python -m benchmarks.run_tabular_benchmark

Make Changes

Following the Make Changes Guide from Github Before committing or merging, please run the linters defined in make lint and the tests defined in make test

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

target_permutation_importances-1.0.0.tar.gz (4.4 kB view hashes)

Uploaded Source

Built Distribution

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page