diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 38bc0e840a..76db660f14 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from functools import singledispatch import numpy as np @@ -134,6 +135,10 @@ def log_jac_det(self, value, *inputs): y = pt.zeros(value.shape) return pt.sum(y, axis=-1) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Drop the last label since SumTo1 reduces dimensionality by 1.""" + return labels[:-1] + class CholeskyCovPacked(Transform): """Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale.""" @@ -311,6 +316,10 @@ def backward(self, value, *rv_inputs): def log_jac_det(self, value, *rv_inputs): return pt.constant(0.0) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Drop the last label since ZeroSumTransform reduces dimensionality by 1.""" + return labels[:-1] + log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8d2bbacd26..6c6c7910da 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -35,7 +35,7 @@ # SOFTWARE. import abc -from collections.abc import Callable +from collections.abc import Callable, Sequence import numpy as np import pytensor.tensor as pt @@ -154,6 +154,10 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: phi_inv = self.backward(value, *inputs) return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Mutate user-provided coordinates associated with the variable to label transformed values returned by this class.""" + return labels + def __str__(self): """Return a string representation of the object.""" return f"{self.__class__.__name__}" @@ -1006,6 +1010,10 @@ def log_jac_det(self, value, *inputs): res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) return pt.sum(res, -1) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Drop the last label since Simplex reduces dimensionality by 1.""" + return labels[:-1] + class CircularTransform(Transform): name = "circular"