ChemXAI Documentation

Overview

ChemXAI is a comprehensive library for explainable AI in chemical machine learning. It provides a unified interface for explaining predictions from both tabular and graph-based molecular models using various explanation methods.

Quick Start

# Install ChemXAI
pip install chemxai

# Basic usage example
from chemxai.explainers import Shap
from chemxai.data import qm9_tabular

# Load data
qm9 = qm9_tabular()
train_loader, _, test_loader, _, _, _, _ = qm9.get_paired_dataloaders_tabular(
    batch_size=32, descriptor_type='Morgan'
)

# Create explainer
explainer = Shap(
    model=your_model,
    background_tensor=next(iter(train_loader))[0],
    test_tensor=next(iter(test_loader))[0],
    device=device
)

# Generate explanation
explanation = explainer.explain_local(index=0)

Documentation Structure

Explainers

Tabular Explainers

  • SHAP: SHapley Additive exPlanations for tabular molecular data
  • LIME: Local Interpretable Model-agnostic Explanations for tabular data

Graph-based Explainers

  • GNN Explainer: Native GNN explanation method using masking
  • Graph Explainers: Collection of graph-specific explanation methods:
  • GraphShap: SHAP for graph-level predictions
  • GraphLIME: LIME for graph-level predictions
  • NodeGraphShap: SHAP for node-level predictions
  • NodeGrapLIME: LIME for node-level predictions

Visualization and Analysis

  • Plots: Visualization functions for explanation results
  • Evaluator: Robustness evaluation and fingerprint analysis

Supported Model Types

Tabular Models

  • PyTorch neural networks
  • Models accepting molecular fingerprints
  • Regression and classification tasks

Graph Models

  • PyTorch Geometric compatible models
  • Message Passing Neural Networks (MPNNs)
  • Graph Convolutional Networks (GCNs)
  • Graph Attention Networks (GATs)

Supported Data Types

Molecular Representations

  • Morgan Fingerprints: Circular fingerprints for molecular structure
  • Graph Data: Node features and edge connectivity
  • SMILES: Molecular string representations
  • Custom Descriptors: User-defined molecular features

Tasks

  • Molecular Property Prediction: Regression tasks for chemical properties
  • Molecular Classification: Classification of molecular types
  • Node-level Prediction: Atom-specific property prediction
  • Graph-level Prediction: Whole-molecule property prediction

Core Features

Explanation Methods

  • Local Explanations: Instance-specific feature importance
  • Global Explanations: Dataset-wide feature importance patterns
  • Graph Explanations: Node and edge importance in molecular graphs
  • Substructure Analysis: Important molecular fragments identification

Robustness Evaluation

  • Noise Sensitivity: Compare explanations under data perturbations
  • Multiple Metrics: Cosine similarity, L1/L2 distances, rank correlation
  • Statistical Analysis: Distribution analysis and significance testing
  • Visualization: Automatic generation of robustness plots

Molecular Analysis

  • Fragment Identification: Map important bits to molecular substructures
  • Chemical Interpretation: Visualize important molecular regions
  • Interactive Analysis: Detailed molecular fragment exploration

Installation

# Install from PyPI
pip install chemxai

# Install development version
git clone https://github.com/your-repo/chemxai.git
cd chemxai
pip install -e .

Dependencies

Core dependencies: - PyTorch >= 1.9.0 - torch-geometric >= 2.0.0 - RDKit >= 2021.09.1 - scikit-learn >= 1.0.0 - SHAP >= 0.40.0 - LIME >= 0.2.0 - matplotlib >= 3.3.0 - numpy >= 1.20.0

API Reference

Quick Reference

# Tabular explainers
from chemxai.explainers import Shap, LIME

# Graph explainers  
from chemxai.explainers import (
    GNNExplain, GraphShap, GraphLIME, 
    NodeGraphShap, NodeGrapLIME
)

# Visualization
from chemxai.plots import radar_plot, horizontal_bar_plot

# Evaluation
from chemxai.evaluate import Evaluator, FingerprintAnalyzer

Common Patterns

Explaining Tabular Models

# 1. Prepare data
background = train_data[0]  # Background for SHAP
test_data = test_batch[0]   # Data to explain

# 2. Create explainer
explainer = Shap(model, background, test_data, device)

# 3. Generate explanations
local_exp = explainer.explain_local(index=0)
global_exp = explainer.explain_global()

# 4. Visualize
from chemxai.plots import horizontal_bar_plot
fig, ax = horizontal_bar_plot(local_exp, title="Feature Importance")

Explaining Graph Models

# 1. Prepare graph data
graph_data = Data(x=node_features, edge_index=edge_indices)

# 2. Create explainer
explainer = GNNExplain(model, device, graph_data, epochs=100)

# 3. Generate explanation
node_mask, edge_mask, explanation = explainer.explain()

# 4. Analyze results
important_nodes = torch.argsort(node_mask, descending=True)[:5]

Evaluating Robustness

# 1. Setup evaluator
evaluator = Evaluator(
    model_normal=clean_model,
    model_noise=noisy_model,
    train_loader_normal=clean_train,
    test_loader_normal=clean_test,
    train_loader_noise=noisy_train,
    test_loader_noise=noisy_test,
    device=device,
    explainer_type='shap_local'
)

# 2. Run evaluation
results = evaluator.robustness()
similarities, l1_diffs, l2_diffs, spearman_corrs, figs = results

# 3. Analyze results
mean_similarity = np.mean(similarities)
print(f"Mean cosine similarity: {mean_similarity:.3f}")

Best Practices

Model Preparation

  1. Ensure Reproducibility: Set random seeds for consistent results
  2. Model State: Use model.eval() during explanation generation
  3. Device Consistency: Keep model and data on the same device

Data Preparation

  1. Background Selection: Choose representative background data for SHAP
  2. Batch Sizes: Use appropriate batch sizes for memory constraints
  3. Data Quality: Ensure clean data for baseline comparisons

Explanation Generation

  1. Method Selection: Choose appropriate explainer for your model type
  2. Parameter Tuning: Adjust explainer parameters for your use case
  3. Validation: Cross-validate explanations across multiple instances

Evaluation and Analysis

  1. Multiple Metrics: Use multiple robustness metrics for comprehensive evaluation
  2. Statistical Significance: Ensure sufficient samples for reliable statistics
  3. Domain Knowledge: Validate explanations against chemical knowledge

Examples and Tutorials

Complete Workflow Example

import torch
from chemxai.explainers import Shap
from chemxai.evaluate import Evaluator, FingerprintAnalyzer
from chemxai.plots import horizontal_bar_plot

# 1. Load your trained model and data
model = torch.load('your_model.pt')
# ... load data loaders ...

# 2. Generate explanations
explainer = Shap(model, background_data, test_data, device)
explanation = explainer.explain_local(index=0)

# 3. Visualize explanations
fig, ax = horizontal_bar_plot(
    explanation, 
    title="Molecular Feature Importance",
    max_features=20
)

# 4. Analyze molecular fragments (for fingerprint data)
analyzer = FingerprintAnalyzer(explanation, batch_idx=0, mol_idx=0)
summary = analyzer.analyze()

# 5. Evaluate robustness
evaluator = Evaluator(
    model_normal=clean_model,
    model_noise=noisy_model,
    # ... data loaders ...
    explainer_type='shap_local'
)
robustness_results = evaluator.robustness()

Support

  • Documentation: This comprehensive documentation
  • Issues: GitHub issues for bug reports and feature requests
  • Discussions: GitHub discussions for questions and community support
  • Examples: Jupyter notebooks with detailed examples

License

ChemXAI is released under the MIT License. See LICENSE for details.