Skip to content

Models API

This page documents the neural network models available in molax for molecular property prediction with uncertainty quantification.

GCN Models

The core Graph Convolutional Network models with MC Dropout uncertainty.

GCNConfig

molax.models.gcn.GCNConfig dataclass

Configuration for Graph Convolutional Network.

Attributes:

Name Type Description
node_features int

Input node feature dimension

hidden_features Sequence[int]

List of hidden layer dimensions

out_features int

Output dimension

dropout_rate float

Dropout rate for regularization

UncertaintyGCN

molax.models.gcn.UncertaintyGCN

Bases: Module

GCN with uncertainty estimation via mean and variance heads.

Outputs both mean prediction and predicted variance for uncertainty quantification.

__call__

__call__(
    graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray]

Forward pass returning mean and variance.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched (and possibly padded) jraph.GraphsTuple

required
training bool

Whether in training mode

False

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple of (mean, variance) each of shape [n_graphs, out_features]

extract_embeddings

extract_embeddings(graph: GraphsTuple, training: bool = False) -> jnp.ndarray

Extract penultimate layer embeddings (graph-level).

Extracts the pooled graph representations before the output heads, which can be used for Core-Set selection, DPP sampling, or other embedding-based acquisition strategies.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode (enables dropout)

False

Returns:

Type Description
ndarray

Embeddings of shape [n_graphs, hidden_dim] where hidden_dim

ndarray

is the last element of hidden_features in the config.

MolecularGCN

molax.models.gcn.MolecularGCN

Bases: Module

Graph Convolutional Network for molecular property prediction.

Uses jraph for efficient batched processing of variable-sized graphs.

__call__

__call__(graph: GraphsTuple, training: bool = False) -> jnp.ndarray

Forward pass through the GCN.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched (and possibly padded) jraph.GraphsTuple

required
training bool

Whether in training mode (enables dropout)

False

Returns:

Type Description
ndarray

Graph-level predictions of shape [n_graphs, out_features]


Deep Ensembles

Ensemble methods for improved uncertainty quantification through model disagreement.

EnsembleConfig

molax.models.ensemble.EnsembleConfig dataclass

Configuration for Deep Ensemble.

Attributes:

Name Type Description
base_config GCNConfig

Configuration for each ensemble member (GCNConfig)

n_members int

Number of ensemble members (default: 5)

DeepEnsemble

molax.models.ensemble.DeepEnsemble

Bases: Module

Deep Ensemble of UncertaintyGCN models.

Trains N independent GCN models with different random initializations. Provides improved uncertainty estimation by decomposing into: - Epistemic uncertainty: disagreement between models (reducible with more data) - Aleatoric uncertainty: average predicted variance (inherent noise)

__init__

__init__(config: EnsembleConfig, rngs: Rngs)

Initialize ensemble with N independent models.

Parameters:

Name Type Description Default
config EnsembleConfig

EnsembleConfig with base model config and n_members

required
rngs Rngs

Random number generators for initialization

required

__call__

