Skip to content

Molax Feature Roadmap

A prioritized roadmap of killer features for molax, focusing on advancing uncertainty quantification for molecular active learning.

Target Audiences: - Drug discovery researchers seeking reliable predictions with confidence estimates - ML researchers exploring uncertainty quantification and active learning methods

Scope: 2D graph-based molecular representations only (no 3D conformers).


Current Capabilities

Feature Status
Efficient jraph batching (~400x speedup)
GCN with uncertainty head
MC Dropout uncertainty
Uncertainty sampling
Diversity sampling
Combined acquisition
ESOL dataset support
Flax NNX integration

Phase 1: Uncertainty Excellence (High Priority)

Better uncertainty quantification is the core differentiator for active learning. These features directly improve the quality and reliability of uncertainty estimates.

1.1 Deep Ensembles ✅

Status: Implemented in molax/models/ensemble.py

What: Train N independent GCN models with different random initializations; use prediction disagreement as uncertainty.

Why: Ensembles consistently outperform single-model uncertainty methods. They capture both aleatoric (data) and epistemic (model) uncertainty.

Implementation:

# molax/models/ensemble.py
from dataclasses import dataclass
from flax import nnx
import jax.numpy as jnp
from molax.models.gcn import GCNConfig, UncertaintyGCN

@dataclass
class EnsembleConfig:
    base_config: GCNConfig
    n_members: int = 5

class DeepEnsemble(nnx.Module):
    def __init__(self, config: EnsembleConfig, rngs: nnx.Rngs):
        self.members = [
            UncertaintyGCN(config.base_config, rngs=nnx.Rngs(i))
            for i in range(config.n_members)
        ]

    def __call__(self, graphs, training: bool = False):
        # Collect predictions from all members
        predictions = [m(graphs, training=training) for m in self.members]
        means = jnp.stack([p[0] for p in predictions])  # (N, batch)

        # Ensemble mean and variance
        ensemble_mean = jnp.mean(means, axis=0)
        epistemic_var = jnp.var(means, axis=0)  # Disagreement
        aleatoric_var = jnp.mean(jnp.stack([p[1] for p in predictions]), axis=0)
        total_var = epistemic_var + aleatoric_var

        return ensemble_mean, total_var, epistemic_var

Acceptance Criteria: - [x] DeepEnsemble class with configurable number of members - [x] Separate epistemic and aleatoric uncertainty outputs - [x] Parallel training support for ensemble members - [x] Tests comparing ensemble vs MC Dropout uncertainty quality


1.2 Evidential Deep Learning ✅

Status: Implemented in molax/models/evidential.py

What: Directly predict uncertainty without MC sampling by modeling output as a higher-order distribution (Normal-Inverse-Gamma).

Why: Single forward pass for uncertainty (faster inference), well-calibrated for out-of-distribution detection.

Reference: Amini et al., NeurIPS 2020

Implementation:

# molax/models/evidential.py
import jax.numpy as jnp
from flax import nnx

class EvidentialHead(nnx.Module):
    """Predicts Normal-Inverse-Gamma parameters for evidential regression."""

    def __init__(self, in_features: int, rngs: nnx.Rngs):
        # Output: (gamma, nu, alpha, beta) - NIG parameters
        self.linear = nnx.Linear(in_features, 4, rngs=rngs)

    def __call__(self, x):
        out = self.linear(x)
        # Ensure valid parameter ranges
        gamma = out[..., 0]  # Mean prediction
        nu = nnx.softplus(out[..., 1]) + 1e-6  # > 0
        alpha = nnx.softplus(out[..., 2]) + 1.0  # > 1
        beta = nnx.softplus(out[..., 3]) + 1e-6  # > 0
        return gamma, nu, alpha, beta

def evidential_loss(gamma, nu, alpha, beta, targets, lambda_reg=0.1):
    """NIG negative log-likelihood with regularization."""
    omega = 2 * beta * (1 + nu)
    nll = (
        0.5 * jnp.log(jnp.pi / nu)
        - alpha * jnp.log(omega)
        + (alpha + 0.5) * jnp.log((targets - gamma)**2 * nu + omega)
        + jnp.lgamma(alpha) - jnp.lgamma(alpha + 0.5)
    )
    # Regularize evidence on errors
    reg = lambda_reg * jnp.abs(targets - gamma) * (2 * nu + alpha)
    return jnp.mean(nll + reg)

