Models API
This page documents the neural network models available in molax for molecular property prediction with uncertainty quantification.
GCN Models
The core Graph Convolutional Network models with MC Dropout uncertainty.
GCNConfig
molax.models.gcn.GCNConfig
dataclass
Configuration for Graph Convolutional Network.
Attributes:
| Name | Type | Description |
|---|---|---|
node_features |
int
|
Input node feature dimension |
hidden_features |
Sequence[int]
|
List of hidden layer dimensions |
out_features |
int
|
Output dimension |
dropout_rate |
float
|
Dropout rate for regularization |
UncertaintyGCN
molax.models.gcn.UncertaintyGCN
Bases: Module
GCN with uncertainty estimation via mean and variance heads.
Outputs both mean prediction and predicted variance for uncertainty quantification.
__call__
Forward pass returning mean and variance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched (and possibly padded) jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ndarray]
|
Tuple of (mean, variance) each of shape [n_graphs, out_features] |
extract_embeddings
Extract penultimate layer embeddings (graph-level).
Extracts the pooled graph representations before the output heads, which can be used for Core-Set selection, DPP sampling, or other embedding-based acquisition strategies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode (enables dropout) |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Embeddings of shape [n_graphs, hidden_dim] where hidden_dim |
ndarray
|
is the last element of hidden_features in the config. |
MolecularGCN
molax.models.gcn.MolecularGCN
Bases: Module
Graph Convolutional Network for molecular property prediction.
Uses jraph for efficient batched processing of variable-sized graphs.
__call__
Forward pass through the GCN.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched (and possibly padded) jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode (enables dropout) |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Graph-level predictions of shape [n_graphs, out_features] |
Deep Ensembles
Ensemble methods for improved uncertainty quantification through model disagreement.
EnsembleConfig
molax.models.ensemble.EnsembleConfig
dataclass
Configuration for Deep Ensemble.
Attributes:
| Name | Type | Description |
|---|---|---|
base_config |
GCNConfig
|
Configuration for each ensemble member (GCNConfig) |
n_members |
int
|
Number of ensemble members (default: 5) |
DeepEnsemble
molax.models.ensemble.DeepEnsemble
Bases: Module
Deep Ensemble of UncertaintyGCN models.
Trains N independent GCN models with different random initializations. Provides improved uncertainty estimation by decomposing into: - Epistemic uncertainty: disagreement between models (reducible with more data) - Aleatoric uncertainty: average predicted variance (inherent noise)
__init__
Initialize ensemble with N independent models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EnsembleConfig
|
EnsembleConfig with base model config and n_members |
required |
rngs
|
Rngs
|
Random number generators for initialization |
required |
__call__
__call__(
graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Forward pass returning mean, total uncertainty, and epistemic uncertainty.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Tuple of: |
ndarray
|
|
ndarray
|
|
Tuple[ndarray, ndarray, ndarray]
|
|
Tuple[ndarray, ndarray, ndarray]
|
Each has shape [n_graphs, out_features]. |
predict_member
predict_member(
member_idx: int, graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray]
Get prediction from a specific ensemble member.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
member_idx
|
int
|
Index of the ensemble member (0 to n_members-1) |
required |
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ndarray]
|
Tuple of (mean, variance) from the specified member |
extract_embeddings
Extract averaged embeddings from all ensemble members.
Each member extracts embeddings independently, and the results are averaged to produce a single embedding per graph. This can be used for Core-Set selection, DPP sampling, or other embedding-based acquisition strategies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Averaged embeddings of shape [n_graphs, hidden_dim] |
Evidential Deep Learning
Single-pass uncertainty estimation using evidential neural networks.
EvidentialConfig
molax.models.evidential.EvidentialConfig
dataclass
Configuration for Evidential GCN.
Attributes:
| Name | Type | Description |
|---|---|---|
base_config |
GCNConfig
|
Configuration for the GCN backbone (GCNConfig) |
lambda_reg |
float
|
Regularization weight for evidence on errors (default: 0.1) |
EvidentialGCN
molax.models.evidential.EvidentialGCN
Bases: Module
GCN with Evidential Deep Learning for uncertainty quantification.
Uses a GCN backbone followed by an evidential head that predicts Normal-Inverse-Gamma parameters, enabling single-pass uncertainty estimation with separation of aleatoric and epistemic components.
__init__
Initialize EvidentialGCN.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EvidentialConfig
|
EvidentialConfig with base model config and lambda_reg |
required |
rngs
|
Rngs
|
Random number generators for initialization |
required |
forward_raw
forward_raw(
graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
Forward pass returning raw NIG parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Tuple of (gamma, nu, alpha, beta) - NIG parameters |
ndarray
|
Each has shape [n_graphs, 1] |
__call__
__call__(
graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Forward pass returning mean, total uncertainty, and epistemic uncertainty.
This signature matches DeepEnsemble for drop-in replacement.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Tuple of: |
ndarray
|
|
ndarray
|
|
Tuple[ndarray, ndarray, ndarray]
|
|
extract_embeddings
Extract penultimate layer embeddings (graph-level).
Extracts the pooled graph representations before the evidential head, which can be used for Core-Set selection, DPP sampling, or other embedding-based acquisition strategies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graph
|
GraphsTuple
|
Batched jraph.GraphsTuple |
required |
training
|
bool
|
Whether in training mode (enables dropout) |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Embeddings of shape [n_graphs, hidden_dim] where hidden_dim |
ndarray
|
is the last element of hidden_features in the config. |
Training Utilities
molax.models.gcn.train_step
train_step(
model: UncertaintyGCN,
optimizer: Optimizer,
graph: GraphsTuple,
labels: ndarray,
mask: ndarray,
) -> jnp.ndarray
JIT-compiled training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
UncertaintyGCN
|
The model |
required |
optimizer
|
Optimizer
|
The optimizer |
required |
graph
|
GraphsTuple
|
Batched (padded) input graphs |
required |
labels
|
ndarray
|
Target labels of shape [n_graphs] (padded with zeros) |
required |
mask
|
ndarray
|
Boolean mask indicating real graphs (not padding) |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Loss value |
molax.models.gcn.eval_step
eval_step(
model: UncertaintyGCN, graph: GraphsTuple, labels: ndarray, mask: ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]
JIT-compiled evaluation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
UncertaintyGCN
|
The model |
required |
graph
|
GraphsTuple
|
Batched (padded) input graphs |
required |
labels
|
ndarray
|
Target labels (padded) |
required |
mask
|
ndarray
|
Boolean mask for real graphs |
required |
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ndarray]
|
Tuple of (mse, mean_predictions) |