Calibration Metrics
This page documents the uncertainty calibration metrics and visualization tools.
Calibration Metrics
expected_calibration_error
molax.metrics.calibration.expected_calibration_error
expected_calibration_error(
predictions: ndarray,
uncertainties: ndarray,
targets: ndarray,
n_bins: int = 10,
mask: Optional[ndarray] = None,
) -> jnp.ndarray
Compute Expected Calibration Error (ECE) for regression.
ECE measures the average gap between expected and observed confidence across different confidence levels. Perfect calibration = 0.
For regression: ECE = mean(|observed_coverage - expected_coverage|)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted means of shape [n_samples] |
required |
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
targets
|
ndarray
|
True values of shape [n_samples] |
required |
n_bins
|
int
|
Number of confidence level bins (default: 10) |
10
|
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar ECE value in [0, 1]. Lower is better. |
negative_log_likelihood
molax.metrics.calibration.negative_log_likelihood
negative_log_likelihood(
mean: ndarray,
var: ndarray,
targets: ndarray,
mask: Optional[ndarray] = None,
) -> jnp.ndarray
Compute Gaussian negative log-likelihood (proper scoring rule).
NLL = 0.5 * (log(2pivar) + (y - mean)^2 / var)
Lower is better. This is the proper scoring rule for probabilistic predictions with Gaussian likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
ndarray
|
Predicted means of shape [n_samples] or [n_samples, 1] |
required |
var
|
ndarray
|
Predicted variances of shape [n_samples] or [n_samples, 1] |
required |
targets
|
ndarray
|
True values of shape [n_samples] or [n_samples, 1] |
required |
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples (True = include) |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar NLL value (mean over valid samples) |
compute_calibration_curve
molax.metrics.calibration.compute_calibration_curve
compute_calibration_curve(
predictions: ndarray,
uncertainties: ndarray,
targets: ndarray,
n_bins: int = 10,
mask: Optional[ndarray] = None,
) -> Dict[str, jnp.ndarray]
Compute data for reliability diagrams.
For regression with Gaussian uncertainty, we compute calibration by checking what fraction of targets fall within various confidence intervals.
For each confidence level p (e.g., 50%, 68%, 90%, 95%): - Compute interval: [mean - z_p * std, mean + z_p * std] - Count fraction of targets within interval (observed coverage) - Perfect calibration: observed coverage = expected coverage
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted means of shape [n_samples] |
required |
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
targets
|
ndarray
|
True values of shape [n_samples] |
required |
n_bins
|
int
|
Number of confidence level bins (default: 10) |
10
|
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, ndarray]
|
Dictionary with: |
Dict[str, ndarray]
|
|
Dict[str, ndarray]
|
|
Dict[str, ndarray]
|
|
sharpness
molax.metrics.calibration.sharpness
Compute sharpness (average predicted uncertainty).
Sharpness measures how confident the model is on average. Lower = sharper/more confident predictions.
Note: Sharpness alone doesn't indicate quality - a model can be overconfidently wrong. Use together with calibration metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar mean standard deviation |
evaluate_calibration
molax.metrics.calibration.evaluate_calibration
evaluate_calibration(
predictions: ndarray,
uncertainties: ndarray,
targets: ndarray,
mask: Optional[ndarray] = None,
n_bins: int = 10,
) -> Dict[str, float]
Compute comprehensive calibration metrics.
Convenience function that computes all calibration metrics at once.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted means of shape [n_samples] |
required |
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
targets
|
ndarray
|
True values of shape [n_samples] |
required |
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
n_bins
|
int
|
Number of bins for ECE computation |
10
|
Returns:
| Type | Description |
|---|---|
Dict[str, float]
|
Dictionary with: |
Dict[str, float]
|
|
Dict[str, float]
|
|
Dict[str, float]
|
|
Dict[str, float]
|
|
Dict[str, float]
|
|
Post-hoc Calibration
TemperatureScaling
molax.metrics.calibration.TemperatureScaling
Temperature scaling for post-hoc calibration.
Temperature scaling learns a single parameter T to scale uncertainties
calibrated_variance = T * predicted_variance
T > 1 increases uncertainty (model is overconfident) T < 1 decreases uncertainty (model is underconfident)
The temperature is optimized to minimize NLL on a validation set.
Usage
scaler = TemperatureScaling() scaler.fit(val_predictions, val_uncertainties, val_targets) calibrated_var = scaler.transform(test_uncertainties)
Reference: Guo et al., "On Calibration of Modern Neural Networks", ICML 2017
fit
fit(
val_predictions: ndarray,
val_uncertainties: ndarray,
val_targets: ndarray,
mask: Optional[ndarray] = None,
max_iter: int = 100,
lr: float = 0.1,
) -> TemperatureScaling
Optimize temperature on validation data.
Uses gradient descent to minimize NLL with respect to temperature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
val_predictions
|
ndarray
|
Validation predictions of shape [n_val] |
required |
val_uncertainties
|
ndarray
|
Validation variances of shape [n_val] |
required |
val_targets
|
ndarray
|
Validation targets of shape [n_val] |
required |
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
max_iter
|
int
|
Maximum optimization iterations |
100
|
lr
|
float
|
Learning rate for gradient descent |
0.1
|
Returns:
| Type | Description |
|---|---|
TemperatureScaling
|
Self for method chaining |
Visualization
plot_reliability_diagram
molax.metrics.visualization.plot_reliability_diagram
plot_reliability_diagram(
predictions: ndarray,
uncertainties: ndarray,
targets: ndarray,
n_bins: int = 10,
mask: Optional[ndarray] = None,
ax: Optional[Axes] = None,
title: str = "Reliability Diagram",
color: str = "steelblue",
show_ece: bool = True,
) -> plt.Axes
Plot reliability diagram showing calibration quality.
A reliability diagram visualizes how well-calibrated uncertainty estimates are. The x-axis shows expected confidence (coverage), and the y-axis shows observed confidence (actual coverage).
Perfect calibration: points lie on the diagonal (y = x). Above diagonal: underconfident (uncertainties too high) Below diagonal: overconfident (uncertainties too low)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted means of shape [n_samples] |
required |
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
targets
|
ndarray
|
True values of shape [n_samples] |
required |
n_bins
|
int
|
Number of confidence level bins |
10
|
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
ax
|
Optional[Axes]
|
Optional matplotlib Axes. Creates new figure if None. |
None
|
title
|
str
|
Plot title |
'Reliability Diagram'
|
color
|
str
|
Color for the calibration curve |
'steelblue'
|
show_ece
|
bool
|
Whether to display ECE value in legend |
True
|
Returns:
| Type | Description |
|---|---|
Axes
|
matplotlib Axes object |
plot_calibration_comparison
molax.metrics.visualization.plot_calibration_comparison
plot_calibration_comparison(
results: Dict[str, Tuple[ndarray, ndarray, ndarray]],
n_bins: int = 10,
figsize: Tuple[int, int] = (12, 5),
colors: Optional[Dict[str, str]] = None,
) -> plt.Figure
Compare calibration across multiple models.
Creates a figure with two subplots: 1. Reliability diagrams for all models overlaid 2. Bar chart comparing ECE values
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
results
|
Dict[str, Tuple[ndarray, ndarray, ndarray]]
|
Dictionary mapping model names to tuples of (predictions, uncertainties, targets) |
required |
n_bins
|
int
|
Number of bins for calibration computation |
10
|
figsize
|
Tuple[int, int]
|
Figure size as (width, height) |
(12, 5)
|
colors
|
Optional[Dict[str, str]]
|
Optional dictionary mapping model names to colors |
None
|
Returns:
| Type | Description |
|---|---|
Figure
|
matplotlib Figure object |
plot_uncertainty_vs_error
molax.metrics.visualization.plot_uncertainty_vs_error
plot_uncertainty_vs_error(
predictions: ndarray,
uncertainties: ndarray,
targets: ndarray,
mask: Optional[ndarray] = None,
ax: Optional[Axes] = None,
title: str = "Uncertainty vs Error",
color: str = "steelblue",
show_correlation: bool = True,
) -> plt.Axes
Scatter plot of predicted uncertainty vs actual error.
For well-calibrated models, higher uncertainty should correlate with higher error. Points should cluster around the diagonal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted means of shape [n_samples] |
required |
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
targets
|
ndarray
|
True values of shape [n_samples] |
required |
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
ax
|
Optional[Axes]
|
Optional matplotlib Axes. Creates new figure if None. |
None
|
title
|
str
|
Plot title |
'Uncertainty vs Error'
|
color
|
str
|
Color for scatter points |
'steelblue'
|
show_correlation
|
bool
|
Whether to display Pearson correlation |
True
|
Returns:
| Type | Description |
|---|---|
Axes
|
matplotlib Axes object |
create_calibration_report
molax.metrics.visualization.create_calibration_report
create_calibration_report(
predictions: ndarray,
uncertainties: ndarray,
targets: ndarray,
mask: Optional[ndarray] = None,
model_name: str = "Model",
figsize: Tuple[int, int] = (14, 10),
) -> plt.Figure
Create a comprehensive calibration report with multiple plots.
Generates a figure with four subplots: 1. Reliability diagram 2. Uncertainty vs error scatter 3. Uncertainty histogram 4. Z-score histogram
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted means of shape [n_samples] |
required |
uncertainties
|
ndarray
|
Predicted variances of shape [n_samples] |
required |
targets
|
ndarray
|
True values of shape [n_samples] |
required |
mask
|
Optional[ndarray]
|
Optional boolean mask for valid samples |
None
|
model_name
|
str
|
Name of the model for titles |
'Model'
|
figsize
|
Tuple[int, int]
|
Figure size as (width, height) |
(14, 10)
|
Returns:
| Type | Description |
|---|---|
Figure
|
matplotlib Figure object |