def evidential_uncertainty(nu, alpha, beta):
    """Extract aleatoric and epistemic uncertainty from NIG params."""
    aleatoric = beta / (alpha - 1)  # Expected variance
    epistemic = aleatoric / nu      # Uncertainty in the variance
    return aleatoric, epistemic

Acceptance Criteria: - [x] EvidentialGCN model variant - [x] NIG loss function with configurable regularization - [x] Separate aleatoric/epistemic uncertainty outputs - [x] Comparison with MC Dropout on OOD detection (in tests)


1.3 Calibration Metrics ✅

Status: Implemented in molax/metrics/

What: Quantify how well predicted uncertainties match actual error frequencies.

Why: Raw uncertainties are meaningless without calibration. These metrics let users trust the confidence estimates.

Implementation:

# molax/metrics/calibration.py
from molax.metrics import (
    expected_calibration_error,
    negative_log_likelihood,
    compute_calibration_curve,
    sharpness,
    evaluate_calibration,
    TemperatureScaling,
    plot_reliability_diagram,
    plot_calibration_comparison,
    create_calibration_report,
)

# Compute ECE
ece = expected_calibration_error(predictions, uncertainties, targets, n_bins=10)

# Compute NLL (proper scoring rule)
nll = negative_log_likelihood(mean, var, targets)

# Comprehensive evaluation
metrics = evaluate_calibration(mean, var, targets)
# Returns: {'nll': ..., 'ece': ..., 'rmse': ..., 'sharpness': ..., 'mean_z_score': ...}

# Temperature scaling for post-hoc calibration
scaler = TemperatureScaling()
scaler.fit(val_mean, val_var, val_targets)
calibrated_var = scaler.transform(test_var)
print(f"Learned temperature: {scaler.temperature}")

# Visualization
plot_reliability_diagram(predictions, uncertainties, targets)
fig = plot_calibration_comparison({
    "Model A": (preds_a, var_a, targets),
    "Model B": (preds_b, var_b, targets),
})

Acceptance Criteria: - [x] ECE computation (Expected Calibration Error) - [x] Reliability diagram plotting utility - [x] NLL as proper scoring rule - [x] Temperature scaling for post-hoc calibration - [x] Integration into evaluation pipeline


Phase 2: Advanced Acquisition Strategies

Better acquisition functions select more informative samples, improving data efficiency.

2.1 BALD (Bayesian Active Learning by Disagreement)

What: Maximize mutual information between predictions and model parameters.

Why: Theoretically principled; targets samples that maximally reduce model uncertainty.

Implementation:

# molax/acquisition/bald.py
import jax.numpy as jnp

def bald_acquisition(
    model,
    graphs,
    n_mc_samples: int = 20,
    rngs: jnp.ndarray = None
) -> jnp.ndarray:
    """
    BALD = H[y|x, D] - E_{theta}[H[y|x, theta]]
    = Total uncertainty - Expected aleatoric uncertainty
    """
    # Collect MC samples
    mc_means = []
    mc_vars = []
    for i in range(n_mc_samples):
        mean, var = model(graphs, training=True)  # Dropout active
        mc_means.append(mean)
        mc_vars.append(var)

    mc_means = jnp.stack(mc_means)  # (n_mc, n_samples)
    mc_vars = jnp.stack(mc_vars)

    # Total uncertainty (entropy of predictive distribution)
    predictive_mean = jnp.mean(mc_means, axis=0)
    predictive_var = jnp.var(mc_means, axis=0) + jnp.mean(mc_vars, axis=0)
    total_entropy = 0.5 * jnp.log(2 * jnp.pi * jnp.e * predictive_var)

    # Expected aleatoric uncertainty
    expected_entropy = 0.5 * jnp.mean(jnp.log(2 * jnp.pi * jnp.e * mc_vars), axis=0)

    # BALD score = mutual information
    return total_entropy - expected_entropy

Acceptance Criteria: - [ ] bald_acquisition function - [ ] Efficient batched MC sampling - [ ] Comparison benchmark vs uncertainty sampling


