1414
1515from pytensor .configdefaults import config
1616from pytensor .gradient import grad_not_implemented
17+ from pytensor .scalar .basic import BinaryScalarOp , ScalarOp , UnaryScalarOp
18+ from pytensor .scalar .basic import abs as scalar_abs
1719from pytensor .scalar .basic import (
18- BinaryScalarOp ,
19- ScalarOp ,
20- UnaryScalarOp ,
2120 as_scalar ,
2221 complex_types ,
2322 constant ,
2726 expm1 ,
2827 float64 ,
2928 float_types ,
29+ identity ,
3030 isinf ,
3131 log ,
3232 log1p ,
33+ reciprocal ,
34+ scalar_maximum ,
3335 sqrt ,
3436 switch ,
3537 true_div ,
@@ -1329,8 +1331,8 @@ def grad(self, inp, grads):
13291331 (gz ,) = grads
13301332
13311333 return [
1332- gz * betainc_der (a , b , x , True ),
1333- gz * betainc_der (a , b , x , False ),
1334+ gz * betainc_grad (a , b , x , True ),
1335+ gz * betainc_grad (a , b , x , False ),
13341336 gz
13351337 * exp (
13361338 log1p (- x ) * (b - 1 )
@@ -1346,28 +1348,28 @@ def c_code(self, *args, **kwargs):
13461348betainc = BetaInc (upgrade_to_float_no_complex , name = "betainc" )
13471349
13481350
1349- class BetaIncDer (ScalarOp ):
1350- """
1351- Gradient of the regularized incomplete beta function wrt to the first
1352- argument (alpha) or the second argument (beta), depending on whether the
1353- fourth argument to betainc_der is `True` or `False`, respectively.
1351+ def betainc_grad (p , q , x , wrtp : bool ):
1352+ """Gradient of the regularized lower gamma function (P) wrt to the first
1353+ argument (k, a.k.a. alpha).
13541354
1355- Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
1356- Journal of Statistical Software, 3(1), 1-20.
1355+ Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
1356+
1357+ Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
1358+ ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
13571359 """
13581360
1359- nin = 4
1361+ def _betainc_der (p , q , x , wrtp , skip_loop ):
1362+ dtype = upcast (p .type .dtype , q .type .dtype , x .type .dtype , "float32" )
1363+
1364+ def betaln (a , b ):
1365+ return gammaln (a ) + (gammaln (b ) - gammaln (a + b ))
13601366
1361- def impl (self , p , q , x , wrtp ):
13621367 def _betainc_a_n (f , p , q , n ):
13631368 """
13641369 Numerator (a_n) of the nth approximant of the continued fraction
13651370 representation of the regularized incomplete beta function
13661371 """
13671372
1368- if n == 1 :
1369- return p * f * (q - 1 ) / (q * (p + 1 ))
1370-
13711373 p2n = p + 2 * n
13721374 F1 = p ** 2 * f ** 2 * (n - 1 ) / (q ** 2 )
13731375 F2 = (
@@ -1377,7 +1379,11 @@ def _betainc_a_n(f, p, q, n):
13771379 / ((p2n - 3 ) * (p2n - 2 ) ** 2 * (p2n - 1 ))
13781380 )
13791381
1380- return F1 * F2
1382+ return switch (
1383+ eq (n , 1 ),
1384+ p * f * (q - 1 ) / (q * (p + 1 )),
1385+ F1 * F2 ,
1386+ )
13811387
13821388 def _betainc_b_n (f , p , q , n ):
13831389 """
@@ -1397,9 +1403,6 @@ def _betainc_da_n_dp(f, p, q, n):
13971403 Derivative of a_n wrt p
13981404 """
13991405
1400- if n == 1 :
1401- return - p * f * (q - 1 ) / (q * (p + 1 ) ** 2 )
1402-
14031406 pp = p ** 2
14041407 ppp = pp * p
14051408 p2n = p + 2 * n
@@ -1414,20 +1417,25 @@ def _betainc_da_n_dp(f, p, q, n):
14141417 D1 = q ** 2 * (p2n - 3 ) ** 2
14151418 D2 = (p2n - 2 ) ** 3 * (p2n - 1 ) ** 2
14161419
1417- return (N1 / D1 ) * (N2a + N2b + N2c + N2d + N2e ) / D2
1420+ return switch (
1421+ eq (n , 1 ),
1422+ - p * f * (q - 1 ) / (q * (p + 1 ) ** 2 ),
1423+ (N1 / D1 ) * (N2a + N2b + N2c + N2d + N2e ) / D2 ,
1424+ )
14181425
14191426 def _betainc_da_n_dq (f , p , q , n ):
14201427 """
14211428 Derivative of a_n wrt q
14221429 """
1423- if n == 1 :
1424- return p * f / (q * (p + 1 ))
1425-
14261430 p2n = p + 2 * n
14271431 F1 = (p ** 2 * f ** 2 / (q ** 2 )) * (n - 1 ) * (p + n - 1 ) * (2 * q + p - 2 )
14281432 D1 = (p2n - 3 ) * (p2n - 2 ) ** 2 * (p2n - 1 )
14291433
1430- return F1 / D1
1434+ return switch (
1435+ eq (n , 1 ),
1436+ p * f / (q * (p + 1 )),
1437+ F1 / D1 ,
1438+ )
14311439
14321440 def _betainc_db_n_dp (f , p , q , n ):
14331441 """
@@ -1452,42 +1460,44 @@ def _betainc_db_n_dq(f, p, q, n):
14521460 p2n = p + 2 * n
14531461 return - (p ** 2 * f ) / (q * (p2n - 2 ) * p2n )
14541462
1455- # Input validation
1456- if not (0 <= x <= 1 ) or p < 0 or q < 0 :
1457- return np .nan
1458-
1459- if x > (p / (p + q )):
1460- return - self .impl (q , p , 1 - x , not wrtp )
1461-
1462- min_iters = 3
1463- max_iters = 200
1464- err_threshold = 1e-12
1465-
1466- derivative_old = 0
1463+ min_iters = np .array (3 , dtype = "int32" )
1464+ max_iters = switch (
1465+ skip_loop , np .array (0 , dtype = "int32" ), np .array (200 , dtype = "int32" )
1466+ )
1467+ err_threshold = np .array (1e-12 , dtype = config .floatX )
14671468
1468- Am2 , Am1 = 1 , 1
1469- Bm2 , Bm1 = 0 , 1
1470- dAm2 , dAm1 = 0 , 0
1471- dBm2 , dBm1 = 0 , 0
1469+ Am2 , Am1 = np . array ( 1 , dtype = dtype ), np . array ( 1 , dtype = dtype )
1470+ Bm2 , Bm1 = np . array ( 0 , dtype = dtype ), np . array ( 1 , dtype = dtype )
1471+ dAm2 , dAm1 = np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype )
1472+ dBm2 , dBm1 = np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype )
14721473
14731474 f = (q * x ) / (p * (1 - x ))
1474- K = np .exp (
1475- p * np .log (x )
1476- + (q - 1 ) * np .log1p (- x )
1477- - np .log (p )
1478- - scipy .special .betaln (p , q )
1479- )
1475+ K = exp (p * log (x ) + (q - 1 ) * log1p (- x ) - log (p ) - betaln (p , q ))
14801476 if wrtp :
1481- dK = (
1482- np .log (x )
1483- - 1 / p
1484- + scipy .special .digamma (p + q )
1485- - scipy .special .digamma (p )
1486- )
1477+ dK = log (x ) - reciprocal (p ) + psi (p + q ) - psi (p )
14871478 else :
1488- dK = np .log1p (- x ) + scipy .special .digamma (p + q ) - scipy .special .digamma (q )
1489-
1490- for n in range (1 , max_iters + 1 ):
1479+ dK = log1p (- x ) + psi (p + q ) - psi (q )
1480+
1481+ derivative = np .array (0 , dtype = dtype )
1482+ n = np .array (1 , dtype = "int16" ) # Enough for 200 max iters
1483+
1484+ def inner_loop (
1485+ derivative ,
1486+ Am2 ,
1487+ Am1 ,
1488+ Bm2 ,
1489+ Bm1 ,
1490+ dAm2 ,
1491+ dAm1 ,
1492+ dBm2 ,
1493+ dBm1 ,
1494+ n ,
1495+ f ,
1496+ p ,
1497+ q ,
1498+ K ,
1499+ dK ,
1500+ ):
14911501 a_n_ = _betainc_a_n (f , p , q , n )
14921502 b_n_ = _betainc_b_n (f , p , q , n )
14931503 if wrtp :
@@ -1502,36 +1512,53 @@ def _betainc_db_n_dq(f, p, q, n):
15021512 dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
15031513 dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
15041514
1505- Am2 , Am1 = Am1 , A
1506- Bm2 , Bm1 = Bm1 , B
1507- dAm2 , dAm1 = dAm1 , dA
1508- dBm2 , dBm1 = dBm1 , dB
1509-
1510- if n < min_iters - 1 :
1511- continue
1515+ Am2 , Am1 = identity (Am1 ), identity (A )
1516+ Bm2 , Bm1 = identity (Bm1 ), identity (B )
1517+ dAm2 , dAm1 = identity (dAm1 ), identity (dA )
1518+ dBm2 , dBm1 = identity (dBm1 ), identity (dB )
15121519
15131520 F1 = A / B
15141521 F2 = (dA - F1 * dB ) / B
1515- derivative = K * (F1 * dK + F2 )
1522+ derivative_new = K * (F1 * dK + F2 )
15161523
1517- errapx = abs (derivative_old - derivative )
1518- d_errapx = errapx / max (err_threshold , abs (derivative ))
1519- derivative_old = derivative
1520-
1521- if d_errapx <= err_threshold :
1522- return derivative
1524+ errapx = scalar_abs (derivative - derivative_new )
1525+ d_errapx = errapx / scalar_maximum (
1526+ err_threshold , scalar_abs (derivative_new )
1527+ )
15231528
1524- warnings .warn (
1525- f"betainc_der did not converge after { n } iterations" ,
1526- RuntimeWarning ,
1527- )
1528- return np .nan
1529+ min_iters_cond = n > (min_iters - 1 )
1530+ derivative = switch (
1531+ min_iters_cond ,
1532+ derivative_new ,
1533+ derivative ,
1534+ )
1535+ n += 1
15291536
1530- def c_code (self , * args , ** kwargs ):
1531- raise NotImplementedError ()
1537+ return (
1538+ (derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ),
1539+ (d_errapx <= err_threshold ) & min_iters_cond ,
1540+ )
15321541
1542+ init = [derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ]
1543+ constant = [f , p , q , K , dK ]
1544+ grad = _make_scalar_loop (
1545+ max_iters , init , constant , inner_loop , name = "betainc_grad"
1546+ )
1547+ return grad
15331548
1534- betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1549+ # Input validation
1550+ nan_branch = (x < 0 ) | (x > 1 ) | (p < 0 ) | (q < 0 )
1551+ flip_branch = x > (p / (p + q ))
1552+ grad = switch (
1553+ nan_branch ,
1554+ np .nan ,
1555+ switch (
1556+ flip_branch ,
1557+ - _betainc_der (q , p , 1 - x , not wrtp , skip_loop = nan_branch | (~ flip_branch )),
1558+ _betainc_der (p , q , x , wrtp , skip_loop = nan_branch | flip_branch ),
1559+ ),
1560+ )
1561+ return grad
15351562
15361563
15371564class Hyp2F1 (ScalarOp ):
0 commit comments