-
Notifications
You must be signed in to change notification settings - Fork 74
bugfix for fit_laplace absent dims #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
52d8a92
1e41758
5abe4ff
81086e7
1450629
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,6 +193,44 @@ def test_fit_laplace_ragged_coords(rng): | |
| assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() | ||
|
|
||
|
|
||
| def test_fit_laplace_no_data_or_deterministic_dims(rng): | ||
| coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} | ||
| with pm.Model(coords=coords) as ragged_dim_model: | ||
| X = pm.Data("X", np.ones((100, 2))) | ||
| beta = pm.Normal( | ||
| "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] | ||
| ) | ||
| mu = pm.Deterministic("mu", (X[:, None, :] * beta[None]).sum(axis=-1)) | ||
| sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) | ||
|
|
||
| obs = pm.Normal( | ||
| "obs", | ||
| mu=mu, | ||
| sigma=sigma, | ||
| observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), | ||
| dims=["obs_idx", "city"], | ||
| ) | ||
|
|
||
| idata = fit_laplace( | ||
| optimize_method="Newton-CG", | ||
| progressbar=False, | ||
| use_grad=True, | ||
| use_hessp=True, | ||
| ) | ||
|
|
||
| # These should have been dropped when the laplace idata was created | ||
| assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) | ||
| assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) | ||
|
|
||
| assert idata["posterior"].beta.shape[-2:] == (3, 2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we test we get the expected sample dims as well?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is definitely a good idea. I added assert statements to test for sampling dims and I also consolidated the test with |
||
| assert idata["posterior"].sigma.shape[-1:] == (3,) | ||
|
|
||
| # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 | ||
| # strictly positive | ||
| assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() | ||
| assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() | ||
|
|
||
|
|
||
| def test_model_with_nonstandard_dimensionality(rng): | ||
| y_obs = np.concatenate( | ||
| [rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow what's hapening here. If
dim_shapesis coming from rv.type.shape, you could have dim_shapes[0] not be None, and a follow up be None. Examplex = pm.Data("x", np.zeros(5, 3), shape=(5, None)is valid.The whole approach seems a bit backwards. Instead of trying to plug the sample dims in the model that may not have properly defined dims to begin with, why not sample first and then attach those in the final inference data object, which always has dims, explicit or automatic?
If you just want to patch for now be more exhaustive and check
not any(d is None for d in dim_shapes)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Thank you for taking the time to review my changes. I went back and looked at the logic and I agree with you that it would be more appropriate to sample and allow the inference data object to create dims automatically. However, due to the addition of (temp_chain, temp_draw) dims in
model_to_laplace_approxthe automatic dims are incremented by 2. For example, the expectedmu_dim_0would bemu_dim_2. I wrote a helper function to rename these automatically generated dims/coords post creation. Please let me know if that looks okay to you.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense yes. I think the ultimate issue here was trying to put sampling dims in the model instead of working on the random function directly, but that's a larger question that doesn't need to be had in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are correct. It is odd that we are assigning the sampling dims (temp_chain, temp_draws) in the model. I need to dig a little deeper and see how to change this architectural design to allow the random function handle the sampling dims and not the model. There are some comments in the implementation mentioning that (temp_chain, temp_draws) are supposedly batch dimensions. Maybe the correct approach would be to treat those separately from (chain, draw) and let the random function just name them using defaults. For example (using the above example), we would then get (mu_dim_0, mu_dim_1, mu_dim_2) with shapes (2, 500, 100). I am not sure, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point though is that the chain/draw dimensions from regular sampling don't exist in the graph either. Here we were trying to create a vectorized posterior/predictive/compute deterministics sort of function. Which is fine, but maybe shouldn't 1) be restricted to this method specifically, and 2) hacked into the model. We basically need to get the underlying function that's used by this routine after creating the batched model, call it once, and then handle the conversion to InferenceData ourselves, which is when you tell it that there are 2 batch dimensions with name chain/draw for every variable.
This should be thought about separate from this PR