|
3 | 3 | from llvmlite import ir |
4 | 4 | from numba import types |
5 | 5 | from numba.core import cgutils |
| 6 | +from numba.core.base import BaseContext |
6 | 7 | from numba.np import arrayobj |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def compute_itershape( |
10 | | - ctx, |
| 11 | + ctx: BaseContext, |
11 | 12 | builder: ir.IRBuilder, |
12 | 13 | in_shapes, |
13 | 14 | broadcast_pattern, |
14 | 15 | ): |
15 | 16 | one = ir.IntType(64)(1) |
16 | 17 | ndim = len(in_shapes[0]) |
17 | | - #shape = [ir.IntType(64)(1) for _ in range(ndim)] |
18 | 18 | shape = [None] * ndim |
19 | 19 | for i in range(ndim): |
20 | | - # TODO Error checking... |
21 | | - # What if all shapes are 0? |
22 | | - for bc, in_shape in zip(broadcast_pattern, in_shapes): |
| 20 | + for j, (bc, in_shape) in enumerate( |
| 21 | + zip(broadcast_pattern, in_shapes, strict=True) |
| 22 | + ): |
| 23 | + length = in_shape[i] |
23 | 24 | if bc[i]: |
24 | | - # TODO |
25 | | - # raise error if length != 1 |
26 | | - pass |
| 25 | + with builder.if_then( |
| 26 | + builder.icmp_unsigned("!=", length, one), likely=False |
| 27 | + ): |
| 28 | + msg = ( |
| 29 | + f"Input {j} to elemwise is expected to have shape 1 in axis {i}" |
| 30 | + ) |
| 31 | + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) |
| 32 | + elif shape[i] is not None: |
| 33 | + with builder.if_then( |
| 34 | + builder.icmp_unsigned("!=", length, shape[i]), likely=False |
| 35 | + ): |
| 36 | + with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( |
| 37 | + then, |
| 38 | + otherwise, |
| 39 | + ): |
| 40 | + with then: |
| 41 | + msg = ( |
| 42 | + f"Incompative shapes for input {j} and axis {i} of " |
| 43 | + f"elemwise. Input {j} has shape 1, but is not statically " |
| 44 | + "known to have shape 1, and thus not broadcastable." |
| 45 | + ) |
| 46 | + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) |
| 47 | + with otherwise: |
| 48 | + msg = ( |
| 49 | + f"Input {j} to elemwise has an incompatible " |
| 50 | + f"shape in axis {i}." |
| 51 | + ) |
| 52 | + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) |
27 | 53 | else: |
28 | | - # TODO |
29 | | - # if shape[i] is not None: |
30 | | - # raise Error if != |
31 | | - shape[i] = in_shape[i] |
| 54 | + shape[i] = length |
32 | 55 | for i in range(ndim): |
33 | 56 | if shape[i] is None: |
34 | 57 | shape[i] = one |
|
0 commit comments