|
7 | 7 | ) |
8 | 8 | from ._array_object import Array |
9 | 9 | from ._dtypes import float32, complex64 |
10 | | -from ._flags import requires_api_version |
| 10 | +from ._flags import requires_api_version, get_array_api_strict_flags |
11 | 11 | from ._creation_functions import zeros |
12 | 12 | from ._manipulation_functions import concat |
13 | 13 |
|
@@ -89,14 +89,16 @@ def prod( |
89 | 89 | ) -> Array: |
90 | 90 | if x.dtype not in _numeric_dtypes: |
91 | 91 | raise TypeError("Only numeric dtypes are allowed in prod") |
92 | | - # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that |
93 | | - # for integers, but not for float32 or complex64, so we need to |
94 | | - # special-case it here |
| 92 | + |
95 | 93 | if dtype is None: |
96 | | - if x.dtype == float32: |
97 | | - dtype = np.float64 |
98 | | - elif x.dtype == complex64: |
99 | | - dtype = np.complex128 |
| 94 | + # Note: In versions prior to 2023.12, sum() and prod() upcast for all |
| 95 | + # dtypes when dtype=None. For 2023.12, the behavior is the same as in |
| 96 | + # NumPy (only upcast for integral dtypes). |
| 97 | + if get_array_api_strict_flags()['api_version'] < '2023.12': |
| 98 | + if x.dtype == float32: |
| 99 | + dtype = np.float64 |
| 100 | + elif x.dtype == complex64: |
| 101 | + dtype = np.complex128 |
100 | 102 | else: |
101 | 103 | dtype = dtype._np_dtype |
102 | 104 | return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) |
@@ -126,14 +128,16 @@ def sum( |
126 | 128 | ) -> Array: |
127 | 129 | if x.dtype not in _numeric_dtypes: |
128 | 130 | raise TypeError("Only numeric dtypes are allowed in sum") |
129 | | - # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that |
130 | | - # for integers, but not for float32 or complex64, so we need to |
131 | | - # special-case it here |
| 131 | + |
132 | 132 | if dtype is None: |
133 | | - if x.dtype == float32: |
134 | | - dtype = np.float64 |
135 | | - elif x.dtype == complex64: |
136 | | - dtype = np.complex128 |
| 133 | + # Note: In versions prior to 2023.12, sum() and prod() upcast for all |
| 134 | + # dtypes when dtype=None. For 2023.12, the behavior is the same as in |
| 135 | + # NumPy (only upcast for integral dtypes). |
| 136 | + if get_array_api_strict_flags()['api_version'] < '2023.12': |
| 137 | + if x.dtype == float32: |
| 138 | + dtype = np.float64 |
| 139 | + elif x.dtype == complex64: |
| 140 | + dtype = np.complex128 |
137 | 141 | else: |
138 | 142 | dtype = dtype._np_dtype |
139 | 143 | return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) |
|
0 commit comments