Graph-based Explainers

Overview

This module provides several explainers specifically designed for graph neural networks, including SHAP-based and LIME-based methods adapted for graph data structures.

GraphShap

Class: GraphShap

Implements KernelSHAP for graph-level predictions by computing Shapley values for graph features.

Initialization

from chemxai.explainers import GraphShap

explainer = GraphShap(
    data=graph_data,
    model=model,
    device=device,
    gpu=False
)

Parameters

  • data (torch_geometric.data.Data): Graph data with node features and edge indices
  • model (torch.nn.Module): Trained GNN model
  • device (torch.device): Computation device
  • gpu (bool, optional): Whether to use GPU acceleration (default: False)

Methods

explain(num_samples=30, info=True)

Parameters: - num_samples (int): Number of samples for SHAP approximation - info (bool): Whether to return additional information

Returns: - list: SHAP values for each feature in the graph

Example:

explainer = GraphShap(data=graph_data, model=model, device=device)
shap_values = explainer.explain(num_samples=50)

GraphLIME

Class: GraphLIME

Adapts LIME for graph-level predictions by perturbing node features and measuring impact on predictions.

Initialization

from chemxai.explainers import GraphLIME

explainer = GraphLIME(
    model=model,
    device='cpu',
    rho=0.1
)

Parameters

  • model (torch.nn.Module): Trained GNN model
  • device (str, optional): Computation device (default: 'cpu')
  • rho (float, optional): Regularization parameter for Lasso regression (default: 0.1)

Methods

explain(data, num_samples=100)

Parameters: - data (torch_geometric.data.Data): Graph data to explain - num_samples (int): Number of perturbations for LIME

Returns: - list: Feature importance values for the graph

Example:

explainer = GraphLIME(model=model, device='cpu')
importance = explainer.explain(data=graph_data, num_samples=100)

NodeGraphShap

Class: NodeGraphShap

Implements KernelSHAP for node-level predictions in graphs.

Initialization

from chemxai.explainers import NodeGraphShap

explainer = NodeGraphShap(
    data=graph_data,
    model=model,
    device=device,
    gpu=False
)

Parameters

  • data (torch_geometric.data.Data): Graph data
  • model (torch.nn.Module): Trained GNN model
  • device (torch.device): Computation device
  • gpu (bool, optional): GPU acceleration flag

Methods

explain(node_index=0, hops=2, num_samples=10, info=True, multiclass=False)

Parameters: - node_index (int): Target node index - hops (int): Number of hops for subgraph extraction - num_samples (int): Number of samples for SHAP approximation - info (bool): Return additional information - multiclass (bool): Whether model is multiclass

Returns: - list: SHAP values for node features

NodeGrapLIME

Class: NodeGrapLIME

Adapts LIME for node-level explanations in graphs.

Initialization

from chemxai.explainers import NodeGrapLIME

explainer = NodeGrapLIME(
    data=graph_data,
    model=model,
    device=device,
    gpu=False,
    hop=2,
    rho=0.1,
    cached=True
)

Parameters

  • data (torch_geometric.data.Data): Graph data
  • model (torch.nn.Module): Trained GNN model
  • device (torch.device): Computation device
  • gpu (bool): GPU acceleration flag
  • hop (int): Number of hops for subgraph extraction
  • rho (float): Regularization parameter
  • cached (bool): Whether to cache predictions

Methods

explain(node_index, hops, num_samples, info=False, multiclass=False)

Parameters: - node_index (int): Target node index - hops (int): Number of hops for subgraph - num_samples (int): Number of perturbations - info (bool): Return additional information - multiclass (bool): Multiclass classification flag

Returns: - list: Feature importance values for the node

Usage Examples

Complete Graph Explanation Workflow

import torch
from torch_geometric.data import Data
from chemxai.explainers import GraphShap, GraphLIME, NodeGraphShap, NodeGrapLIME

# Prepare graph data
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)
x = torch.randn(4, 10)  # 4 nodes, 10 features each
data = Data(x=x, edge_index=edge_index)

# Graph-level explanations
graph_shap = GraphShap(data=data, model=model, device=device)
graph_shap_values = graph_shap.explain(num_samples=30)

graph_lime = GraphLIME(model=model, device=device)
graph_lime_values = graph_lime.explain(data=data, num_samples=100)

# Node-level explanations
node_shap = NodeGraphShap(data=data, model=model, device=device)
node_shap_values = node_shap.explain(node_index=0, hops=2, num_samples=30)

node_lime = NodeGrapLIME(data=data, model=model, device=device)
node_lime_values = node_lime.explain(node_index=0, hops=2, num_samples=30)

Technical Details

Shapley Value Computation

All SHAP-based explainers use the Shapley kernel formula: - Weight samples based on coalition size - Use weighted linear regression to approximate Shapley values - Handle edge cases (empty and full coalitions) with high weights

LIME Perturbation Strategy

LIME-based explainers: - Generate binary masks for feature presence/absence - Use Lasso regression for feature selection - Apply PCA for dimensionality reduction when needed

Subgraph Extraction

Node-level explainers use k-hop subgraph extraction: - Extract neighborhood around target node - Maintain graph connectivity - Relabel nodes for consistent indexing

Performance Optimization

  • Caching mechanisms for repeated predictions
  • GPU acceleration where applicable
  • Efficient matrix operations using NumPy and PyTorch