Interactive Individual Conditional Expectation (ICE) plots

This post is not about a new technique or package, but rather combining existing functionality in interpretable machine learning and data visualization in a way to facilitate analyses of model results. We’ll make use of two packages DALEX and PLOTLY ot create interactive Individual Conditional Expectation (ICE) plots show how to use them to find interesting behavior. Let’s take a random forest (RF) trained on an imputed version of the titanic data as an example, on which we create a DALEX explainer.

titanic_imputed <- aread("pbiecek/models/27e5c")
titanic_rf <- aread("pbiecek/models/4e0fc")
target_as_int <- as.integer(titanic_imputed$survived)
explainer_rf <- DALEX::explain(model = titanic_rf,  
                               data = titanic_imputed[, -9],
                               y = target_as_int, 
                               label = "Random Forest")
## Preparation of a new explainer is initiated
##   -> model label       :  Random Forest 
##   -> data              :  2207  rows  8  cols 
##   -> target variable   :  2207  values 
##   -> predict function  :  yhat.randomForest  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package randomForest , ver. 4.6.14 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0 , mean =  0.2353095 , max =  1  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  0.108 , mean =  1.086847 , max =  2  
##   A new explainer has been created! 

A typical way to understand the dependence of the prediction on a single variable, say age, is a partial dependence plot (PDP).

plot(model_profile(explainer = explainer_rf, variables = "age", N = dim(titanic_imputed)[1]))

While a PDP provides a reasonable average summary, it suppresses the variation for single observations. An ICE plot visualizes the for a subset of all observations - basically, it shows the result of the data perturbation BEFORE averaging. It is therefore essentially an intermediate step on the way to a PDP. Note that these plots are called ceteris paribus profiles in DALEX.

dep_age <- model_profile(explainer = explainer_rf, variables = "age")
plot(dep_age,geom="profiles") + ggtitle("ICE plot of RF for age")

Unfortunately, the wealth of information makes it hard to draw any direct conclusions from this. This is where the plotly package comes into play. By looking at an interactive version of the plot for a reasonable number of observations, we can find interesting cases. For example, note how the observation in row 1230 has a high peak at age 52 that is more pronounced than on average.

dep_age <- model_profile(explainer = explainer_rf, variables = "age", N = 10)
ggplotly(plot(dep_age,geom="profiles") + ggtitle("ICE plot for 10 randomly selected rows"))

The underlying observation has age 61 and did not survive, which corresponds to a very low probability of survival. Thus our model is saying that the predicted probability of this person would have been a lot higher (~28 percentage points) if the person had been 9 years younger.

##      gender age class    embarked       country fare sibsp parch survived
## 1230   male  61   1st Southampton United States 33.1     0     0       no
cb_1230 <- predict_profile(explainer_rf, titanic_imputed[1230,])
ggplotly(plot(cb_1230,variables="age")+ggtitle("Certeris paribus plot for observation 1230"))

While this is an interesting observation already, we can further break down effects by creating this fictive person with age 52.

person_1230_age_52 <-titanic_imputed[1230,]
person_1230_age_52$age <- 52
predict(titanic_rf, person_1230_age_52,type="prob")
##         no   yes
## 1230 0.684 0.316
## attr(,"class")
## [1] "matrix" "array"  "votes"

A breakdown plot now shows that indeed, age is the main driver for the change in predicted survival probability.

grid.arrange(plot(predict_parts(explainer_rf, person_1230_age_52))+ggtitle("Age 52"),
             plot(predict_parts(explainer_rf,titanic_imputed[1230,]))+ggtitle("Age 61"),nrow=1)

However, some other (smaller) effects change as well. For embarked, an breakdown version that quantifies also interactions between variables shows that this is also mainly driven by age:

grid.arrange(plot(predict_parts(explainer_rf, person_1230_age_52,type="break_down_interactions"))+ggtitle("Age 52"),
             plot(predict_parts(explainer_rf,titanic_imputed[1230,], type="break_down_interactions"))+ggtitle("Age 61"),nrow=1)

So, even we did not re-invent the wheel and simply combined well known existing approaches, I still hope this techniques is helpful for some of you.