Captum: A unified and generic model interpretability library for PyTorch

Narine Kokhlikyan, Vivek Miglani, Miguel Martin, Edward Wang, Bilal Alsallakh, Jonathan Reynolds, Alexander Melnikov, Natalia Kliushkina, Carlos Araya, Siqi Yan, Orion Reblitz-Richardson

Introduction

Given the complexity and black-box nature of NN models, there is a strong demand for clear understanding of how those models reason. Model interpretability aims to describe model internals in human understandable terms and is an important field of Explainable AI . While building interpretable models is encouraged, many existing state-of-the-art NNs are not designed to be interpretable, thus the development of algorithms that explain black-box models becomes highly desirable. This is particularly important when AI is used in domains such as healthcare, finance or self-driving vehicles where establishing trust in AI-driven systems is critical.

Some of the fundamental approaches that interpret black-box models are feature, neuron and layer importance algorithms, also known as attribution algorithms. Existing frameworks such as DeepExplain , Alibi and InterpretML have been developed to unify those algorithms in one framework and make them accessible to all machine learning model developers and practitioners. These frameworks, however, have insufficient support for PyTorch models. In Captum, we provide generic implementations of a number of gradient and perturbation-based attribution algorithms that can be applied to any PyTorch model of any modality. The library is easily extensible and lets users scale computations across multiple GPUs and handles large-sized input by dividing them into smaller pieces, thereby preventing out of memory situations.

Another important aspect is that both qualitative and quantitative evaluation of attributions are difficult. Visual explanations can be misleading and evaluation metrics are subjective or domain specific . To address these issues, we provide generic implementations of two evaluation metrics called infidelity and max-sensitivity proposed in . These metrics can be used in combination with any PyTorch model and most attribution algorithms.

Lastly, model understanding research largely focuses on the Computer Vision (CV) domain whereas there are many unexplored NN applications that desperately need model understanding tools. Adapting CV-specific implementations for those applications is not always straightforward, thus the need for a well-tested and generic library that can be easily applied to multiple domains across research and production.

An Overview of the Algorithms

The attribution algorithms in Captum can be grouped into three main categories: primary-, neuron- and layer- attributions, as shown in Figure 1. Primary attribution algorithms are the traditional feature importance algorithms that allow us to attribute output predictions to model inputs. Layer attribution variants allow us to attribute output predictions to all neurons in a hidden layer and neuron attribution methods allow us to attribute an internal, hidden neuron to the inputs of the model. In most cases, both neuron and layer variants are slight modifications of the primary attribution algorithms.

Most attribution algorithms in Captum can be categorized into gradient and perturbation-based approaches as also depicted in Figure 2. Some of these algorithms such as GradCam , GuidedGradCam , GuidedBackProp , Deconvolution and Occlusion are more popular in the CV community, where they stem from. However, our implementations of those algorithms are generic and can be applied to any model that meets certain requirements dictated by those approaches. For example, GradCam and GuidedGradCam only makes sense for convolutional models.

NoiseTunnel includes generic implementations of SmoothGrad, SmoothGrad Square and VarGrad smoothing techniques proposed in . These methods help to mitigate noise in the attributions and can be used in combination with most attribution algorithms depicted in Figure 2.

The attribution quality of a number of algorithms such as Integrated Gradients , DeepLift , SHAP variants , Feature Ablation and Occlusion depend on the choice of baseline, also known as reference, that needs to be carefully chosen by the user. Baselines express the absence of some input feature and are an integral part of many feature importance equations. For example, black and white images or the average of those two are common baselines for image classification tasks.

From the implementation and usage perspective, all algorithms follow a unified API and signature. This makes it easy to compare the algorithms and switch from one attribution approach to another. The code snippets in Figure 1 demonstrate examples of how to use primary, neuron and layer attribution algorithms in Captum.

Many state-of-the-art Neural Networks with a large number of model parameters use large-sized inputs which, ultimately, leads to computationally expensive forward and backward passes. We want to make sure we leverage available memory and CPU/GPU resources efficiently when performing attribution. In order to avoid out of memory situations for certain algorithms, especially the ones with internal input expansion, we slice inputs into smaller chunks, perform the computations sequentially on each chunk and aggregate resulting attributions. This is especially useful for algorithms such as Integrated Gradients and Layer Conductance because they internally expand the inputs based on the number of integral approximation steps. Being able to chunk the inputs into smaller sizes can theoretically allow us to perform integral approximation for infinite number of steps.