__call__(
    graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

Forward pass returning mean, total uncertainty, and epistemic uncertainty.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode

False

Returns:

Type Description
ndarray

Tuple of:

ndarray
  • ensemble_mean: Mean prediction across all members
ndarray
  • total_var: Total uncertainty (epistemic + aleatoric)
Tuple[ndarray, ndarray, ndarray]
  • epistemic_var: Epistemic uncertainty (model disagreement)
Tuple[ndarray, ndarray, ndarray]

Each has shape [n_graphs, out_features].

predict_member

predict_member(
    member_idx: int, graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray]

Get prediction from a specific ensemble member.

Parameters:

Name Type Description Default
member_idx int

Index of the ensemble member (0 to n_members-1)

required
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode

False

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple of (mean, variance) from the specified member

extract_embeddings

extract_embeddings(graph: GraphsTuple, training: bool = False) -> jnp.ndarray

Extract averaged embeddings from all ensemble members.

Each member extracts embeddings independently, and the results are averaged to produce a single embedding per graph. This can be used for Core-Set selection, DPP sampling, or other embedding-based acquisition strategies.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode

False

Returns:

Type Description
ndarray

Averaged embeddings of shape [n_graphs, hidden_dim]


Evidential Deep Learning

Single-pass uncertainty estimation using evidential neural networks.

EvidentialConfig

molax.models.evidential.EvidentialConfig dataclass

Configuration for Evidential GCN.

Attributes:

Name Type Description
base_config GCNConfig

Configuration for the GCN backbone (GCNConfig)

lambda_reg float

Regularization weight for evidence on errors (default: 0.1)

EvidentialGCN

molax.models.evidential.EvidentialGCN

Bases: Module

GCN with Evidential Deep Learning for uncertainty quantification.

Uses a GCN backbone followed by an evidential head that predicts Normal-Inverse-Gamma parameters, enabling single-pass uncertainty estimation with separation of aleatoric and epistemic components.

__init__

__init__(config: EvidentialConfig, rngs: Rngs)

Initialize EvidentialGCN.

Parameters:

Name Type Description Default
config EvidentialConfig

EvidentialConfig with base model config and lambda_reg

required
rngs Rngs

Random number generators for initialization

required

forward_raw

forward_raw(
    graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Forward pass returning raw NIG parameters.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode

False

Returns:

Type Description
ndarray

Tuple of (gamma, nu, alpha, beta) - NIG parameters

ndarray

Each has shape [n_graphs, 1]

__call__

__call__(
    graph: GraphsTuple, training: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

Forward pass returning mean, total uncertainty, and epistemic uncertainty.

This signature matches DeepEnsemble for drop-in replacement.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode

False

Returns:

Type Description
ndarray

Tuple of:

ndarray
  • mean: Mean prediction [n_graphs, 1]
ndarray
  • total_var: Total uncertainty (aleatoric + epistemic) [n_graphs, 1]
Tuple[ndarray, ndarray, ndarray]
  • epistemic_var: Epistemic uncertainty [n_graphs, 1]

extract_embeddings

extract_embeddings(graph: GraphsTuple, training: bool = False) -> jnp.ndarray

Extract penultimate layer embeddings (graph-level).

Extracts the pooled graph representations before the evidential head, which can be used for Core-Set selection, DPP sampling, or other embedding-based acquisition strategies.

Parameters:

Name Type Description Default
graph GraphsTuple

Batched jraph.GraphsTuple

required
training bool

Whether in training mode (enables dropout)

False

Returns:

Type Description
ndarray

Embeddings of shape [n_graphs, hidden_dim] where hidden_dim

ndarray

is the last element of hidden_features in the config.


Training Utilities

molax.models.gcn.train_step

train_step(
    model: UncertaintyGCN,
    optimizer: Optimizer,
    graph: GraphsTuple,
    labels: ndarray,
    mask: ndarray,
) -> jnp.ndarray

JIT-compiled training step.

Parameters:

Name Type Description Default
model UncertaintyGCN

The model

required
optimizer Optimizer

The optimizer

required
graph GraphsTuple

Batched (padded) input graphs

required
labels ndarray

Target labels of shape [n_graphs] (padded with zeros)

required
mask ndarray

Boolean mask indicating real graphs (not padding)

required

Returns:

Type Description
ndarray

Loss value

molax.models.gcn.eval_step

eval_step(
    model: UncertaintyGCN, graph: GraphsTuple, labels: ndarray, mask: ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]

JIT-compiled evaluation step.

Parameters:

Name Type Description Default
model UncertaintyGCN

The model

required
graph GraphsTuple

Batched (padded) input graphs

required
labels ndarray

Target labels (padded)

required
mask ndarray

Boolean mask for real graphs

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple of (mse, mean_predictions)