Skip to content

Data Utilities

This page documents the data loading and processing utilities.

Dataset Classes

MolecularDataset

molax.utils.data.MolecularDataset

Dataset class for molecular graphs using jraph format.

Attributes:

Name Type Description
graphs List[GraphsTuple]

List of jraph.GraphsTuple objects

labels List[float]

Array of property labels

n_node_features

Number of node features

__init__

__init__(
    data: Union[DataFrame, str, Path],
    smiles_col: str = "smiles",
    label_col: str = "property",
)

Initialize dataset from DataFrame or CSV file.

Parameters:

Name Type Description Default
data Union[DataFrame, str, Path]

DataFrame or path to CSV file

required
smiles_col str

Column name for SMILES strings

'smiles'
label_col str

Column name for property labels

'property'

get_batched

get_batched(
    indices: Optional[List[int]] = None,
    pad_to_nodes: Optional[int] = None,
    pad_to_edges: Optional[int] = None,
    pad_to_graphs: Optional[int] = None,
) -> Tuple[jraph.GraphsTuple, jnp.ndarray]

Get a batched GraphsTuple for the specified indices.

Parameters:

Name Type Description Default
indices Optional[List[int]]

List of indices to include. If None, returns all data.

None
pad_to_nodes Optional[int]

Pad to this many nodes for consistent JIT shapes

None
pad_to_edges Optional[int]

Pad to this many edges

None
pad_to_graphs Optional[int]

Pad to this many graphs

None

Returns:

Type Description
Tuple[GraphsTuple, ndarray]

Tuple of (batched_graphs, labels)

compute_padding_sizes

compute_padding_sizes(batch_size: int) -> Tuple[int, int, int]

Compute fixed padding sizes for efficient JIT compilation.

Parameters:

Name Type Description Default
batch_size int

Maximum batch size

required

Returns:

Type Description
Tuple[int, int, int]

Tuple of (max_nodes, max_edges, n_graphs) for padding

split

split(
    test_size: float = 0.2, seed: Optional[int] = None
) -> Tuple[MolecularDataset, MolecularDataset]

Split dataset into train and test sets.

Parameters:

Name Type Description Default
test_size float

Fraction for test set

0.2
seed Optional[int]

Random seed for reproducibility

None

Returns:

Type Description
Tuple[MolecularDataset, MolecularDataset]

Tuple of (train_dataset, test_dataset)


Graph Conversion

Functions for converting molecular representations to graph format.

smiles_to_jraph

molax.utils.data.smiles_to_jraph

smiles_to_jraph(smiles: str) -> jraph.GraphsTuple

Convert SMILES string to jraph GraphsTuple format.

Parameters:

Name Type Description Default
smiles str

SMILES string representing the molecule

required

Returns:

Type Description
GraphsTuple

jraph.GraphsTuple containing the molecular graph

Raises:

Type Description
ValueError

If the SMILES string is invalid

batch_graphs

molax.utils.data.batch_graphs

batch_graphs(
    graphs: List[GraphsTuple],
    pad_to_nodes: Optional[int] = None,
    pad_to_edges: Optional[int] = None,
    pad_to_graphs: Optional[int] = None,
) -> jraph.GraphsTuple

Batch multiple graphs into a single padded GraphsTuple.

Padding ensures consistent shapes for JIT compilation efficiency.

Parameters:

Name Type Description Default
graphs List[GraphsTuple]

List of individual GraphsTuple objects

required
pad_to_nodes Optional[int]

Pad total nodes to this number (default: auto)

None
pad_to_edges Optional[int]

Pad total edges to this number (default: auto)

None
pad_to_graphs Optional[int]

Pad to this many graphs (default: len(graphs) + 1)

None

Returns:

Type Description
GraphsTuple

Single batched and padded GraphsTuple

unbatch_graphs

molax.utils.data.unbatch_graphs

unbatch_graphs(batched: GraphsTuple) -> List[jraph.GraphsTuple]

Unbatch a batched GraphsTuple back to individual graphs.

Parameters:

Name Type Description Default
batched GraphsTuple

Batched GraphsTuple

required

Returns:

Type Description
List[GraphsTuple]

List of individual GraphsTuple objects