Skip to content

Calibration Metrics

This page documents the uncertainty calibration metrics and visualization tools.

Calibration Metrics

expected_calibration_error

molax.metrics.calibration.expected_calibration_error

expected_calibration_error(
    predictions: ndarray,
    uncertainties: ndarray,
    targets: ndarray,
    n_bins: int = 10,
    mask: Optional[ndarray] = None,
) -> jnp.ndarray

Compute Expected Calibration Error (ECE) for regression.

ECE measures the average gap between expected and observed confidence across different confidence levels. Perfect calibration = 0.

For regression: ECE = mean(|observed_coverage - expected_coverage|)

Parameters:

Name Type Description Default
predictions ndarray

Predicted means of shape [n_samples]

required
uncertainties ndarray

Predicted variances of shape [n_samples]

required
targets ndarray

True values of shape [n_samples]

required
n_bins int

Number of confidence level bins (default: 10)

10
mask Optional[ndarray]

Optional boolean mask for valid samples

None

Returns:

Type Description
ndarray

Scalar ECE value in [0, 1]. Lower is better.

negative_log_likelihood

molax.metrics.calibration.negative_log_likelihood

negative_log_likelihood(
    mean: ndarray,
    var: ndarray,
    targets: ndarray,
    mask: Optional[ndarray] = None,
) -> jnp.ndarray

Compute Gaussian negative log-likelihood (proper scoring rule).

NLL = 0.5 * (log(2pivar) + (y - mean)^2 / var)

Lower is better. This is the proper scoring rule for probabilistic predictions with Gaussian likelihood.

Parameters:

Name Type Description Default
mean ndarray

Predicted means of shape [n_samples] or [n_samples, 1]

required
var ndarray

Predicted variances of shape [n_samples] or [n_samples, 1]

required
targets ndarray

True values of shape [n_samples] or [n_samples, 1]

required
mask Optional[ndarray]

Optional boolean mask for valid samples (True = include)

None

Returns:

Type Description
ndarray

Scalar NLL value (mean over valid samples)

compute_calibration_curve

molax.metrics.calibration.compute_calibration_curve

compute_calibration_curve(
    predictions: ndarray,
    uncertainties: ndarray,
    targets: ndarray,
    n_bins: int = 10,
    mask: Optional[ndarray] = None,
) -> Dict[str, jnp.ndarray]

Compute data for reliability diagrams.

For regression with Gaussian uncertainty, we compute calibration by checking what fraction of targets fall within various confidence intervals.

For each confidence level p (e.g., 50%, 68%, 90%, 95%): - Compute interval: [mean - z_p * std, mean + z_p * std] - Count fraction of targets within interval (observed coverage) - Perfect calibration: observed coverage = expected coverage

Parameters:

Name Type Description Default
predictions ndarray

Predicted means of shape [n_samples]

required
uncertainties ndarray

Predicted variances of shape [n_samples]

required
targets ndarray

True values of shape [n_samples]

required
n_bins int

Number of confidence level bins (default: 10)

10
mask Optional[ndarray]

Optional boolean mask for valid samples

None

Returns:

Type Description
Dict[str, ndarray]

Dictionary with:

Dict[str, ndarray]
  • expected_coverage: Expected confidence levels (bin centers)
Dict[str, ndarray]
  • observed_coverage: Actual fraction of targets within interval
Dict[str, ndarray]
  • bin_counts: Number of samples per bin (all same for regression)

sharpness

molax.metrics.calibration.sharpness

sharpness(
    uncertainties: ndarray, mask: Optional[ndarray] = None
) -> jnp.ndarray

Compute sharpness (average predicted uncertainty).

Sharpness measures how confident the model is on average. Lower = sharper/more confident predictions.

Note: Sharpness alone doesn't indicate quality - a model can be overconfidently wrong. Use together with calibration metrics.

Parameters:

Name Type Description Default
uncertainties ndarray

Predicted variances of shape [n_samples]

