|
18 | 18 | BinaryScalarOp, |
19 | 19 | ScalarOp, |
20 | 20 | UnaryScalarOp, |
| 21 | + as_scalar, |
21 | 22 | complex_types, |
| 23 | + constant, |
22 | 24 | discrete_types, |
| 25 | + eq, |
23 | 26 | exp, |
24 | 27 | expm1, |
25 | 28 | float64, |
26 | 29 | float_types, |
27 | 30 | isinf, |
28 | 31 | log, |
29 | 32 | log1p, |
| 33 | + sqrt, |
30 | 34 | switch, |
31 | 35 | true_div, |
32 | 36 | upcast, |
33 | 37 | upgrade_to_float, |
34 | 38 | upgrade_to_float64, |
35 | 39 | upgrade_to_float_no_complex, |
36 | 40 | ) |
| 41 | +from pytensor.scalar.loop import ScalarLoop |
37 | 42 |
|
38 | 43 |
|
39 | 44 | class Erf(UnaryScalarOp): |
@@ -595,7 +600,7 @@ def grad(self, inputs, grads): |
595 | 600 | (k, x) = inputs |
596 | 601 | (gz,) = grads |
597 | 602 | return [ |
598 | | - gz * gammainc_der(k, x), |
| 603 | + gz * gammainc_grad(k, x), |
599 | 604 | gz * exp(-x + (k - 1) * log(x) - gammaln(k)), |
600 | 605 | ] |
601 | 606 |
|
@@ -644,7 +649,7 @@ def grad(self, inputs, grads): |
644 | 649 | (k, x) = inputs |
645 | 650 | (gz,) = grads |
646 | 651 | return [ |
647 | | - gz * gammaincc_der(k, x), |
| 652 | + gz * gammaincc_grad(k, x), |
648 | 653 | gz * -exp(-x + (k - 1) * log(x) - gammaln(k)), |
649 | 654 | ] |
650 | 655 |
|
@@ -675,162 +680,209 @@ def __hash__(self): |
675 | 680 | gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") |
676 | 681 |
|
677 | 682 |
|
678 | | -class GammaIncDer(BinaryScalarOp): |
679 | | - """ |
680 | | - Gradient of the the regularized lower gamma function (P) wrt to the first |
681 | | - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp` |
| 683 | +def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): |
| 684 | + init = [as_scalar(x) for x in init] |
| 685 | + constant = [as_scalar(x) for x in constant] |
| 686 | + # Create dummy types, in case some variables have the same initial form |
| 687 | + init_ = [x.type() for x in init] |
| 688 | + constant_ = [x.type() for x in constant] |
| 689 | + update_, until_ = inner_loop_fn(*init_, *constant_) |
| 690 | + op = ScalarLoop( |
| 691 | + init=init_, |
| 692 | + constant=constant_, |
| 693 | + update=update_, |
| 694 | + until=until_, |
| 695 | + until_condition_failed="warn", |
| 696 | + name=name, |
| 697 | + ) |
| 698 | + S, *_ = op(n_steps, *init, *constant) |
| 699 | + return S |
| 700 | + |
| 701 | + |
| 702 | +def gammainc_grad(k, x): |
| 703 | + """Gradient of the regularized lower gamma function (P) wrt to the first |
| 704 | + argument (k, a.k.a. alpha). |
| 705 | +
|
| 706 | + Adapted from STAN `grad_reg_lower_inc_gamma.hpp` |
682 | 707 |
|
683 | 708 | Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions. |
684 | 709 | ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481. |
685 | 710 | """ |
| 711 | + dtype = upcast(k.type.dtype, x.type.dtype, "float32") |
686 | 712 |
|
687 | | - def impl(self, k, x): |
688 | | - if x == 0: |
689 | | - return 0 |
690 | | - |
691 | | - sqrt_exp = -756 - x**2 + 60 * x |
692 | | - if ( |
693 | | - (k < 0.8 and x > 15) |
694 | | - or (k < 12 and x > 30) |
695 | | - or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp)) |
696 | | - ): |
697 | | - return -GammaIncCDer.st_impl(k, x) |
698 | | - |
699 | | - precision = 1e-10 |
700 | | - max_iters = int(1e5) |
| 713 | + def grad_approx(skip_loop): |
| 714 | + precision = np.array(1e-10, dtype=config.floatX) |
| 715 | + max_iters = switch( |
| 716 | + skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32") |
| 717 | + ) |
701 | 718 |
|
702 | | - log_x = np.log(x) |
703 | | - log_gamma_k_plus_1 = scipy.special.gammaln(k + 1) |
| 719 | + log_x = log(x) |
| 720 | + log_gamma_k_plus_1 = gammaln(k + 1) |
704 | 721 |
|
705 | | - k_plus_n = k |
| 722 | + # First loop |
| 723 | + k_plus_n = k # Should not overflow unless k > 2,147,383,647 |
706 | 724 | log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 |
707 | | - sum_a = 0.0 |
708 | | - for n in range(0, max_iters + 1): |
709 | | - term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) |
710 | | - sum_a += term |
| 725 | + sum_a0 = np.array(0.0, dtype=dtype) |
711 | 726 |
|
712 | | - if term <= precision: |
713 | | - break |
| 727 | + def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x): |
| 728 | + term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) |
| 729 | + sum_a += term |
714 | 730 |
|
715 | | - log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) |
| 731 | + log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) |
716 | 732 | k_plus_n += 1 |
717 | | - |
718 | | - if n >= max_iters: |
719 | | - warnings.warn( |
720 | | - f"gammainc_der did not converge after {n} iterations", |
721 | | - RuntimeWarning, |
| 733 | + return ( |
| 734 | + (sum_a, log_gamma_k_plus_n_plus_1, k_plus_n), |
| 735 | + (term <= precision), |
722 | 736 | ) |
723 | | - return np.nan |
724 | 737 |
|
725 | | - k_plus_n = k |
| 738 | + init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] |
| 739 | + constant = [log_x] |
| 740 | + sum_a = _make_scalar_loop( |
| 741 | + max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" |
| 742 | + ) |
| 743 | + |
| 744 | + # Second loop |
| 745 | + n = np.array(0, dtype="int32") |
726 | 746 | log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 |
727 | | - sum_b = 0.0 |
728 | | - for n in range(0, max_iters + 1): |
729 | | - term = np.exp( |
730 | | - k_plus_n * log_x - log_gamma_k_plus_n_plus_1 |
731 | | - ) * scipy.special.digamma(k_plus_n + 1) |
732 | | - sum_b += term |
| 747 | + k_plus_n = k |
| 748 | + sum_b0 = np.array(0.0, dtype=dtype) |
733 | 749 |
|
734 | | - if term <= precision and n >= 1: # Require at least two iterations |
735 | | - return np.exp(-x) * (log_x * sum_a - sum_b) |
| 750 | + def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x): |
| 751 | + term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1) |
| 752 | + sum_b += term |
736 | 753 |
|
737 | | - log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) |
| 754 | + log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) |
| 755 | + n += 1 |
738 | 756 | k_plus_n += 1 |
| 757 | + return ( |
| 758 | + (sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n), |
| 759 | + # Require at least two iterations |
| 760 | + ((term <= precision) & (n > 1)), |
| 761 | + ) |
739 | 762 |
|
740 | | - warnings.warn( |
741 | | - f"gammainc_der did not converge after {n} iterations", |
742 | | - RuntimeWarning, |
| 763 | + init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n] |
| 764 | + constant = [log_x] |
| 765 | + sum_b, *_ = _make_scalar_loop( |
| 766 | + max_iters, init, constant, inner_loop_b, name="gammainc_grad_b" |
743 | 767 | ) |
744 | | - return np.nan |
745 | | - |
746 | | - def c_code(self, *args, **kwargs): |
747 | | - raise NotImplementedError() |
748 | | - |
749 | | - |
750 | | -gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der") |
751 | 768 |
|
752 | | - |
753 | | -class GammaIncCDer(BinaryScalarOp): |
754 | | - """ |
755 | | - Gradient of the the regularized upper gamma function (Q) wrt to the first |
756 | | - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` |
| 769 | + grad_approx = exp(-x) * (log_x * sum_a - sum_b) |
| 770 | + return grad_approx |
| 771 | + |
| 772 | + zero_branch = eq(x, 0) |
| 773 | + sqrt_exp = -756 - x**2 + 60 * x |
| 774 | + gammaincc_branch = ( |
| 775 | + ((k < 0.8) & (x > 15)) |
| 776 | + | ((k < 12) & (x > 30)) |
| 777 | + | ((sqrt_exp > 0) & (k < sqrt(sqrt_exp))) |
| 778 | + ) |
| 779 | + grad = switch( |
| 780 | + zero_branch, |
| 781 | + 0, |
| 782 | + switch( |
| 783 | + gammaincc_branch, |
| 784 | + -gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)), |
| 785 | + grad_approx(skip_loop=zero_branch | gammaincc_branch), |
| 786 | + ), |
| 787 | + ) |
| 788 | + return grad |
| 789 | + |
| 790 | + |
| 791 | +def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): |
| 792 | + """Gradient of the regularized upper gamma function (Q) wrt to the first |
| 793 | + argument (k, a.k.a. alpha). |
| 794 | +
|
| 795 | + Adapted from STAN `grad_reg_inc_gamma.hpp` |
| 796 | +
|
| 797 | + skip_loops is used for faster branching when this function is called by `gammainc_der` |
757 | 798 | """ |
| 799 | + dtype = upcast(k.type.dtype, x.type.dtype, "float32") |
758 | 800 |
|
759 | | - @staticmethod |
760 | | - def st_impl(k, x): |
761 | | - gamma_k = scipy.special.gamma(k) |
762 | | - digamma_k = scipy.special.digamma(k) |
763 | | - log_x = np.log(x) |
764 | | - |
765 | | - # asymptotic expansion http://dlmf.nist.gov/8.11#E2 |
766 | | - if (x >= k) and (x >= 8): |
767 | | - S = 0 |
768 | | - k_minus_one_minus_n = k - 1 |
769 | | - fac = k_minus_one_minus_n |
770 | | - dfac = 1 |
771 | | - xpow = x |
| 801 | + gamma_k = gamma(k) |
| 802 | + digamma_k = psi(k) |
| 803 | + log_x = log(x) |
| 804 | + |
| 805 | + def approx_a(skip_loop): |
| 806 | + n_steps = switch( |
| 807 | + skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32") |
| 808 | + ) |
| 809 | + sum_a0 = np.array(0.0, dtype=dtype) |
| 810 | + dfac = np.array(1.0, dtype=dtype) |
| 811 | + xpow = x |
| 812 | + k_minus_one_minus_n = k - 1 |
| 813 | + fac = k_minus_one_minus_n |
| 814 | + delta = true_div(dfac, xpow) |
| 815 | + |
| 816 | + def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x): |
| 817 | + sum_a += delta |
| 818 | + xpow *= x |
| 819 | + k_minus_one_minus_n -= 1 |
| 820 | + dfac = k_minus_one_minus_n * dfac + fac |
| 821 | + fac *= k_minus_one_minus_n |
772 | 822 | delta = dfac / xpow |
| 823 | + return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), () |
773 | 824 |
|
774 | | - for n in range(1, 10): |
775 | | - k_minus_one_minus_n -= 1 |
776 | | - S += delta |
777 | | - xpow *= x |
778 | | - dfac = k_minus_one_minus_n * dfac + fac |
779 | | - fac *= k_minus_one_minus_n |
780 | | - delta = dfac / xpow |
781 | | - if np.isinf(delta): |
782 | | - warnings.warn( |
783 | | - "gammaincc_der did not converge", |
784 | | - RuntimeWarning, |
785 | | - ) |
786 | | - return np.nan |
| 825 | + init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] |
| 826 | + constant = [x] |
| 827 | + sum_a = _make_scalar_loop( |
| 828 | + n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a" |
| 829 | + ) |
| 830 | + grad_approx_a = ( |
| 831 | + gammaincc(k, x) * (log_x - digamma_k) |
| 832 | + + exp(-x + (k - 1) * log_x) * sum_a / gamma_k |
| 833 | + ) |
| 834 | + return grad_approx_a |
787 | 835 |
|
| 836 | + def approx_b(skip_loop): |
| 837 | + max_iters = switch( |
| 838 | + skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32") |
| 839 | + ) |
| 840 | + log_precision = np.array(np.log(1e-6), dtype=config.floatX) |
| 841 | + |
| 842 | + sum_b0 = np.array(0.0, dtype=dtype) |
| 843 | + log_s = np.array(0.0, dtype=dtype) |
| 844 | + s_sign = np.array(1, dtype="int8") |
| 845 | + n = np.array(1, dtype="int32") |
| 846 | + log_delta = log_s - 2 * log(k) |
| 847 | + |
| 848 | + def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x): |
| 849 | + delta = exp(log_delta) |
| 850 | + sum_b += switch(s_sign > 0, delta, -delta) |
| 851 | + s_sign = -s_sign |
| 852 | + |
| 853 | + # log will cast >int16 to float64 |
| 854 | + log_s_inc = log_x - log(n) |
| 855 | + if log_s_inc.type.dtype != log_s.type.dtype: |
| 856 | + log_s_inc = log_s_inc.astype(log_s.type.dtype) |
| 857 | + log_s += log_s_inc |
| 858 | + |
| 859 | + new_log_delta = log_s - 2 * log(n + k) |
| 860 | + if new_log_delta.type.dtype != log_delta.type.dtype: |
| 861 | + new_log_delta = new_log_delta.astype(log_delta.type.dtype) |
| 862 | + log_delta = new_log_delta |
| 863 | + |
| 864 | + n += 1 |
788 | 865 | return ( |
789 | | - scipy.special.gammaincc(k, x) * (log_x - digamma_k) |
790 | | - + np.exp(-x + (k - 1) * log_x) * S / gamma_k |
791 | | - ) |
792 | | - |
793 | | - # gradient of series expansion http://dlmf.nist.gov/8.7#E3 |
794 | | - else: |
795 | | - log_precision = np.log(1e-6) |
796 | | - max_iters = int(1e5) |
797 | | - S = 0 |
798 | | - log_s = 0.0 |
799 | | - s_sign = 1 |
800 | | - log_delta = log_s - 2 * np.log(k) |
801 | | - for n in range(1, max_iters + 1): |
802 | | - S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) |
803 | | - s_sign = -s_sign |
804 | | - log_s += log_x - np.log(n) |
805 | | - log_delta = log_s - 2 * np.log(n + k) |
806 | | - |
807 | | - if np.isinf(log_delta): |
808 | | - warnings.warn( |
809 | | - "gammaincc_der did not converge", |
810 | | - RuntimeWarning, |
811 | | - ) |
812 | | - return np.nan |
813 | | - |
814 | | - if log_delta <= log_precision: |
815 | | - return ( |
816 | | - scipy.special.gammainc(k, x) * (digamma_k - log_x) |
817 | | - + np.exp(k * log_x) * S / gamma_k |
818 | | - ) |
819 | | - |
820 | | - warnings.warn( |
821 | | - f"gammaincc_der did not converge after {n} iterations", |
822 | | - RuntimeWarning, |
| 866 | + (sum_b, log_s, s_sign, log_delta, n), |
| 867 | + log_delta <= log_precision, |
823 | 868 | ) |
824 | | - return np.nan |
825 | | - |
826 | | - def impl(self, k, x): |
827 | | - return self.st_impl(k, x) |
828 | | - |
829 | | - def c_code(self, *args, **kwargs): |
830 | | - raise NotImplementedError() |
831 | 869 |
|
832 | | - |
833 | | -gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") |
| 870 | + init = [sum_b0, log_s, s_sign, log_delta, n] |
| 871 | + constant = [k, log_x] |
| 872 | + sum_b = _make_scalar_loop( |
| 873 | + max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" |
| 874 | + ) |
| 875 | + grad_approx_b = ( |
| 876 | + gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k |
| 877 | + ) |
| 878 | + return grad_approx_b |
| 879 | + |
| 880 | + branch_a = (x >= k) & (x >= 8) |
| 881 | + return switch( |
| 882 | + branch_a, |
| 883 | + approx_a(skip_loop=~branch_a | skip_loops), |
| 884 | + approx_b(skip_loop=branch_a | skip_loops), |
| 885 | + ) |
834 | 886 |
|
835 | 887 |
|
836 | 888 | class GammaU(BinaryScalarOp): |
|
0 commit comments