GNN Explainer¶
Overview¶
GNNExplainer is designed to explain Graph Neural Network predictions by identifying important nodes and edges that contribute to the model's decision. It uses a masking strategy to iteratively remove or perturb graph elements and measure their impact on predictions.
Class: GNNExplain
¶
Initialization¶
from chemxai.explainers import GNNExplain
explainer = GNNExplain(
model=model,
device=device,
data=graph_data,
epochs=100,
mode='regression',
task_level='graph',
return_type='raw'
)
Parameters¶
- model (
torch.nn.Module
): The trained GNN model to be explained - device (
torch.device
): Device for computations (CPU or GPU) - data (
torch_geometric.data.Data
): Graph data containing node features and edge indices - epochs (
int
): Number of epochs for training the explainer - mode (
str
, optional): Task type - 'regression' or 'classification' (default: 'regression') - task_level (
str
, optional): Explanation level - 'node' or 'graph' (default: 'graph') - return_type (
str
, optional): Format of returned explanation - 'raw' or 'probabilities' (default: 'raw')
Methods¶
explain(index=None)
¶
Explains the prediction by calculating node and edge importance masks.
Parameters:
- index (int
, optional): Index for node-level explanations. For graph-level explanations, this parameter is less relevant
Returns:
- tuple: A tuple containing:
- node_mask (list
): Feature importance values for each node
- edge_mask (list
): Feature importance values for each edge
- explanation (Explanation
): Full explanation object with additional metadata
Example:
# Explain graph-level prediction
node_mask, edge_mask, full_explanation = explainer.explain()
# Explain with specific index
node_mask, edge_mask, full_explanation = explainer.explain(index=0)
Usage Example¶
import torch
from torch_geometric.data import Data
from chemxai.explainers import GNNExplain
# Prepare graph data
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)
x = torch.randn(3, 4) # 3 nodes, 4 features each
data = Data(x=x, edge_index=edge_index)
# Initialize explainer
explainer = GNNExplain(
model=your_gnn_model,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
data=data,
epochs=100,
mode='regression',
task_level='graph'
)
# Generate explanation
node_importance, edge_importance, explanation = explainer.explain()
print(f"Node importance: {node_importance}")
print(f"Edge importance: {edge_importance}")
Technical Details¶
Explanation Methodology¶
GNNExplainer works by:
- Mask Initialization: Creates learnable masks for nodes and edges
- Optimization: Trains masks to maximize mutual information between masked graph and predictions
- Regularization: Applies sparsity constraints to produce interpretable explanations
- Output: Returns importance scores for nodes and edges
Task Levels¶
- Graph-level: Explains predictions for entire graphs (molecular property prediction)
- Node-level: Explains predictions for individual nodes (atom property prediction)
Mask Types¶
- Node masks: Indicate importance of node features/attributes
- Edge masks: Indicate importance of edges/bonds in the graph
Integration with PyTorch Geometric¶
The explainer uses PyTorch Geometric's Explainer
framework:
- Automatically handles batching for graph-level tasks
- Supports various GNN architectures
- Provides consistent API across different explanation methods
Performance Considerations¶
- Training epochs affect explanation quality and computation time
- Larger graphs require more epochs for convergence
- GPU acceleration significantly speeds up the explanation process
- Memory usage scales with graph size and number of epochs
Visualization¶
The explanation results can be used with molecular visualization tools: - Node masks can highlight important atoms - Edge masks can highlight important bonds - Results integrate well with RDKit for molecular visualization