2.2 Core-Set Selection

What: Select samples that maximize coverage of the feature space using K-center algorithm.

Why: Ensures diversity in learned representations, not just input space.

Implementation:

# molax/acquisition/coreset.py
import jax.numpy as jnp

def extract_embeddings(model, graphs) -> jnp.ndarray:
    """Get penultimate layer representations."""
    # Add embedding extraction hook to model
    pass

def k_center_greedy(
    embeddings: jnp.ndarray,
    labeled_mask: jnp.ndarray,
    n_select: int
) -> jnp.ndarray:
    """
    Greedy K-center: iteratively select point furthest from labeled set.
    """
    n_samples = embeddings.shape[0]
    selected = jnp.where(labeled_mask)[0]

    # Compute pairwise distances once
    distances = jnp.linalg.norm(
        embeddings[:, None] - embeddings[None, :], axis=-1
    )

    for _ in range(n_select):
        # Distance from each point to nearest labeled point
        min_dist_to_labeled = jnp.min(distances[:, selected], axis=1)
        min_dist_to_labeled = jnp.where(labeled_mask, -jnp.inf, min_dist_to_labeled)

        # Select furthest point
        new_idx = jnp.argmax(min_dist_to_labeled)
        selected = jnp.append(selected, new_idx)
        labeled_mask = labeled_mask.at[new_idx].set(True)

    return selected[-n_select:]

Acceptance Criteria: - [ ] Embedding extraction from any model layer - [ ] K-center greedy implementation - [ ] GPU-accelerated distance computations


2.3 Batch-Aware Acquisition

What: When selecting K samples, account for redundancy between them.

Why: Naive top-K selection often picks near-duplicates; batch-aware methods improve diversity.

Implementation:

# molax/acquisition/batch.py

def batch_bald(
    model, graphs, n_select: int, n_mc_samples: int = 20
) -> jnp.ndarray:
    """
    BatchBALD: Select batch jointly to maximize mutual information.
    Approximated via greedy selection with joint entropy tracking.
    """
    pass

def determinantal_point_process(
    scores: jnp.ndarray,
    similarity_matrix: jnp.ndarray,
    n_select: int
) -> jnp.ndarray:
    """
    DPP sampling: balance high scores with diversity.
    Uses fast greedy MAP inference.
    """
    pass

Acceptance Criteria: - [ ] BatchBALD implementation - [ ] DPP-based diverse selection - [ ] Configurable diversity-quality tradeoff


2.4 Expected Model Change

What: Select samples that would maximally change model predictions if labeled.

Why: Directly targets samples that affect the model most, regardless of current uncertainty.

Implementation:

# molax/acquisition/emc.py

def expected_gradient_length(model, graphs, labels_placeholder):
    """
    EGL: Use gradient magnitude as proxy for influence.
    """
    def loss_fn(model, x, y):
        mean, _ = model(x, training=False)
        return jnp.mean((mean - y)**2)

    # Compute gradient for hypothetical labels (use predicted mean)
    predicted_mean, _ = model(graphs, training=False)
    grads = jax.grad(loss_fn)(model, graphs, predicted_mean)

    # Gradient magnitude per sample
    return jnp.linalg.norm(grads, axis=-1)

Acceptance Criteria: - [ ] Expected Gradient Length implementation - [ ] Fisher Information-based variant - [ ] Efficient gradient computation


Phase 3: Architecture Diversity

Multiple architectures capture different inductive biases about molecular structure.

3.1 Message Passing Neural Network (MPNN) ✅

Status: Implemented in molax/models/mpnn.py

What: Generalized framework with explicit edge feature processing.

Why: Enables richer molecular representations using bond features.

Implementation:

# molax/models/mpnn.py
from molax.models.mpnn import MPNNConfig, UncertaintyMPNN

config = MPNNConfig(
    node_features=6,
    edge_features=1,  # Bond type feature
    hidden_features=[64, 64],
    out_features=1,
    aggregation="sum",  # or "mean", "max"
    dropout_rate=0.1,
)
model = UncertaintyMPNN(config, rngs=nnx.Rngs(0))

