Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pymc_extras/inference/laplace_approx/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,15 @@ def model_to_laplace_approx(
elif name in model.named_vars_to_dims:
dims = (*batch_dims, *model.named_vars_to_dims[name])
else:
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
initval = initial_point.get(name, None)
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
laplace_model.add_coords(
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
)
if dim_shapes[0] is not None:
Copy link
Member

@ricardoV94 ricardoV94 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow what's hapening here. If dim_shapes is coming from rv.type.shape, you could have dim_shapes[0] not be None, and a follow up be None. Example x = pm.Data("x", np.zeros(5, 3), shape=(5, None) is valid.

The whole approach seems a bit backwards. Instead of trying to plug the sample dims in the model that may not have properly defined dims to begin with, why not sample first and then attach those in the final inference data object, which always has dims, explicit or automatic?

If you just want to patch for now be more exhaustive and check not any(d is None for d in dim_shapes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Thank you for taking the time to review my changes. I went back and looked at the logic and I agree with you that it would be more appropriate to sample and allow the inference data object to create dims automatically. However, due to the addition of (temp_chain, temp_draw) dims inmodel_to_laplace_approx the automatic dims are incremented by 2. For example, the expected mu_dim_0 would be mu_dim_2. I wrote a helper function to rename these automatically generated dims/coords post creation. Please let me know if that looks okay to you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense yes. I think the ultimate issue here was trying to put sampling dims in the model instead of working on the random function directly, but that's a larger question that doesn't need to be had in this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct. It is odd that we are assigning the sampling dims (temp_chain, temp_draws) in the model. I need to dig a little deeper and see how to change this architectural design to allow the random function handle the sampling dims and not the model. There are some comments in the implementation mentioning that (temp_chain, temp_draws) are supposedly batch dimensions. Maybe the correct approach would be to treat those separately from (chain, draw) and let the random function just name them using defaults. For example (using the above example), we would then get (mu_dim_0, mu_dim_1, mu_dim_2) with shapes (2, 500, 100). I am not sure, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point though is that the chain/draw dimensions from regular sampling don't exist in the graph either. Here we were trying to create a vectorized posterior/predictive/compute deterministics sort of function. Which is fine, but maybe shouldn't 1) be restricted to this method specifically, and 2) hacked into the model. We basically need to get the underlying function that's used by this routine after creating the batched model, call it once, and then handle the conversion to InferenceData ourselves, which is when you tell it that there are 2 batch dimensions with name chain/draw for every variable.

This should be thought about separate from this PR

dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
laplace_model.add_coords(
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
)
else:
dims = None

pm.Deterministic(name, batched_rv, dims=dims)

Expand Down
38 changes: 38 additions & 0 deletions tests/inference/laplace_approx/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,44 @@ def test_fit_laplace_ragged_coords(rng):
assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all()


def test_fit_laplace_no_data_or_deterministic_dims(rng):
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
with pm.Model(coords=coords) as ragged_dim_model:
X = pm.Data("X", np.ones((100, 2)))
beta = pm.Normal(
"beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"]
)
mu = pm.Deterministic("mu", (X[:, None, :] * beta[None]).sum(axis=-1))
sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"])

obs = pm.Normal(
"obs",
mu=mu,
sigma=sigma,
observed=rng.normal(loc=3, scale=1.5, size=(100, 3)),
dims=["obs_idx", "city"],
)

idata = fit_laplace(
optimize_method="Newton-CG",
progressbar=False,
use_grad=True,
use_hessp=True,
)

# These should have been dropped when the laplace idata was created
assert "laplace_approximation" not in list(idata.posterior.data_vars.keys())
assert "unpacked_var_names" not in list(idata.posterior.coords.keys())

assert idata["posterior"].beta.shape[-2:] == (3, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we test we get the expected sample dims as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is definitely a good idea. I added assert statements to test for sampling dims and I also consolidated the test with test_fit_laplace_ragged_coords because there was a lot of code duplicated.

assert idata["posterior"].sigma.shape[-1:] == (3,)

# Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1
# strictly positive
assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all()
assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all()


def test_model_with_nonstandard_dimensionality(rng):
y_obs = np.concatenate(
[rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)]
Expand Down