.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_inspection_plot_permutation_importance.py>`     to download the full example code
    .. rst-class:: sphx-glr-example-title

    .. _sphx_glr_auto_examples_inspection_plot_permutation_importance.py:


================================================================
Permutation Importance vs Random Forest Feature Importance (MDI)
================================================================

In this example, we will compare the impurity-based feature importance of
:class:`~sklearn.ensemble.RandomForestClassifier` with the
permutation importance on the titanic dataset using
:func:`~sklearn.inspection.permutation_importance`. We will show that the
impurity-based feature importance can inflate the importance of numerical
features.

Furthermore, the impurity-based feature importance of random forests suffers
from being computed on statistics derived from the training dataset: the
importances can be high even for features that are not predictive of the target
variable, as long as the model has the capacity to use them to overfit.

This example shows how to use Permutation Importances as an alternative that
can mitigate those limitations.

.. topic:: References:

   [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
       2001. https://doi.org/10.1023/A:1010933404324


.. code-block:: default

    print(__doc__)
    import matplotlib.pyplot as plt
    import numpy as np

    from sklearn.datasets import fetch_openml
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.impute import SimpleImputer
    from sklearn.inspection import permutation_importance
    from sklearn.compose import ColumnTransformer
    from sklearn.model_selection import train_test_split
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import OneHotEncoder









Data Loading and Feature Engineering
------------------------------------
Let's use pandas to load a copy of the titanic dataset. The following shows
how to apply separate preprocessing on numerical and categorical features.

We further include two random variables that are not correlated in any way
with the target variable (``survived``):

- ``random_num`` is a high cardinality numerical variable (as many unique
  values as records).
- ``random_cat`` is a low cardinality categorical variable (3 possible
  values).


.. code-block:: default

    X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
    rng = np.random.RandomState(seed=42)
    X['random_cat'] = rng.randint(3, size=X.shape[0])
    X['random_num'] = rng.randn(X.shape[0])

    categorical_columns = ['pclass', 'sex', 'embarked', 'random_cat']
    numerical_columns = ['age', 'sibsp', 'parch', 'fare', 'random_num']

    X = X[categorical_columns + numerical_columns]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, stratify=y, random_state=42)

    categorical_pipe = Pipeline([
        ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
        ('onehot', OneHotEncoder(handle_unknown='ignore'))
    ])
    numerical_pipe = Pipeline([
        ('imputer', SimpleImputer(strategy='mean'))
    ])

    preprocessing = ColumnTransformer(
        [('cat', categorical_pipe, categorical_columns),
         ('num', numerical_pipe, numerical_columns)])

    rf = Pipeline([
        ('preprocess', preprocessing),
        ('classifier', RandomForestClassifier(random_state=42))
    ])
    rf.fit(X_train, y_train)



.. rst-class:: sphx-glr-script-out


.. code-block:: pytb

    Traceback (most recent call last):
      File "/usr/lib/python3/dist-packages/sphinx_gallery/gen_gallery.py", line 159, in call_memory
        return 0., func()
      File "/usr/lib/python3/dist-packages/sphinx_gallery/gen_rst.py", line 466, in __call__
        exec(self.code, self.fake_main.__dict__)
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/examples/inspection/plot_permutation_importance.py", line 53, in <module>
        X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/utils/validation.py", line 72, in inner_f
        return f(**kwargs)
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 738, in fetch_openml
        data_info = _get_data_info_by_name(name, version, data_home)
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 381, in _get_data_info_by_name
        json_data = _get_json_content_from_openml_api(url, None, False,
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 161, in _get_json_content_from_openml_api
        return _load_json()
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 61, in wrapper
        return f(*args, **kw)
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 157, in _load_json
        with closing(_open_openml_url(url, data_home)) as response:
      File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 106, in _open_openml_url
        with closing(urlopen(req)) as fsrc:
      File "/usr/lib/python3.8/urllib/request.py", line 222, in urlopen
        return opener.open(url, data, timeout)
      File "/usr/lib/python3.8/urllib/request.py", line 525, in open
        response = self._open(req, data)
      File "/usr/lib/python3.8/urllib/request.py", line 542, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/usr/lib/python3.8/urllib/request.py", line 502, in _call_chain
        result = func(*args)
      File "/usr/lib/python3.8/urllib/request.py", line 1393, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/usr/lib/python3.8/urllib/request.py", line 1353, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>




Accuracy of the Model
---------------------
Prior to inspecting the feature importances, it is important to check that
the model predictive performance is high enough. Indeed there would be little
interest of inspecting the important features of a non-predictive model.

Here one can observe that the train accuracy is very high (the forest model
has enough capacity to completely memorize the training set) but it can still
generalize well enough to the test set thanks to the built-in bagging of
random forests.

It might be possible to trade some accuracy on the training set for a
slightly better accuracy on the test set by limiting the capacity of the
trees (for instance by setting ``min_samples_leaf=5`` or
``min_samples_leaf=10``) so as to limit overfitting while not introducing too
much underfitting.

However let's keep our high capacity random forest model for now so as to
illustrate some pitfalls with feature importance on variables with many
unique values.


.. code-block:: default

    print("RF train accuracy: %0.3f" % rf.score(X_train, y_train))
    print("RF test accuracy: %0.3f" % rf.score(X_test, y_test))



Tree's Feature Importance from Mean Decrease in Impurity (MDI)
--------------------------------------------------------------
The impurity-based feature importance ranks the numerical features to be the
most important features. As a result, the non-predictive ``random_num``
variable is ranked the most important!

This problem stems from two limitations of impurity-based feature
importances:

- impurity-based importances are biased towards high cardinality features;
- impurity-based importances are computed on training set statistics and
  therefore do not reflect the ability of feature to be useful to make
  predictions that generalize to the test set (when the model has enough
  capacity).


.. code-block:: default

    ohe = (rf.named_steps['preprocess']
             .named_transformers_['cat']
             .named_steps['onehot'])
    feature_names = ohe.get_feature_names(input_features=categorical_columns)
    feature_names = np.r_[feature_names, numerical_columns]

    tree_feature_importances = (
        rf.named_steps['classifier'].feature_importances_)
    sorted_idx = tree_feature_importances.argsort()

    y_ticks = np.arange(0, len(feature_names))
    fig, ax = plt.subplots()
    ax.barh(y_ticks, tree_feature_importances[sorted_idx])
    ax.set_yticklabels(feature_names[sorted_idx])
    ax.set_yticks(y_ticks)
    ax.set_title("Random Forest Feature Importances (MDI)")
    fig.tight_layout()
    plt.show()



As an alternative, the permutation importances of ``rf`` are computed on a
held out test set. This shows that the low cardinality categorical feature,
``sex`` is the most important feature.

Also note that both random features have very low importances (close to 0) as
expected.


.. code-block:: default

    result = permutation_importance(rf, X_test, y_test, n_repeats=10,
                                    random_state=42, n_jobs=2)
    sorted_idx = result.importances_mean.argsort()

    fig, ax = plt.subplots()
    ax.boxplot(result.importances[sorted_idx].T,
               vert=False, labels=X_test.columns[sorted_idx])
    ax.set_title("Permutation Importances (test set)")
    fig.tight_layout()
    plt.show()


It is also possible to compute the permutation importances on the training
set. This reveals that ``random_num`` gets a significantly higher importance
ranking than when computed on the test set. The difference between those two
plots is a confirmation that the RF model has enough capacity to use that
random numerical feature to overfit. You can further confirm this by
re-running this example with constrained RF with min_samples_leaf=10.


.. code-block:: default

    result = permutation_importance(rf, X_train, y_train, n_repeats=10,
                                    random_state=42, n_jobs=2)
    sorted_idx = result.importances_mean.argsort()

    fig, ax = plt.subplots()
    ax.boxplot(result.importances[sorted_idx].T,
               vert=False, labels=X_train.columns[sorted_idx])
    ax.set_title("Permutation Importances (train set)")
    fig.tight_layout()
    plt.show()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.010 seconds)


.. _sphx_glr_download_auto_examples_inspection_plot_permutation_importance.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_permutation_importance.py <plot_permutation_importance.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_permutation_importance.ipynb <plot_permutation_importance.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
