1515
1616from typing import Sequence , Union
1717
18+ import numpy as np
1819import pymc as pm
1920import pytensor .tensor as pt
2021
2122__all__ = ["R2D2M2CP" ]
2223
2324
24- def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable ):
25+ def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable , psi_mask ):
2526 pi = pt .erfinv (2 * psi - 1 )
2627 f = (1 / (2 * pi ** 2 + 1 )) ** 0.5
2728 sigma = explained_var ** 0.5 * f
2829 mu = sigma * pi * 2 ** 0.5
29- return mu , sigma
30+ if psi_mask is not None :
31+ return (
32+ pt .where (psi_mask , mu , pt .sign (pi ) * explained_var ** 0.5 ),
33+ pt .where (psi_mask , sigma , 0 ),
34+ )
35+ else :
36+ return mu , sigma
3037
3138
3239def _R2D2M2CP_beta (
@@ -37,6 +44,7 @@ def _R2D2M2CP_beta(
3744 phi : pt .TensorVariable ,
3845 psi : pt .TensorVariable ,
3946 * ,
47+ psi_mask ,
4048 dims : Union [str , Sequence [str ]],
4149 centered = False ,
4250):
@@ -59,16 +67,141 @@ def _R2D2M2CP_beta(
5967 """
6068 tau2 = r2 / (1 - r2 )
6169 explained_variance = phi * pt .expand_dims (tau2 * output_sigma ** 2 , - 1 )
62- mu_param , std_param = _psivar2musigma (psi , explained_variance )
70+ mu_param , std_param = _psivar2musigma (psi , explained_variance , psi_mask = psi_mask )
6371 if not centered :
6472 with pm .Model (name ):
65- raw = pm .Normal ("raw" , dims = dims )
73+ if psi_mask is not None and psi_mask .any ():
74+ # limit case where some probs are not 1 or 0
75+ # setsubtensor is required
76+ r_idx = psi_mask .nonzero ()
77+ with pm .Model ("raw" ):
78+ raw = pm .Normal ("masked" , shape = len (r_idx [0 ]))
79+ raw = pt .set_subtensor (pt .zeros_like (mu_param )[r_idx ], raw )
80+ raw = pm .Deterministic ("raw" , raw , dims = dims )
81+ elif psi_mask is not None :
82+ # all variables are deterministic
83+ raw = pt .zeros_like (mu_param )
84+ else :
85+ raw = pm .Normal ("raw" , dims = dims )
6686 beta = pm .Deterministic (name , (raw * std_param + mu_param ) / input_sigma , dims = dims )
6787 else :
68- beta = pm .Normal (name , mu_param / input_sigma , std_param / input_sigma , dims = dims )
88+ if psi_mask is not None and psi_mask .any ():
89+ # limit case where some probs are not 1 or 0
90+ # setsubtensor is required
91+ r_idx = psi_mask .nonzero ()
92+ with pm .Model (name ):
93+ mean = (mu_param / input_sigma )[r_idx ]
94+ sigma = (std_param / input_sigma )[r_idx ]
95+ masked = pm .Normal (
96+ "masked" ,
97+ mean ,
98+ sigma ,
99+ shape = len (r_idx [0 ]),
100+ )
101+ beta = pt .set_subtensor (mean , masked )
102+ beta = pm .Deterministic (name , beta , dims = dims )
103+ elif psi_mask is not None :
104+ # all variables are deterministic
105+ beta = pm .Deterministic (name , (mu_param / input_sigma ), dims = dims )
106+ else :
107+ beta = pm .Normal (name , mu_param / input_sigma , std_param / input_sigma , dims = dims )
69108 return beta
70109
71110
111+ def _broadcast_as_dims (* values , dims ):
112+ model = pm .modelcontext (None )
113+ shape = [len (model .coords [d ]) for d in dims ]
114+ ret = tuple (np .broadcast_to (v , shape ) for v in values )
115+ # strip output
116+ if len (values ) == 1 :
117+ ret = ret [0 ]
118+ return ret
119+
120+
121+ def _psi_masked (positive_probs , positive_probs_std , * , dims ):
122+ if not (
123+ isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
124+ ):
125+ raise TypeError (
126+ "Only constant values for positive_probs and positive_probs_std are accepted"
127+ )
128+ positive_probs , positive_probs_std = _broadcast_as_dims (
129+ positive_probs .data , positive_probs_std .data , dims = dims
130+ )
131+ mask = ~ np .bitwise_or (positive_probs == 1 , positive_probs == 0 )
132+ if np .bitwise_and (~ mask , positive_probs_std != 0 ).any ():
133+ raise ValueError ("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0" )
134+ if (~ mask ).any () and mask .any ():
135+ # limit case where some probs are not 1 or 0
136+ # setsubtensor is required
137+ r_idx = mask .nonzero ()
138+ with pm .Model ("psi" ):
139+ psi = pm .Beta (
140+ "masked" ,
141+ mu = positive_probs [r_idx ],
142+ sigma = positive_probs_std [r_idx ],
143+ shape = len (r_idx [0 ]),
144+ )
145+ psi = pt .set_subtensor (pt .as_tensor (positive_probs )[r_idx ], psi )
146+ psi = pm .Deterministic ("psi" , psi , dims = dims )
147+ elif (~ mask ).all ():
148+ # limit case where all the probs are limit case
149+ psi = pt .as_tensor (positive_probs )
150+ else :
151+ psi = pm .Beta ("psi" , mu = positive_probs , sigma = positive_probs_std , dims = dims )
152+ mask = None
153+ return mask , psi
154+
155+
156+ def _psi (positive_probs , positive_probs_std , * , dims ):
157+ if positive_probs_std is not None :
158+ mask , psi = _psi_masked (
159+ positive_probs = pt .as_tensor (positive_probs ),
160+ positive_probs_std = pt .as_tensor (positive_probs_std ),
161+ dims = dims ,
162+ )
163+ else :
164+ positive_probs = pt .as_tensor (positive_probs )
165+ if not isinstance (positive_probs , pt .Constant ):
166+ raise TypeError ("Only constant values for positive_probs are allowed" )
167+ psi = _broadcast_as_dims (positive_probs .data , dims = dims )
168+ mask = np .atleast_1d (~ np .bitwise_or (psi == 1 , psi == 0 ))
169+ if mask .all ():
170+ mask = None
171+ return mask , psi
172+
173+
174+ def _phi (
175+ variables_importance ,
176+ variance_explained ,
177+ importance_concentration ,
178+ * ,
179+ dims ,
180+ ):
181+ * broadcast_dims , dim = dims
182+ model = pm .modelcontext (None )
183+ if variables_importance is not None :
184+ if variance_explained is not None :
185+ raise TypeError ("Can't use variable importance with variance explained" )
186+ if len (model .coords [dim ]) <= 1 :
187+ raise TypeError ("Can't use variable importance with less than two variables" )
188+ variables_importance = pt .as_tensor (variables_importance )
189+ if importance_concentration is not None :
190+ variables_importance *= importance_concentration
191+ return pm .Dirichlet ("phi" , variables_importance , dims = broadcast_dims + [dim ])
192+ elif variance_explained is not None :
193+ if len (model .coords [dim ]) <= 1 :
194+ raise TypeError ("Can't use variance explained with less than two variables" )
195+ phi = pt .as_tensor (variance_explained )
196+ else :
197+ phi = 1 / len (model .coords [dim ])
198+ phi = _broadcast_as_dims (phi , dims = dims )
199+ if importance_concentration is not None :
200+ return pm .Dirichlet ("phi" , importance_concentration * phi , dims = broadcast_dims + [dim ])
201+ else :
202+ return phi
203+
204+
72205def R2D2M2CP (
73206 name ,
74207 output_sigma ,
@@ -78,6 +211,7 @@ def R2D2M2CP(
78211 r2 ,
79212 variables_importance = None ,
80213 variance_explained = None ,
214+ importance_concentration = None ,
81215 r2_std = None ,
82216 positive_probs = 0.5 ,
83217 positive_probs_std = None ,
@@ -102,6 +236,8 @@ def R2D2M2CP(
102236 variance_explained : tensor, optional
103237 Alternative estimate for variables importance which is point estimate of
104238 variance explained, should sum up to one, by default None
239+ importance_concentration : tensor, optional
240+ Confidence around variance explained or variable importance estimate
105241 r2_std : tensor, optional
106242 Optional uncertainty over :math:`R^2`, by default None
107243 positive_probs : tensor, optional
@@ -125,8 +261,8 @@ def R2D2M2CP(
125261 -----
126262 The R2D2M2CP prior is a modification of R2D2M2 prior.
127263
128- - ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
129- - R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
264+ - ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
265+ - R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
130266
131267 Examples
132268 --------
@@ -259,31 +395,20 @@ def R2D2M2CP(
259395 input_sigma = pt .as_tensor (input_sigma )
260396 output_sigma = pt .as_tensor (output_sigma )
261397 with pm .Model (name ) as model :
262- if variables_importance is not None :
263- if variance_explained is not None :
264- raise TypeError ("Can't use variable importance with variance explained" )
265- if len (model .coords [dim ]) <= 1 :
266- raise TypeError ("Can't use variable importance with less than two variables" )
267- phi = pm .Dirichlet (
268- "phi" , pt .as_tensor (variables_importance ), dims = broadcast_dims + [dim ]
269- )
270- elif variance_explained is not None :
271- if len (model .coords [dim ]) <= 1 :
272- raise TypeError ("Can't use variance explained with less than two variables" )
273- phi = pt .as_tensor (variance_explained )
274- else :
275- phi = 1 / len (model .coords [dim ])
398+ if not all (isinstance (model .dim_lengths [d ], pt .TensorConstant ) for d in dims ):
399+ raise ValueError (f"{ dims !r} should be constant length immutable dims" )
276400 if r2_std is not None :
277401 r2 = pm .Beta ("r2" , mu = r2 , sigma = r2_std , dims = broadcast_dims )
278- if positive_probs_std is not None :
279- psi = pm .Beta (
280- "psi" ,
281- mu = pt .as_tensor (positive_probs ),
282- sigma = pt .as_tensor (positive_probs_std ),
283- dims = broadcast_dims + [dim ],
284- )
285- else :
286- psi = pt .as_tensor (positive_probs )
402+ phi = _phi (
403+ variables_importance = variables_importance ,
404+ variance_explained = variance_explained ,
405+ importance_concentration = importance_concentration ,
406+ dims = dims ,
407+ )
408+ mask , psi = _psi (
409+ positive_probs = positive_probs , positive_probs_std = positive_probs_std , dims = dims
410+ )
411+
287412 beta = _R2D2M2CP_beta (
288413 name ,
289414 output_sigma ,
@@ -293,6 +418,7 @@ def R2D2M2CP(
293418 psi ,
294419 dims = broadcast_dims + [dim ],
295420 centered = centered ,
421+ psi_mask = mask ,
296422 )
297423 resid_sigma = (1 - r2 ) ** 0.5 * output_sigma
298424 return resid_sigma , beta
0 commit comments