Evaluator¶
Overview¶
The Evaluator class provides comprehensive robustness assessment for explainability methods by comparing explanations from models trained on clean data versus noisy data. It implements multiple similarity metrics to quantify explanation stability.
Class: Evaluator
¶
Initialization¶
from chemxai.evaluate import Evaluator
evaluator = Evaluator(
model_normal=clean_model,
model_noise=noisy_model,
train_loader_normal=clean_train_loader,
test_loader_normal=clean_test_loader,
train_loader_noise=noisy_train_loader,
test_loader_noise=noisy_test_loader,
device=device,
model_type='graph',
explainer_type='shap_local',
mol_index=0,
atom_index=0
)
Parameters¶
- model_normal (
torch.nn.Module
): Model trained on clean data - model_noise (
torch.nn.Module
): Model trained on noisy data - train_loader_normal (
DataLoader
): Training data loader (clean) - test_loader_normal (
DataLoader
): Test data loader (clean) - train_loader_noise (
DataLoader
): Training data loader (noisy) - test_loader_noise (
DataLoader
): Test data loader (noisy) - device (
torch.device
): Computation device - model_type (
str
): Type of model - 'graph' or 'tabular' (default: 'graph') - explainer_type (
str
): Type of explainer to evaluate (default: 'shap_local') - mol_index (
int
): Molecule index for analysis (default: 0) - atom_index (
int
): Atom index for node-level explanations (default: 0)
Supported Explainer Types¶
- Tabular:
'shap_local'
,'shap_global'
,'lime'
- Graph:
'gnn_explainer'
,'graph_shap'
,'graph_lime'
,'node_graph_shap'
,'node_graph_lime'
Methods¶
robustness()
¶
Evaluates the robustness of explanations by comparing results from clean and noisy models.
Returns:
- similarities (list
): Cosine similarity values between explanations
- l1_differences (list
): L1 norm differences between explanations
- l2_differences (list
): L2 norm differences between explanations
- spearman_correlations (list
): Spearman correlation coefficients
- figs (list
): List of matplotlib figures showing distributions
Example:
# Evaluate robustness
similarities, l1_diffs, l2_diffs, spearman_corrs, figures = evaluator.robustness()
print(f"Mean cosine similarity: {np.mean(similarities):.3f}")
print(f"Mean L1 difference: {np.mean(l1_diffs):.3f}")
print(f"Mean Spearman correlation: {np.mean(spearman_corrs):.3f}")
Usage Examples¶
Tabular Model Evaluation¶
import torch
from chemxai.evaluate import Evaluator
# Initialize evaluator for SHAP local explanations
evaluator = Evaluator(
model_normal=clean_tabular_model,
model_noise=noisy_tabular_model,
train_loader_normal=clean_train_loader,
test_loader_normal=clean_test_loader,
train_loader_noise=noisy_train_loader,
test_loader_noise=noisy_test_loader,
device=torch.device('cuda'),
model_type='tabular',
explainer_type='shap_local'
)
# Run robustness evaluation
results = evaluator.robustness()
similarities, l1_diffs, l2_diffs, spearman_corrs, figs = results
# Analyze results
import numpy as np
print(f"Robustness Metrics:")
print(f" Cosine Similarity: {np.mean(similarities):.3f} ± {np.std(similarities):.3f}")
print(f" L1 Difference: {np.mean(l1_diffs):.3f} ± {np.std(l1_diffs):.3f}")
print(f" Spearman Correlation: {np.mean(spearman_corrs):.3f} ± {np.std(spearman_corrs):.3f}")
Graph Model Evaluation¶
# Initialize evaluator for GNN explainer
evaluator = Evaluator(
model_normal=clean_gnn_model,
model_noise=noisy_gnn_model,
train_loader_normal=clean_graph_train_loader,
test_loader_normal=clean_graph_test_loader,
train_loader_noise=noisy_graph_train_loader,
test_loader_noise=noisy_graph_test_loader,
device=torch.device('cuda'),
model_type='graph',
explainer_type='gnn_explainer'
)
# Evaluate robustness
results = evaluator.robustness()
Node-Level Evaluation¶
# Evaluate robustness for node-level explanations
evaluator = Evaluator(
model_normal=clean_gnn_model,
model_noise=noisy_gnn_model,
train_loader_normal=clean_train_loader,
test_loader_normal=clean_test_loader,
train_loader_noise=noisy_train_loader,
test_loader_noise=noisy_test_loader,
device=device,
model_type='graph',
explainer_type='node_graph_shap',
atom_index=2 # Focus on specific atom
)
results = evaluator.robustness()
Technical Details¶
Robustness Metrics¶
Cosine Similarity¶
Measures the cosine of the angle between explanation vectors: - Range: [-1, 1] - Interpretation: 1 = identical direction, 0 = orthogonal, -1 = opposite - Higher values indicate more robust explanations
L1 Difference (Manhattan Distance)¶
Sum of absolute differences between explanation values: - Range: [0, ∞) - Interpretation: 0 = identical explanations - Lower values indicate more robust explanations
L2 Difference (Euclidean Distance)¶
Euclidean norm of the difference vector: - Range: [0, ∞) - Interpretation: 0 = identical explanations - Lower values indicate more robust explanations
Spearman Correlation¶
Rank correlation coefficient between explanations: - Range: [-1, 1] - Interpretation: 1 = perfect rank correlation, 0 = no correlation - Higher values indicate more robust ranking
Explanation Processing¶
The evaluator handles various explanation formats: - Tuples: Automatically extracts relevant components (e.g., node masks from GNNExplainer) - Multi-dimensional arrays: Flattens complex structures - Missing values: Handles NaN and infinite values gracefully - Size mismatches: Aligns explanations of different dimensions
Error Handling¶
Robust error handling for: - Invalid explanation formats - Numerical instabilities (NaN, infinity) - Matrix inversion failures - Empty or malformed explanations
Class: FingerprintAnalyzer
¶
Overview¶
Analyzes Morgan fingerprint explanations by visualizing important molecular fragments and their contributions to predictions.
Initialization¶
from chemxai.evaluate import FingerprintAnalyzer
analyzer = FingerprintAnalyzer(
explanation=shap_values,
batch_idx=0,
mol_idx=0,
dataset_type='test',
device='cpu'
)
Parameters¶
- explanation (
list
ornumpy.ndarray
): Pre-calculated SHAP values for fingerprint bits - batch_idx (
int
): Batch index containing the molecule of interest - mol_idx (
int
): Molecule index within the batch - dataset_type (
str
): Dataset type - 'train', 'test', or 'val' (default: 'test') - device (
str
): Computation device (default: 'cpu')
Methods¶
analyze()
¶
Performs complete analysis of fingerprint explanations with molecular visualization.
Returns: - str: Text summary of the analysis
Example:
# Analyze fingerprint explanation
analyzer = FingerprintAnalyzer(
explanation=shap_values,
batch_idx=0,
mol_idx=0,
dataset_type='test'
)
analysis_summary = analyzer.analyze()
print(analysis_summary)
Visualization Features¶
Original Molecule Display¶
Shows the complete molecular structure being analyzed.
Bits Summary¶
Displays: - Total number of important bits - Active vs inactive bits in the fingerprint - Top 10 most important bits with their SHAP values
Fragment Analysis¶
For each important active bit: - Molecular fragment visualization - Side-by-side comparison of full molecule and fragment - Atom environment highlighting - Substructure extraction when possible
Color Coding¶
- Red: Central atoms in fragments
- Orange: Neighboring atoms in the environment
- Blue: Inactive but important bits
Usage Example¶
import torch
from chemxai.explainers import Shap
from chemxai.evaluate import FingerprintAnalyzer
# Generate SHAP explanation
explainer = Shap(model=model, background_tensor=background,
test_tensor=test_data, device=device)
explanation = explainer.explain_local(index=0)
# Analyze the fingerprint explanation
analyzer = FingerprintAnalyzer(
explanation=explanation,
batch_idx=0,
mol_idx=0,
dataset_type='test'
)
# Run complete analysis
summary = analyzer.analyze()
# The analysis will display:
# 1. Original molecule structure
# 2. Summary of important bits
# 3. Individual fragment visualizations
# 4. Text summary of findings
Integration with ChemXAI¶
The evaluator integrates seamlessly with other ChemXAI components:
With Explainers¶
from chemxai.explainers import Shap, LIME, GNNExplain
from chemxai.evaluate import Evaluator
# Works with all explainer types
evaluator = Evaluator(..., explainer_type='shap_local')
# or
evaluator = Evaluator(..., explainer_type='gnn_explainer')
With Data Loaders¶
from chemxai.data import qm9_tabular
qm9 = qm9_tabular()
train_loader, _, test_loader, train_noise, _, test_noise, _ = qm9.get_paired_dataloaders_tabular(
batch_size=32, n_noise=3
)
evaluator = Evaluator(
train_loader_normal=train_loader,
test_loader_normal=test_loader,
train_loader_noise=train_noise,
test_loader_noise=test_noise,
# ... other parameters
)
Output Visualization¶
The robustness evaluation automatically generates and saves:
- Cosine similarity distribution histogram
- L1 difference distribution histogram
- Spearman correlation distribution histogram
Files are saved to a graphs/
directory with descriptive names.
Performance Considerations¶
- Batch Processing: Evaluates all molecules in provided batches
- Memory Management: Handles large datasets efficiently
- GPU Acceleration: Supports CUDA for faster computation
- Parallel Processing: Can process multiple explanation types simultaneously
Best Practices¶
- Data Balance: Ensure clean and noisy datasets have similar distributions
- Noise Level: Use appropriate noise levels that reflect real-world conditions
- Sample Size: Use sufficient samples for statistical significance
- Metric Selection: Choose metrics appropriate for your explanation type
- Visualization: Always examine distribution plots, not just summary statistics