# Same API as UncertaintyGCN
mean, variance = model(batched_graphs, training=False)

# Extract embeddings for Core-Set selection
embeddings = model.extract_embeddings(batched_graphs)

Acceptance Criteria: - [x] MPNN with edge feature support - [x] Configurable aggregation (sum, mean, max) - [x] Same API as UncertaintyGCN for acquisition function compatibility - [x] MC Dropout uncertainty via get_mpnn_uncertainties()


3.2 Graph Attention Network (GAT) ✅

Status: Implemented in molax/models/gat.py

What: Learn edge importance dynamically via attention mechanism.

Why: Adaptively weights neighbor contributions based on learned relevance.

Implementation:

# molax/models/gat.py
from molax.models.gat import GATConfig, UncertaintyGAT

config = GATConfig(
    node_features=6,
    edge_features=1,  # Optional: include edge features in attention
    hidden_features=[64, 64],
    out_features=1,
    n_heads=4,
    dropout_rate=0.1,
    attention_dropout_rate=0.1,
    negative_slope=0.2,
)
model = UncertaintyGAT(config, rngs=nnx.Rngs(0))

# Same API as UncertaintyGCN/UncertaintyMPNN
mean, variance = model(batched_graphs, training=False)

# Extract embeddings for Core-Set selection
embeddings = model.extract_embeddings(batched_graphs)

Acceptance Criteria: - [x] Multi-head attention implementation - [x] Edge feature incorporation option - [x] Dropout on attention weights - [x] Same API as UncertaintyGCN/UncertaintyMPNN for acquisition function compatibility


3.3 SchNet (Continuous-Filter Convolutions)

What: Distance-based convolutions using RBF expansions (adapted for 2D without coordinates).

Why: Smooth distance-aware aggregation; useful when edge weights encode bond lengths.

Note: For 2D graphs, use topological distance (shortest path) or bond order as edge weights.


3.4 Graph Transformer ✅

Status: Implemented in molax/models/graph_transformer.py

What: Full self-attention over molecular graphs with positional encodings.

Why: State-of-the-art performance; captures long-range dependencies.

Implementation:

# molax/models/graph_transformer.py
from molax.models.graph_transformer import GraphTransformerConfig, UncertaintyGraphTransformer

config = GraphTransformerConfig(
    node_features=6,
    edge_features=1,  # Optional: include edge features as attention bias
    hidden_features=[64, 64],
    out_features=1,
    n_heads=4,
    ffn_ratio=4.0,  # FFN hidden dim = 4 * model dim
    dropout_rate=0.1,
    attention_dropout_rate=0.1,
    pe_type="rwpe",  # Random Walk PE (or "laplacian", "none")
    pe_dim=16,
)
model = UncertaintyGraphTransformer(config, rngs=nnx.Rngs(0))

# Same API as UncertaintyGCN/UncertaintyMPNN/UncertaintyGAT
mean, variance = model(batched_graphs, training=False)

# Extract embeddings for Core-Set selection
embeddings = model.extract_embeddings(batched_graphs)

Acceptance Criteria: - [x] Graph-aware attention masking - [x] Positional encodings (Laplacian eigenvectors, random walk) - [x] Configurable depth and width - [x] Same API as UncertaintyGCN/UncertaintyMPNN/UncertaintyGAT for acquisition function compatibility


Phase 4: Rich Molecular Featurization

Better input features directly improve model capacity.

4.1 Extended Node Features

Current: 6 features (atomic num, degree, charge, chirality, hybridization, aromaticity)

Proposed: 20+ features including:

# molax/utils/featurizers.py

def extended_atom_features(atom) -> list:
    """Comprehensive RDKit atom features."""
    return [
        # Current features
        atom.GetAtomicNum(),
        atom.GetDegree(),
        atom.GetFormalCharge(),
        int(atom.GetChiralTag()),
        int(atom.GetHybridization()),
        int(atom.GetIsAromatic()),

        # Ring features
        atom.IsInRing(),
        atom.IsInRingSize(3),
        atom.IsInRingSize(4),
        atom.IsInRingSize(5),
        atom.IsInRingSize(6),

        # Electronic features
        atom.GetNumRadicalElectrons(),
        atom.GetNumImplicitHs(),
        atom.GetNumExplicitHs(),

        # Neighborhood
        atom.GetTotalNumHs(),
        atom.GetTotalDegree(),

        # Pharmacophore-related
        is_hydrogen_donor(atom),
        is_hydrogen_acceptor(atom),

        # Electronegativity (from table)
        ELECTRONEGATIVITY.get(atom.GetAtomicNum(), 0),

        # Atomic mass
        atom.GetMass(),
    ]