required
mask Optional[ndarray]

Optional boolean mask for valid samples

None

Returns:

Type Description
ndarray

Scalar mean standard deviation

evaluate_calibration

molax.metrics.calibration.evaluate_calibration

evaluate_calibration(
    predictions: ndarray,
    uncertainties: ndarray,
    targets: ndarray,
    mask: Optional[ndarray] = None,
    n_bins: int = 10,
) -> Dict[str, float]

Compute comprehensive calibration metrics.

Convenience function that computes all calibration metrics at once.

Parameters:

Name Type Description Default
predictions ndarray

Predicted means of shape [n_samples]

required
uncertainties ndarray

Predicted variances of shape [n_samples]

required
targets ndarray

True values of shape [n_samples]

required
mask Optional[ndarray]

Optional boolean mask for valid samples

None
n_bins int

Number of bins for ECE computation

10

Returns:

Type Description
Dict[str, float]

Dictionary with:

Dict[str, float]
  • nll: Negative log-likelihood
Dict[str, float]
  • ece: Expected calibration error
Dict[str, float]
  • rmse: Root mean squared error
Dict[str, float]
  • sharpness: Average predicted std
Dict[str, float]
  • mean_z_score: Mean |z-score| (should be ~0.8 for calibrated)

Post-hoc Calibration

TemperatureScaling

molax.metrics.calibration.TemperatureScaling

Temperature scaling for post-hoc calibration.

Temperature scaling learns a single parameter T to scale uncertainties

calibrated_variance = T * predicted_variance

T > 1 increases uncertainty (model is overconfident) T < 1 decreases uncertainty (model is underconfident)

The temperature is optimized to minimize NLL on a validation set.

Usage

scaler = TemperatureScaling() scaler.fit(val_predictions, val_uncertainties, val_targets) calibrated_var = scaler.transform(test_uncertainties)

Reference: Guo et al., "On Calibration of Modern Neural Networks", ICML 2017

temperature property

temperature: float

Get learned temperature value.

is_fitted property

is_fitted: bool

Check if the scaler has been fitted.

__init__

__init__()

Initialize temperature scaler with T=1 (no scaling).

fit

fit(
    val_predictions: ndarray,
    val_uncertainties: ndarray,
    val_targets: ndarray,
    mask: Optional[ndarray] = None,
    max_iter: int = 100,
    lr: float = 0.1,
) -> TemperatureScaling

Optimize temperature on validation data.

Uses gradient descent to minimize NLL with respect to temperature.

Parameters:

Name Type Description Default
val_predictions ndarray

Validation predictions of shape [n_val]

required
val_uncertainties ndarray

Validation variances of shape [n_val]

required
val_targets ndarray

Validation targets of shape [n_val]

required
mask Optional[ndarray]

Optional boolean mask for valid samples

None
max_iter int

Maximum optimization iterations

100
lr float

Learning rate for gradient descent

0.1

Returns:

Type Description
TemperatureScaling

Self for method chaining

transform

transform(uncertainties: ndarray) -> jnp.ndarray

Apply learned temperature scaling to uncertainties.

Parameters:

Name Type Description Default
uncertainties ndarray

Predicted variances of shape [n_samples]

required

Returns:

Type Description
ndarray

Scaled variances: T * uncertainties


Visualization

plot_reliability_diagram

molax.metrics.visualization.plot_reliability_diagram

plot_reliability_diagram(
    predictions: ndarray,
    uncertainties: ndarray,
    targets: ndarray,
    n_bins: int = 10,
    mask: Optional[ndarray] = None,
    ax: Optional[Axes] = None,
    title: str = "Reliability Diagram",
    color: str = "steelblue",
    show_ece: bool = True,
) -> plt.Axes

Plot reliability diagram showing calibration quality.

A reliability diagram visualizes how well-calibrated uncertainty estimates are. The x-axis shows expected confidence (coverage), and the y-axis shows observed confidence (actual coverage).