In case of feature perturbation, if the inputs are small and we have enough memory resources available, we can perturb multiple features together in one input batch. This requires that we expand the inputs by the number of features that we perturb together and helps to improve runtime performance of all our feature perturbation algorithms.

In addition to this, all algorithms support PyTorch DataParallel, which performs model forward and backward passes simultaneously on multiple GPUs and improves attribution runtime significantly. We made this available for all layer and neuron attribution algorithms, including Layer Activation and NoiseTunnel.

In one of our experiments, we used Integrated Gradients on a pre-trained VGG19 model, for a single 3 x 224 x 224 input image in a GPU environment that has 8 GPUs, each 16GB memory available and gradually increased the number of GPUs while keeping the number of integral approximation steps constant (in this case 2990). From the second column of Table 1 we can see how the execution time decreases substantially as we increase the number of GPUs. In the second experiment we used the same execution environment, pre-trained model, input image and performed feature ablation by ablating multiple features in one batch using model’s single forward pass. Based on the experimental results shown in third column of Table 1 we can tell that by increasing the number of GPUs from 1 to 8 the execution time drops by approximately 85%.

Evaluation

Global attribution functions describe marginal effects of the inputs on the outputs of the model with respect to a chosen baseline . Examples of such attribution functions are Integrated Gradients, DeepLift, SHAP variant of DeepLift and Gradient SHAP aka Expected Gradients . The mathematical formulations of those methods multiply the resulting saliency maps by xx0x-x_{0}.

In the next section we demonstrate results of these two metrics in different applications.

Applications

Classification models are not the only types of models that Captum supports. As an example, we built a regression model using Boston house prices dataset and a simple four layer NN using linear layers and ReLUs. We attributed the output predictions to the last linear layer using layer conductance algorithm and plotted them together with the learned weights of that same layer as shown in Figure 4. Here we can see that the weights and the attribution scores are aligned with each other. Both scores were normalized using L1L_{1} norm.

In order to improve model debugging experience, we developed an interactive visualization tool called Captum Insights. The tool allows to sub-sample input examples and interactively attribute different output classes to the inputs of the model using different types of attribution algorithms depicted in Figure 5.

The support for multi-modal Neural Networks is one of the core motivations of Captum library. This allows us to apply Captum to machine-learning models that are built using features stemming from different sources such as audio, video, image, text, categorical or dense features. Aggregated feature importance scores for each input modality can reveal which modalities are most impactful. In the example of attributions computed for a multi-modal Visual Question Answering model depicted in Figure 5, we can tell whether the stronger predictive signal is coming from text or image.

Conclusion

We presented a unified model interpretability library for PyTorch, called Captum, that supports generic implementations of a number of gradient and perturbation -based attribution algorithms. Captum can be applied to NN models of any type and used both in research and production environments. Furthermore, we described how the computations are scaled and how we handle large-sized inputs. In addition to that, we also added support for two generic quantitative evaluation metrics called infidelity and maximum-sensitivity. Lastly, we walked through different types of applications including multi-modal models and introduced Captum Insights model debugging tool.

Future Work

Our future work involves both expanding the list of attribution algorithms and looking beyond attribution methods for model understanding. Beyond feature, neuron and layer attribution we are also looking into adversarial robustness and the intersection between these two fields of research. Concept-based model interpretability that aims to explain the models globally using human understandable concepts is another interesting direction to explore. Besides that, visualizing high-dimensional embedding vectors in the latent layers and being able to debug the models and understand what information a single or a group of neurons in the layers encode are other interesting avenues to explore.

Acknowledgements

We are grateful to all Captum contributors and users both external and internal, and PyTorch community for their support, contributions and feedback. Their input and contributions helped us to create a self-contained model understanding library that can be used both in research and production. We would like to say thank you to Soumith Chintala, Joe Spisak, Alban Desmaison and Francisco Massa from core PyTorch team for supporting us with Open Source and PyTorch related questions. Davide Testuggine for his support and initial discussions on model interpretability and integrated gradients. We would also like to thank Tucker Hart for implementing the initial prototype of Captum Insights, Jessica Lin for helping us with the documentation and Fuchun Peng for reviewing this paper and providing valuable feedback.

References