SHAP Explainer¶
Overview¶
The SHAP (SHapley Additive exPlanations) explainer provides model-agnostic explanations for tabular molecular data by calculating Shapley values for each feature. SHAP values represent the contribution of each feature to the model's prediction.
Class: Shap
¶
Initialization¶
from chemxai.explainers import Shap
explainer = Shap(
model=model,
background_tensor=background_data,
test_tensor=test_data,
device=device
)
Parameters¶
- model (
torch.nn.Module
): The trained model to be explained - background_tensor (
torch.Tensor
): Background data used to establish baseline for SHAP calculations - test_tensor (
torch.Tensor
): Test data to be explained - device (
torch.device
): Device for computations (CPU or GPU)
Methods¶
explain_local(index)
¶
Generates local explanations for a specific instance by calculating SHAP values.
Parameters:
- index (int
): Index of the instance in the test set to be explained
Returns: - list: SHAP values for each feature of the specified instance
Example:
# Explain the first instance
explanation = explainer.explain_local(index=0)
print(f"SHAP values: {explanation}")
explain_global()
¶
Generates global explanations by calculating the mean absolute SHAP values across all instances.
Returns: - list: Mean absolute SHAP values for each feature across all instances
Example:
# Get global feature importance
global_explanation = explainer.explain_global()
print(f"Global feature importance: {global_explanation}")
Usage Example¶
import torch
from chemxai.explainers import Shap
from chemxai.data import qm9_tabular
# Load data
qm9 = qm9_tabular()
train_loader, _, test_loader, _, _, _, _ = qm9.get_paired_dataloaders_tabular(
batch_size=32,
descriptor_type='Morgan',
morgan_radius=3,
morgan_nBits=512
)
# Get background and test data
background_batch = next(iter(train_loader))
test_batch = next(iter(test_loader))
# Initialize explainer
explainer = Shap(
model=your_model,
background_tensor=background_batch[0],
test_tensor=test_batch[0],
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)
# Generate local explanation
local_explanation = explainer.explain_local(index=0)
# Generate global explanation
global_explanation = explainer.explain_global()
Technical Details¶
SHAP Values Calculation¶
The explainer uses KernelExplainer from the SHAP library, which:
- Takes the background data to establish a baseline
- Calculates Shapley values by evaluating the model on coalitions of features
- Returns values that sum to the difference between the prediction and the baseline
Model Compatibility¶
- Works with any PyTorch model that accepts tabular input
- Automatically handles tensor conversion between PyTorch and NumPy
- Supports both regression and classification tasks
Performance Considerations¶
- Background data size affects computation time
- Larger test sets will take longer to process
- GPU acceleration is supported for model inference