|
| 1 | +import base64 |
| 2 | +import pickle |
1 | 3 | from functools import singledispatch |
2 | 4 | from numbers import Number |
3 | | -import pickle |
4 | 5 | from textwrap import indent |
5 | | -from typing import Any, Callable, Literal, Optional, Union |
6 | | -import base64 |
| 6 | +from typing import Any, Callable, Optional, Union |
7 | 7 |
|
8 | 8 | import numba |
9 | 9 | import numpy as np |
10 | | -from llvmlite import ir |
11 | | -from numba import TypingError, literal_unroll, types, literally |
| 10 | +from numba import TypingError, types |
12 | 11 | from numba.core import cgutils |
13 | | -from numba.cpython.unsafe.tuple import tuple_setitem |
14 | 12 | from numba.np import arrayobj |
15 | 13 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
16 | 14 |
|
17 | 15 | from pytensor import config |
18 | 16 | from pytensor.graph.basic import Apply |
19 | 17 | from pytensor.graph.op import Op |
20 | 18 | from pytensor.link.numba.dispatch import basic as numba_basic |
| 19 | +from pytensor.link.numba.dispatch import elemwise_codegen |
21 | 20 | from pytensor.link.numba.dispatch.basic import ( |
22 | 21 | create_numba_signature, |
23 | 22 | create_tuple_creator, |
24 | 23 | numba_funcify, |
25 | 24 | numba_njit, |
26 | 25 | use_optimized_cheap_pass, |
27 | 26 | ) |
28 | | -from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper |
29 | | -from pytensor.link.numba.dispatch import elemwise_codegen |
30 | 27 | from pytensor.link.utils import compile_function_src, get_name_for_object |
31 | 28 | from pytensor.scalar.basic import ( |
32 | 29 | AND, |
|
45 | 42 | from pytensor.scalar.basic import add as add_as |
46 | 43 | from pytensor.scalar.basic import scalar_maximum |
47 | 44 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
48 | | -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros |
| 45 | +from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum |
49 | 46 | from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad |
50 | 47 | from pytensor.tensor.type import scalar |
51 | 48 |
|
@@ -376,8 +373,7 @@ def careduce_maximum(input): |
376 | 373 | careduce_def_src = f""" |
377 | 374 | def {careduce_fn_name}({input_name}): |
378 | 375 | {careduce_assign_lines} |
379 | | - #return np.asarray({var_name}) |
380 | | - return {var_name} |
| 376 | + return np.asarray({var_name}) |
381 | 377 | """ |
382 | 378 |
|
383 | 379 | careduce_fn = compile_function_src( |
@@ -447,6 +443,7 @@ def axis_apply_fn(x): |
447 | 443 | } |
448 | 444 | } |
449 | 445 |
|
| 446 | + |
450 | 447 | @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) |
451 | 448 | def _vectorized( |
452 | 449 | typingctx, |
@@ -490,7 +487,6 @@ def _vectorized( |
490 | 487 | inplace_pattern = inplace_pattern.literal_value |
491 | 488 | inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) |
492 | 489 |
|
493 | | - n_inputs = len(inputs) |
494 | 490 | n_outputs = len(output_bc_patterns) |
495 | 491 |
|
496 | 492 | if not len(inputs) > 0: |
@@ -531,7 +527,10 @@ def codegen( |
531 | 527 |
|
532 | 528 | [_, _, _, _, _, inputs] = args |
533 | 529 | inputs = cgutils.unpack_tuple(builder, inputs) |
534 | | - inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)] |
| 530 | + inputs = [ |
| 531 | + arrayobj.make_array(ty)(ctx, builder, val) |
| 532 | + for ty, val in zip(input_types, inputs) |
| 533 | + ] |
535 | 534 | in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] |
536 | 535 |
|
537 | 536 | iter_shape = elemwise_codegen.compute_itershape( |
@@ -586,14 +585,22 @@ def _check_input_shapes(*_): |
586 | 585 | return outputs[0]._getvalue() |
587 | 586 |
|
588 | 587 | for inplace_idx in dict(inplace_pattern): |
589 | | - ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value()) |
590 | | - return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs]) |
| 588 | + ctx.nrt.incref( |
| 589 | + builder, |
| 590 | + sig.return_type.types[inplace_idx], |
| 591 | + outputs[inplace_idx]._get_value(), |
| 592 | + ) |
| 593 | + return ctx.make_tuple( |
| 594 | + builder, sig.return_type, [out._getvalue() for out in outputs] |
| 595 | + ) |
591 | 596 |
|
592 | 597 | # TODO check inplace_pattern |
593 | | - ret_type = types.Tuple([ |
594 | | - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") |
595 | | - for dtype in output_dtypes |
596 | | - ]) |
| 598 | + ret_type = types.Tuple( |
| 599 | + [ |
| 600 | + types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") |
| 601 | + for dtype in output_dtypes |
| 602 | + ] |
| 603 | + ) |
597 | 604 | if len(output_dtypes) == 1: |
598 | 605 | ret_type = ret_type.types[0] |
599 | 606 | sig = ret_type(*arg_types) |
@@ -649,6 +656,40 @@ def elemwise_wrapper(*inputs): |
649 | 656 | return elemwise_wrapper |
650 | 657 |
|
651 | 658 |
|
| 659 | +@numba_funcify.register(Sum) |
| 660 | +def numba_funcify_Sum(op, node, **kwargs): |
| 661 | + axes = op.axis |
| 662 | + if axes is None: |
| 663 | + axes = list(range(node.inputs[0].ndim)) |
| 664 | + |
| 665 | + axes = list(axes) |
| 666 | + |
| 667 | + ndim_input = node.inputs[0].ndim |
| 668 | + |
| 669 | + if hasattr(op, "acc_dtype") and op.acc_dtype is not None: |
| 670 | + acc_dtype = op.acc_dtype |
| 671 | + else: |
| 672 | + acc_dtype = node.outputs[0].type.dtype |
| 673 | + |
| 674 | + np_acc_dtype = np.dtype(acc_dtype) |
| 675 | + |
| 676 | + if ndim_input == len(axes): |
| 677 | + |
| 678 | + @numba_njit(fastmath=True) |
| 679 | + def impl_sum(array): |
| 680 | + # TODO The accumulation itself should happen in acc_dtype... |
| 681 | + return np.asarray(array.sum()).astype(np_acc_dtype) |
| 682 | + |
| 683 | + else: |
| 684 | + |
| 685 | + @numba_njit(fastmath=True) |
| 686 | + def impl_sum(array): |
| 687 | + # TODO The accumulation itself should happen in acc_dtype... |
| 688 | + return array.sum(axes).astype(np_acc_dtype) |
| 689 | + |
| 690 | + return impl_sum |
| 691 | + |
| 692 | + |
652 | 693 | @numba_funcify.register(CAReduce) |
653 | 694 | def numba_funcify_CAReduce(op, node, **kwargs): |
654 | 695 | axes = op.axis |
|
0 commit comments