GNN Explainer

Overview

GNNExplainer is designed to explain Graph Neural Network predictions by identifying important nodes and edges that contribute to the model's decision. It uses a masking strategy to iteratively remove or perturb graph elements and measure their impact on predictions.

Class: GNNExplain

Initialization

from chemxai.explainers import GNNExplain

explainer = GNNExplain(
    model=model,
    device=device,
    data=graph_data,
    epochs=100,
    mode='regression',
    task_level='graph',
    return_type='raw'
)

Parameters

  • model (torch.nn.Module): The trained GNN model to be explained
  • device (torch.device): Device for computations (CPU or GPU)
  • data (torch_geometric.data.Data): Graph data containing node features and edge indices
  • epochs (int): Number of epochs for training the explainer
  • mode (str, optional): Task type - 'regression' or 'classification' (default: 'regression')
  • task_level (str, optional): Explanation level - 'node' or 'graph' (default: 'graph')
  • return_type (str, optional): Format of returned explanation - 'raw' or 'probabilities' (default: 'raw')

Methods

explain(index=None)

Explains the prediction by calculating node and edge importance masks.

Parameters: - index (int, optional): Index for node-level explanations. For graph-level explanations, this parameter is less relevant

Returns: - tuple: A tuple containing: - node_mask (list): Feature importance values for each node - edge_mask (list): Feature importance values for each edge - explanation (Explanation): Full explanation object with additional metadata

Example:

# Explain graph-level prediction
node_mask, edge_mask, full_explanation = explainer.explain()

# Explain with specific index
node_mask, edge_mask, full_explanation = explainer.explain(index=0)

Usage Example

import torch
from torch_geometric.data import Data
from chemxai.explainers import GNNExplain

# Prepare graph data
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)
x = torch.randn(3, 4)  # 3 nodes, 4 features each
data = Data(x=x, edge_index=edge_index)

# Initialize explainer
explainer = GNNExplain(
    model=your_gnn_model,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    data=data,
    epochs=100,
    mode='regression',
    task_level='graph'
)

# Generate explanation
node_importance, edge_importance, explanation = explainer.explain()

print(f"Node importance: {node_importance}")
print(f"Edge importance: {edge_importance}")

Technical Details

Explanation Methodology

GNNExplainer works by:

  1. Mask Initialization: Creates learnable masks for nodes and edges
  2. Optimization: Trains masks to maximize mutual information between masked graph and predictions
  3. Regularization: Applies sparsity constraints to produce interpretable explanations
  4. Output: Returns importance scores for nodes and edges

Task Levels

  • Graph-level: Explains predictions for entire graphs (molecular property prediction)
  • Node-level: Explains predictions for individual nodes (atom property prediction)

Mask Types

  • Node masks: Indicate importance of node features/attributes
  • Edge masks: Indicate importance of edges/bonds in the graph

Integration with PyTorch Geometric

The explainer uses PyTorch Geometric's Explainer framework: - Automatically handles batching for graph-level tasks - Supports various GNN architectures - Provides consistent API across different explanation methods

Performance Considerations

  • Training epochs affect explanation quality and computation time
  • Larger graphs require more epochs for convergence
  • GPU acceleration significantly speeds up the explanation process
  • Memory usage scales with graph size and number of epochs

Visualization

The explanation results can be used with molecular visualization tools: - Node masks can highlight important atoms - Edge masks can highlight important bonds - Results integrate well with RDKit for molecular visualization