SHAP Explainer

Overview

The SHAP (SHapley Additive exPlanations) explainer provides model-agnostic explanations for tabular molecular data by calculating Shapley values for each feature. SHAP values represent the contribution of each feature to the model's prediction.

Class: Shap

Initialization

from chemxai.explainers import Shap

explainer = Shap(
    model=model,
    background_tensor=background_data,
    test_tensor=test_data,
    device=device
)

Parameters

  • model (torch.nn.Module): The trained model to be explained
  • background_tensor (torch.Tensor): Background data used to establish baseline for SHAP calculations
  • test_tensor (torch.Tensor): Test data to be explained
  • device (torch.device): Device for computations (CPU or GPU)

Methods

explain_local(index)

Generates local explanations for a specific instance by calculating SHAP values.

Parameters: - index (int): Index of the instance in the test set to be explained

Returns: - list: SHAP values for each feature of the specified instance

Example:

# Explain the first instance
explanation = explainer.explain_local(index=0)
print(f"SHAP values: {explanation}")

explain_global()

Generates global explanations by calculating the mean absolute SHAP values across all instances.

Returns: - list: Mean absolute SHAP values for each feature across all instances

Example:

# Get global feature importance
global_explanation = explainer.explain_global()
print(f"Global feature importance: {global_explanation}")

Usage Example

import torch
from chemxai.explainers import Shap
from chemxai.data import qm9_tabular

# Load data
qm9 = qm9_tabular()
train_loader, _, test_loader, _, _, _, _ = qm9.get_paired_dataloaders_tabular(
    batch_size=32, 
    descriptor_type='Morgan', 
    morgan_radius=3, 
    morgan_nBits=512
)

# Get background and test data
background_batch = next(iter(train_loader))
test_batch = next(iter(test_loader))

# Initialize explainer
explainer = Shap(
    model=your_model,
    background_tensor=background_batch[0],
    test_tensor=test_batch[0],
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Generate local explanation
local_explanation = explainer.explain_local(index=0)

# Generate global explanation
global_explanation = explainer.explain_global()

Technical Details

SHAP Values Calculation

The explainer uses KernelExplainer from the SHAP library, which:

  1. Takes the background data to establish a baseline
  2. Calculates Shapley values by evaluating the model on coalitions of features
  3. Returns values that sum to the difference between the prediction and the baseline

Model Compatibility

  • Works with any PyTorch model that accepts tabular input
  • Automatically handles tensor conversion between PyTorch and NumPy
  • Supports both regression and classification tasks

Performance Considerations

  • Background data size affects computation time
  • Larger test sets will take longer to process
  • GPU acceleration is supported for model inference