Using Shapley Values to explain your ML models

You work for a fitness centre. Let’s say you’ve recently deployed a machine learning model to predict whether a customer will churn at the end of their current contract. Your input features to the model are:

  • Average times visited per week over the last month
  • Average times visited per week over the last 6 months
  • Cost of membership
  • Number of fitness classes attended per week
  • etc…

Using these features you achieve 78% accuracy on your model. So, when you deploy it, you think that the marketing department is going to love it – they’ll know who to actively work to retain. However, the first thing they asked is ‘why did you predict that person was going to churn?’.

Now, we can look at the feature importances of the trained model, but it doesn’t give us a particularly good understanding at a per-prediction level. For this, we can use Shapley values.

Shapley values help us to understand how much each feature contributed to the overall prediction. This is calculated by predicting the likelihood to churn with all of the features. We then, remove one feature at a time (replacing it with a different representitive value from our underlying dataset) to understand the impact of that feature – thus, understanding the contribution that it made to the overall prediction.

Effectively, it allows us to decompose any prediction into the sum of its feature effects. Let’s look at an example. In the below, I have trained a random forest classifier, based on my train and test labels and then made predictions based on the test features.

model = RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=20,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False), train_labels)
predictions = model.predict(test_features)

Then, I have executed the below piece of code. This sets up our shap plots, based on the model we have trained above and the test_features from our dataset. In the final line, you will notice I have selected shap_values[1] – this is because, for classifiers, there is a different array of SHAP values for each of the classes we are aiming to predict. In this case, shap_values[1] will show the shap values for the predictions of ‘true’.

import shap 
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(test_features)
shap.summary_plot(shap_values[1], test_features)

Above is the output from this code – each row represents one row of the data (one prediction). Here:

  • When the value is red, it means that the feature value was high
  • The more to the right the bar goes, the more a positive value contributes towards a true prediction
  • The more to the left the bar goes, the more a negative value contributes towards a true prediction

The above is a great way to understand feature impact at a global level (across many/all predictions). We can drill down further and make a per-prediction analysis. We can do this using a force plot, as below.

Here, we take the:

  • shap.initjs() – initializing the javascript plotting library
  • explainer.expected_value[1] – for class 1 (True)
  • shap_values[1][o] – these are the shap values for a given class. So the first index [1] is the shap values for class 1 and the second value is the prediction number (row number).
  • pd.DataFrame(test_features.iloc[o]).T – the row of the dataframe to use.
test_features = pd.DataFrame(test_features)
def shap_plot(j):
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(test_features)
    p = shap.force_plot(explainer.expected_value[0], shap_values[1][o], pd.DataFrame(test_features.iloc[o]).T, feature_names = test_features.columns)

The output of this plot looks something like the below. Here we see how each feature impacted this prediction.

Hopefully this was a useful article about Shap/Shapley values.

Share the Post:

Related Posts