Graph-based Explainers¶
Overview¶
This module provides several explainers specifically designed for graph neural networks, including SHAP-based and LIME-based methods adapted for graph data structures.
GraphShap¶
Class: GraphShap
¶
Implements KernelSHAP for graph-level predictions by computing Shapley values for graph features.
Initialization¶
from chemxai.explainers import GraphShap
explainer = GraphShap(
data=graph_data,
model=model,
device=device,
gpu=False
)
Parameters¶
- data (
torch_geometric.data.Data
): Graph data with node features and edge indices - model (
torch.nn.Module
): Trained GNN model - device (
torch.device
): Computation device - gpu (
bool
, optional): Whether to use GPU acceleration (default: False)
Methods¶
explain(num_samples=30, info=True)
¶
Parameters:
- num_samples (int
): Number of samples for SHAP approximation
- info (bool
): Whether to return additional information
Returns: - list: SHAP values for each feature in the graph
Example:
explainer = GraphShap(data=graph_data, model=model, device=device)
shap_values = explainer.explain(num_samples=50)
GraphLIME¶
Class: GraphLIME
¶
Adapts LIME for graph-level predictions by perturbing node features and measuring impact on predictions.
Initialization¶
from chemxai.explainers import GraphLIME
explainer = GraphLIME(
model=model,
device='cpu',
rho=0.1
)
Parameters¶
- model (
torch.nn.Module
): Trained GNN model - device (
str
, optional): Computation device (default: 'cpu') - rho (
float
, optional): Regularization parameter for Lasso regression (default: 0.1)
Methods¶
explain(data, num_samples=100)
¶
Parameters:
- data (torch_geometric.data.Data
): Graph data to explain
- num_samples (int
): Number of perturbations for LIME
Returns: - list: Feature importance values for the graph
Example:
explainer = GraphLIME(model=model, device='cpu')
importance = explainer.explain(data=graph_data, num_samples=100)
NodeGraphShap¶
Class: NodeGraphShap
¶
Implements KernelSHAP for node-level predictions in graphs.
Initialization¶
from chemxai.explainers import NodeGraphShap
explainer = NodeGraphShap(
data=graph_data,
model=model,
device=device,
gpu=False
)
Parameters¶
- data (
torch_geometric.data.Data
): Graph data - model (
torch.nn.Module
): Trained GNN model - device (
torch.device
): Computation device - gpu (
bool
, optional): GPU acceleration flag
Methods¶
explain(node_index=0, hops=2, num_samples=10, info=True, multiclass=False)
¶
Parameters:
- node_index (int
): Target node index
- hops (int
): Number of hops for subgraph extraction
- num_samples (int
): Number of samples for SHAP approximation
- info (bool
): Return additional information
- multiclass (bool
): Whether model is multiclass
Returns: - list: SHAP values for node features
NodeGrapLIME¶
Class: NodeGrapLIME
¶
Adapts LIME for node-level explanations in graphs.
Initialization¶
from chemxai.explainers import NodeGrapLIME
explainer = NodeGrapLIME(
data=graph_data,
model=model,
device=device,
gpu=False,
hop=2,
rho=0.1,
cached=True
)
Parameters¶
- data (
torch_geometric.data.Data
): Graph data - model (
torch.nn.Module
): Trained GNN model - device (
torch.device
): Computation device - gpu (
bool
): GPU acceleration flag - hop (
int
): Number of hops for subgraph extraction - rho (
float
): Regularization parameter - cached (
bool
): Whether to cache predictions
Methods¶
explain(node_index, hops, num_samples, info=False, multiclass=False)
¶
Parameters:
- node_index (int
): Target node index
- hops (int
): Number of hops for subgraph
- num_samples (int
): Number of perturbations
- info (bool
): Return additional information
- multiclass (bool
): Multiclass classification flag
Returns: - list: Feature importance values for the node
Usage Examples¶
Complete Graph Explanation Workflow¶
import torch
from torch_geometric.data import Data
from chemxai.explainers import GraphShap, GraphLIME, NodeGraphShap, NodeGrapLIME
# Prepare graph data
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)
x = torch.randn(4, 10) # 4 nodes, 10 features each
data = Data(x=x, edge_index=edge_index)
# Graph-level explanations
graph_shap = GraphShap(data=data, model=model, device=device)
graph_shap_values = graph_shap.explain(num_samples=30)
graph_lime = GraphLIME(model=model, device=device)
graph_lime_values = graph_lime.explain(data=data, num_samples=100)
# Node-level explanations
node_shap = NodeGraphShap(data=data, model=model, device=device)
node_shap_values = node_shap.explain(node_index=0, hops=2, num_samples=30)
node_lime = NodeGrapLIME(data=data, model=model, device=device)
node_lime_values = node_lime.explain(node_index=0, hops=2, num_samples=30)
Technical Details¶
Shapley Value Computation¶
All SHAP-based explainers use the Shapley kernel formula: - Weight samples based on coalition size - Use weighted linear regression to approximate Shapley values - Handle edge cases (empty and full coalitions) with high weights
LIME Perturbation Strategy¶
LIME-based explainers: - Generate binary masks for feature presence/absence - Use Lasso regression for feature selection - Apply PCA for dimensionality reduction when needed
Subgraph Extraction¶
Node-level explainers use k-hop subgraph extraction: - Extract neighborhood around target node - Maintain graph connectivity - Relabel nodes for consistent indexing
Performance Optimization¶
- Caching mechanisms for repeated predictions
- GPU acceleration where applicable
- Efficient matrix operations using NumPy and PyTorch