ChemXAI Documentation¶
Overview¶
ChemXAI is a comprehensive library for explainable AI in chemical machine learning. It provides a unified interface for explaining predictions from both tabular and graph-based molecular models using various explanation methods.
Quick Start¶
# Install ChemXAI
pip install chemxai
# Basic usage example
from chemxai.explainers import Shap
from chemxai.data import qm9_tabular
# Load data
qm9 = qm9_tabular()
train_loader, _, test_loader, _, _, _, _ = qm9.get_paired_dataloaders_tabular(
batch_size=32, descriptor_type='Morgan'
)
# Create explainer
explainer = Shap(
model=your_model,
background_tensor=next(iter(train_loader))[0],
test_tensor=next(iter(test_loader))[0],
device=device
)
# Generate explanation
explanation = explainer.explain_local(index=0)
Documentation Structure¶
Explainers¶
Tabular Explainers¶
- SHAP: SHapley Additive exPlanations for tabular molecular data
- LIME: Local Interpretable Model-agnostic Explanations for tabular data
Graph-based Explainers¶
- GNN Explainer: Native GNN explanation method using masking
- Graph Explainers: Collection of graph-specific explanation methods:
- GraphShap: SHAP for graph-level predictions
- GraphLIME: LIME for graph-level predictions
- NodeGraphShap: SHAP for node-level predictions
- NodeGrapLIME: LIME for node-level predictions
Visualization and Analysis¶
- Plots: Visualization functions for explanation results
- Evaluator: Robustness evaluation and fingerprint analysis
Supported Model Types¶
Tabular Models¶
- PyTorch neural networks
- Models accepting molecular fingerprints
- Regression and classification tasks
Graph Models¶
- PyTorch Geometric compatible models
- Message Passing Neural Networks (MPNNs)
- Graph Convolutional Networks (GCNs)
- Graph Attention Networks (GATs)
Supported Data Types¶
Molecular Representations¶
- Morgan Fingerprints: Circular fingerprints for molecular structure
- Graph Data: Node features and edge connectivity
- SMILES: Molecular string representations
- Custom Descriptors: User-defined molecular features
Tasks¶
- Molecular Property Prediction: Regression tasks for chemical properties
- Molecular Classification: Classification of molecular types
- Node-level Prediction: Atom-specific property prediction
- Graph-level Prediction: Whole-molecule property prediction
Core Features¶
Explanation Methods¶
- Local Explanations: Instance-specific feature importance
- Global Explanations: Dataset-wide feature importance patterns
- Graph Explanations: Node and edge importance in molecular graphs
- Substructure Analysis: Important molecular fragments identification
Robustness Evaluation¶
- Noise Sensitivity: Compare explanations under data perturbations
- Multiple Metrics: Cosine similarity, L1/L2 distances, rank correlation
- Statistical Analysis: Distribution analysis and significance testing
- Visualization: Automatic generation of robustness plots
Molecular Analysis¶
- Fragment Identification: Map important bits to molecular substructures
- Chemical Interpretation: Visualize important molecular regions
- Interactive Analysis: Detailed molecular fragment exploration
Installation¶
# Install from PyPI
pip install chemxai
# Install development version
git clone https://github.com/your-repo/chemxai.git
cd chemxai
pip install -e .
Dependencies¶
Core dependencies: - PyTorch >= 1.9.0 - torch-geometric >= 2.0.0 - RDKit >= 2021.09.1 - scikit-learn >= 1.0.0 - SHAP >= 0.40.0 - LIME >= 0.2.0 - matplotlib >= 3.3.0 - numpy >= 1.20.0
API Reference¶
Quick Reference¶
# Tabular explainers
from chemxai.explainers import Shap, LIME
# Graph explainers
from chemxai.explainers import (
GNNExplain, GraphShap, GraphLIME,
NodeGraphShap, NodeGrapLIME
)
# Visualization
from chemxai.plots import radar_plot, horizontal_bar_plot
# Evaluation
from chemxai.evaluate import Evaluator, FingerprintAnalyzer
Common Patterns¶
Explaining Tabular Models¶
# 1. Prepare data
background = train_data[0] # Background for SHAP
test_data = test_batch[0] # Data to explain
# 2. Create explainer
explainer = Shap(model, background, test_data, device)
# 3. Generate explanations
local_exp = explainer.explain_local(index=0)
global_exp = explainer.explain_global()
# 4. Visualize
from chemxai.plots import horizontal_bar_plot
fig, ax = horizontal_bar_plot(local_exp, title="Feature Importance")
Explaining Graph Models¶
# 1. Prepare graph data
graph_data = Data(x=node_features, edge_index=edge_indices)
# 2. Create explainer
explainer = GNNExplain(model, device, graph_data, epochs=100)
# 3. Generate explanation
node_mask, edge_mask, explanation = explainer.explain()
# 4. Analyze results
important_nodes = torch.argsort(node_mask, descending=True)[:5]
Evaluating Robustness¶
# 1. Setup evaluator
evaluator = Evaluator(
model_normal=clean_model,
model_noise=noisy_model,
train_loader_normal=clean_train,
test_loader_normal=clean_test,
train_loader_noise=noisy_train,
test_loader_noise=noisy_test,
device=device,
explainer_type='shap_local'
)
# 2. Run evaluation
results = evaluator.robustness()
similarities, l1_diffs, l2_diffs, spearman_corrs, figs = results
# 3. Analyze results
mean_similarity = np.mean(similarities)
print(f"Mean cosine similarity: {mean_similarity:.3f}")
Best Practices¶
Model Preparation¶
- Ensure Reproducibility: Set random seeds for consistent results
- Model State: Use
model.eval()
during explanation generation - Device Consistency: Keep model and data on the same device
Data Preparation¶
- Background Selection: Choose representative background data for SHAP
- Batch Sizes: Use appropriate batch sizes for memory constraints
- Data Quality: Ensure clean data for baseline comparisons
Explanation Generation¶
- Method Selection: Choose appropriate explainer for your model type
- Parameter Tuning: Adjust explainer parameters for your use case
- Validation: Cross-validate explanations across multiple instances
Evaluation and Analysis¶
- Multiple Metrics: Use multiple robustness metrics for comprehensive evaluation
- Statistical Significance: Ensure sufficient samples for reliable statistics
- Domain Knowledge: Validate explanations against chemical knowledge
Examples and Tutorials¶
Complete Workflow Example¶
import torch
from chemxai.explainers import Shap
from chemxai.evaluate import Evaluator, FingerprintAnalyzer
from chemxai.plots import horizontal_bar_plot
# 1. Load your trained model and data
model = torch.load('your_model.pt')
# ... load data loaders ...
# 2. Generate explanations
explainer = Shap(model, background_data, test_data, device)
explanation = explainer.explain_local(index=0)
# 3. Visualize explanations
fig, ax = horizontal_bar_plot(
explanation,
title="Molecular Feature Importance",
max_features=20
)
# 4. Analyze molecular fragments (for fingerprint data)
analyzer = FingerprintAnalyzer(explanation, batch_idx=0, mol_idx=0)
summary = analyzer.analyze()
# 5. Evaluate robustness
evaluator = Evaluator(
model_normal=clean_model,
model_noise=noisy_model,
# ... data loaders ...
explainer_type='shap_local'
)
robustness_results = evaluator.robustness()
Support¶
- Documentation: This comprehensive documentation
- Issues: GitHub issues for bug reports and feature requests
- Discussions: GitHub discussions for questions and community support
- Examples: Jupyter notebooks with detailed examples
License¶
ChemXAI is released under the MIT License. See LICENSE for details.