Plots and Visualization¶
Overview¶
This module provides visualization functions for explainability results, including radar plots, horizontal bar plots, and utility functions for graph manipulation.
Functions¶
k_hop_subgraph
¶
Computes the k-hop subgraph around specified nodes.
from chemxai.plots import k_hop_subgraph
subset, edge_index, inv, edge_mask = k_hop_subgraph(
node_idx, num_hops, edge_index, relabel_nodes=False,
num_nodes=None, flow='source_to_target'
)
Parameters¶
- node_idx (
int
,list
,tuple
, ortorch.Tensor
): Central node(s) - num_hops (
int
): Number of hops k - edge_index (
torch.LongTensor
): Edge indices of the graph - relabel_nodes (
bool
, optional): Whether to relabel nodes consecutively (default: False) - num_nodes (
int
, optional): Total number of nodes in the graph (default: None) - flow (
str
, optional): Flow direction - 'source_to_target' or 'target_to_source' (default: 'source_to_target')
Returns¶
- subset (
torch.LongTensor
): Nodes in the subgraph - edge_index (
torch.LongTensor
): Filtered edge connectivity - inv (
torch.LongTensor
): Mapping from original to new node indices - edge_mask (
torch.BoolTensor
): Mask indicating preserved edges
Example¶
import torch
from chemxai.plots import k_hop_subgraph
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)
subset, new_edge_index, mapping, edge_mask = k_hop_subgraph(
node_idx=0, num_hops=2, edge_index=edge_index, relabel_nodes=True
)
radar_plot
¶
Creates a radar plot to visualize feature importance values.
from chemxai.plots import radar_plot
fig, ax = radar_plot(
values, feature_names=None,
title="Feature Importance Radar Plot"
)
Parameters¶
- values (
numpy.ndarray
): Feature importance values - feature_names (
list
, optional): Names of features. If None, uses generic names - title (
str
): Plot title
Returns¶
- fig (
matplotlib.figure.Figure
): Figure object - ax (
matplotlib.axes.Axes
): Axes object
Example¶
import numpy as np
from chemxai.plots import radar_plot
# Sample SHAP values
shap_values = np.random.randn(10)
feature_names = [f"Feature_{i}" for i in range(10)]
fig, ax = radar_plot(shap_values, feature_names, "SHAP Values Radar Plot")
plt.show()
horizontal_bar_plot
¶
Creates a horizontal bar plot for feature importance visualization.
from chemxai.plots import horizontal_bar_plot
fig, ax = horizontal_bar_plot(
values, feature_names=None, title="Feature Importance",
sort=True, max_features=None, color_positive='blue',
color_negative='red', figsize=(12, 8), save_path='graphs',
filename='feature_importance.png'
)
Parameters¶
- values (
array-like
): Importance values for each feature - feature_names (
list
, optional): Feature names. If None, uses generic names - title (
str
): Plot title - sort (
bool
): Whether to sort features by absolute importance (default: True) - max_features (
int
, optional): Maximum number of features to display - color_positive (
str
): Color for positive values (default: 'blue') - color_negative (
str
): Color for negative values (default: 'red') - figsize (
tuple
): Figure size as (width, height) (default: (12, 8)) - save_path (
str
): Directory to save the plot (default: 'graphs') - filename (
str
): Filename for saved plot (default: 'feature_importance.png')
Returns¶
- fig (
matplotlib.figure.Figure
): Figure object - ax (
matplotlib.axes.Axes
): Axes object
Example¶
import numpy as np
from chemxai.plots import horizontal_bar_plot
# Sample feature importance values
importance_values = np.random.randn(20)
feature_names = [f"Bit_{i}" for i in range(20)]
fig, ax = horizontal_bar_plot(
values=importance_values,
feature_names=feature_names,
title="Morgan Fingerprint Bit Importance",
max_features=15,
color_positive='green',
color_negative='red',
save_path='results',
filename='fingerprint_importance.png'
)
Usage Examples¶
Complete Visualization Workflow¶
import numpy as np
import matplotlib.pyplot as plt
from chemxai.plots import radar_plot, horizontal_bar_plot
# Sample explanation results
explanation_values = np.random.randn(50)
feature_names = [f"Feature_{i}" for i in range(50)]
# Create radar plot for top features
top_10_indices = np.argsort(np.abs(explanation_values))[-10:]
top_10_values = explanation_values[top_10_indices]
top_10_names = [feature_names[i] for i in top_10_indices]
radar_fig, radar_ax = radar_plot(
values=top_10_values,
feature_names=top_10_names,
title="Top 10 Most Important Features"
)
# Create horizontal bar plot for all features
bar_fig, bar_ax = horizontal_bar_plot(
values=explanation_values,
feature_names=feature_names,
title="Feature Importance Analysis",
max_features=20,
sort=True,
save_path='visualization_results',
filename='feature_importance_bars.png'
)
plt.show()
Molecular Fingerprint Visualization¶
# Example for visualizing Morgan fingerprint explanations
morgan_bits = 512
shap_values = np.random.randn(morgan_bits) # Replace with actual SHAP values
# Create bit names
bit_names = [f"Morgan_Bit_{i}" for i in range(morgan_bits)]
# Visualize top 30 most important bits
fig, ax = horizontal_bar_plot(
values=shap_values,
feature_names=bit_names,
title="Morgan Fingerprint Bit Importance (SHAP)",
max_features=30,
color_positive='darkblue',
color_negative='darkred',
figsize=(14, 10),
save_path='molecular_explanations',
filename='morgan_bits_shap.png'
)
Technical Details¶
Plot Customization¶
All plotting functions support extensive customization: - Colors: Separate colors for positive and negative values - Sizing: Configurable figure dimensions - Sorting: Automatic sorting by absolute importance - Filtering: Limit number of displayed features
File Handling¶
- Automatic directory creation for save paths
- High-resolution output (300 DPI)
- Tight bounding box for clean exports
- Support for various image formats
Integration with Explainers¶
The plotting functions are designed to work seamlessly with explainer outputs: - Handle both single-instance and multi-instance explanations - Automatic feature name generation when not provided - Robust handling of different data types and shapes
Performance Considerations¶
- Efficient handling of large feature spaces
- Memory-efficient plotting for high-dimensional data
- Fast rendering for interactive exploration