Perfect calibration: points lie on the diagonal (y = x). Above diagonal: underconfident (uncertainties too high) Below diagonal: overconfident (uncertainties too low)

Parameters:

Name Type Description Default
predictions ndarray

Predicted means of shape [n_samples]

required
uncertainties ndarray

Predicted variances of shape [n_samples]

required
targets ndarray

True values of shape [n_samples]

required
n_bins int

Number of confidence level bins

10
mask Optional[ndarray]

Optional boolean mask for valid samples

None
ax Optional[Axes]

Optional matplotlib Axes. Creates new figure if None.

None
title str

Plot title

'Reliability Diagram'
color str

Color for the calibration curve

'steelblue'
show_ece bool

Whether to display ECE value in legend

True

Returns:

Type Description
Axes

matplotlib Axes object

plot_calibration_comparison

molax.metrics.visualization.plot_calibration_comparison

plot_calibration_comparison(
    results: Dict[str, Tuple[ndarray, ndarray, ndarray]],
    n_bins: int = 10,
    figsize: Tuple[int, int] = (12, 5),
    colors: Optional[Dict[str, str]] = None,
) -> plt.Figure

Compare calibration across multiple models.

Creates a figure with two subplots: 1. Reliability diagrams for all models overlaid 2. Bar chart comparing ECE values

Parameters:

Name Type Description Default
results Dict[str, Tuple[ndarray, ndarray, ndarray]]

Dictionary mapping model names to tuples of (predictions, uncertainties, targets)

required
n_bins int

Number of bins for calibration computation

10
figsize Tuple[int, int]

Figure size as (width, height)

(12, 5)
colors Optional[Dict[str, str]]

Optional dictionary mapping model names to colors

None

Returns:

Type Description
Figure

matplotlib Figure object

plot_uncertainty_vs_error

molax.metrics.visualization.plot_uncertainty_vs_error

plot_uncertainty_vs_error(
    predictions: ndarray,
    uncertainties: ndarray,
    targets: ndarray,
    mask: Optional[ndarray] = None,
    ax: Optional[Axes] = None,
    title: str = "Uncertainty vs Error",
    color: str = "steelblue",
    show_correlation: bool = True,
) -> plt.Axes

Scatter plot of predicted uncertainty vs actual error.

For well-calibrated models, higher uncertainty should correlate with higher error. Points should cluster around the diagonal.

Parameters:

Name Type Description Default
predictions ndarray

Predicted means of shape [n_samples]

required
uncertainties ndarray

Predicted variances of shape [n_samples]

required
targets ndarray

True values of shape [n_samples]

required
mask Optional[ndarray]

Optional boolean mask for valid samples

None
ax Optional[Axes]

Optional matplotlib Axes. Creates new figure if None.

None
title str

Plot title

'Uncertainty vs Error'
color str

Color for scatter points

'steelblue'
show_correlation bool

Whether to display Pearson correlation

True

Returns:

Type Description
Axes

matplotlib Axes object

create_calibration_report

molax.metrics.visualization.create_calibration_report

create_calibration_report(
    predictions: ndarray,
    uncertainties: ndarray,
    targets: ndarray,
    mask: Optional[ndarray] = None,
    model_name: str = "Model",
    figsize: Tuple[int, int] = (14, 10),
) -> plt.Figure

Create a comprehensive calibration report with multiple plots.

Generates a figure with four subplots: 1. Reliability diagram 2. Uncertainty vs error scatter 3. Uncertainty histogram 4. Z-score histogram

Parameters:

Name Type Description Default
predictions ndarray

Predicted means of shape [n_samples]

required
uncertainties ndarray

Predicted variances of shape [n_samples]

required
targets ndarray

True values of shape [n_samples]

required
mask Optional[ndarray]

Optional boolean mask for valid samples

None
model_name str

Name of the model for titles

'Model'
figsize Tuple[int, int]

Figure size as (width, height)

(14, 10)

Returns:

Type Description
Figure

matplotlib Figure object