Acceptance Criteria: - [ ] Configurable feature sets (minimal, standard, extended) - [ ] Feature normalization utilities - [ ] Documentation of each feature


4.2 Edge Feature Support

What: Include bond features in message passing.

def bond_features(bond) -> list:
    """RDKit bond features."""
    return [
        int(bond.GetBondType()),  # Single, double, triple, aromatic
        bond.GetIsConjugated(),
        bond.IsInRing(),
        int(bond.GetStereo()),  # Stereochemistry
    ]

Acceptance Criteria: - [ ] Edge features in smiles_to_jraph - [ ] Models that consume edge features (MPNN, GAT)


4.3 Pre-trained Embedding Integration

What: Use embeddings from pre-trained molecular language models.

Why: Transfer learning from large-scale pre-training.

# molax/utils/pretrained.py

def load_chemberta_embeddings(smiles_list: list[str]) -> jnp.ndarray:
    """Extract ChemBERTa [CLS] token embeddings."""
    from transformers import AutoModel, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

    inputs = tokenizer(smiles_list, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    return jnp.array(outputs.last_hidden_state[:, 0, :])  # [CLS] token

Acceptance Criteria: - [ ] ChemBERTa integration - [ ] Caching for efficiency - [ ] Option to use as initial node features or auxiliary input


Phase 5: Production Readiness

Features required for real-world deployment.

5.1 Model Checkpointing

What: Save and load model state.

# molax/utils/checkpointing.py
import orbax.checkpoint as ocp
from flax import nnx

def save_model(model, optimizer, path: str, step: int):
    """Save model and optimizer state."""
    state = nnx.state((model, optimizer))
    checkpointer = ocp.PyTreeCheckpointer()
    checkpointer.save(f"{path}/step_{step}", state)

def load_model(model, optimizer, path: str):
    """Restore model and optimizer state."""
    checkpointer = ocp.PyTreeCheckpointer()
    state = checkpointer.restore(path)
    nnx.update((model, optimizer), state)

Acceptance Criteria: - [ ] Orbax-based save/load - [ ] Checkpoint management (keep last N) - [ ] Resume training from checkpoint


5.2 Multi-Dataset Support

What: Support MoleculeNet benchmarks.

# molax/datasets/moleculenet.py

MOLECULENET_DATASETS = {
    'esol': {'task': 'regression', 'n_tasks': 1, 'metric': 'rmse'},
    'freesolv': {'task': 'regression', 'n_tasks': 1, 'metric': 'rmse'},
    'lipophilicity': {'task': 'regression', 'n_tasks': 1, 'metric': 'rmse'},
    'bbbp': {'task': 'classification', 'n_tasks': 1, 'metric': 'auroc'},
    'tox21': {'task': 'classification', 'n_tasks': 12, 'metric': 'auroc'},
    'sider': {'task': 'classification', 'n_tasks': 27, 'metric': 'auroc'},
    'clintox': {'task': 'classification', 'n_tasks': 2, 'metric': 'auroc'},
    'muv': {'task': 'classification', 'n_tasks': 17, 'metric': 'prc-auc'},
    'hiv': {'task': 'classification', 'n_tasks': 1, 'metric': 'auroc'},
    'bace': {'task': 'classification', 'n_tasks': 1, 'metric': 'auroc'},
}

def load_moleculenet(name: str) -> MolecularDataset:
    """Load a MoleculeNet dataset."""
    pass

Acceptance Criteria: - [ ] All MoleculeNet datasets supported - [ ] Scaffold and random split utilities - [ ] Classification task support


5.3 Hyperparameter Optimization

What: Automated hyperparameter tuning.

# molax/tuning/optuna_search.py
import optuna

def create_objective(dataset, n_epochs):
    def objective(trial):
        config = GCNConfig(
            hidden_features=[
                trial.suggest_int('hidden_dim', 32, 256),
            ] * trial.suggest_int('n_layers', 1, 4),
            dropout_rate=trial.suggest_float('dropout', 0.0, 0.5),
        )
        # Train and evaluate
        model = UncertaintyGCN(config, rngs=nnx.Rngs(0))
        # ... training loop
        return val_loss
    return objective

def run_hpo(dataset, n_trials=100):
    study = optuna.create_study(direction='minimize')
    study.optimize(create_objective(dataset, n_epochs=50), n_trials=n_trials)
    return study.best_params

Acceptance Criteria: - [ ] Optuna integration - [ ] Search space definitions for each model - [ ] Pruning for early stopping


5.4 Experiment Tracking

What: Log metrics, hyperparameters, and artifacts.

# molax/tracking/wandb_logger.py
import wandb

class WandbLogger:
    def __init__(self, project: str, config: dict):
        wandb.init(project=project, config=config)

    def log(self, metrics: dict, step: int):
        wandb.log(metrics, step=step)

    def log_model(self, model_path: str):
        wandb.save(model_path)

Acceptance Criteria: - [ ] W&B integration - [ ] MLflow as alternative - [ ] Automatic logging in training loop


Phase 6: Advanced ML Research

Features for ML researchers pushing the boundaries.

6.1 Multi-Task Learning

What: Predict multiple properties with shared GNN backbone.

class MultiTaskGCN(nnx.Module):
    def __init__(self, config, n_tasks, rngs):
        self.backbone = GCNBackbone(config, rngs)
        self.heads = [
            UncertaintyHead(config.hidden_features[-1], rngs)
            for _ in range(n_tasks)
        ]

    def __call__(self, graphs, training=False):
        embeddings = self.backbone(graphs, training)
        return [head(embeddings) for head in self.heads]

Acceptance Criteria: - [ ] Multi-head architecture - [ ] Task weighting strategies - [ ] Uncertainty per task


6.2 Transfer Learning

What: Pre-train on large dataset, fine-tune on small target dataset.

Acceptance Criteria: - [ ] Backbone freezing/unfreezing - [ ] Learning rate schedules for fine-tuning - [ ] Pre-trained weights for common backbones


6.3 Semi-Supervised Learning

What: Leverage unlabeled molecules to improve representations.

def consistency_loss(model, unlabeled_graphs, rngs):
    """Encourage consistent predictions under augmentation."""
    # Two forward passes with different dropout
    pred1, _ = model(unlabeled_graphs, training=True)
    pred2, _ = model(unlabeled_graphs, training=True)
    return jnp.mean((pred1 - pred2)**2)

Acceptance Criteria: - [ ] Consistency regularization - [ ] Pseudo-labeling - [ ] Graph augmentation utilities


6.4 Meta-Learning (MAML)

What: Learn to adapt quickly to new molecular tasks with few examples.

Why: Critical for low-data drug discovery scenarios.

def maml_inner_loop(model, support_graphs, support_labels, inner_lr, n_steps):
    """Adapt model to support set."""
    adapted_model = model.clone()  # Create copy

    for _ in range(n_steps):
        loss, grads = nnx.value_and_grad(loss_fn)(adapted_model, support_graphs, support_labels)
        # Manual gradient descent
        adapted_model = jax.tree_map(
            lambda p, g: p - inner_lr * g,
            adapted_model, grads
        )

    return adapted_model

Acceptance Criteria: - [ ] MAML implementation for few-shot property prediction - [ ] Task distribution utilities - [ ] First-order approximation option


Implementation Priority

Phase Timeline Impact Effort
1. Uncertainty Excellence High High Medium
2. Advanced Acquisition High High Medium
3. Architecture Diversity Medium Medium High
4. Rich Featurization Medium Medium Low
5. Production Readiness Medium High Medium
6. Advanced ML Low Medium High

Contributing

We welcome contributions! Priority areas: 1. Uncertainty methods - Ensembles, evidential learning 2. Calibration tools - Metrics and visualization 3. Acquisition functions - BALD, batch-aware selection

See CONTRIBUTING.md for guidelines.