|
23 | 23 | from pytensor.scalar import upcast |
24 | 24 | from pytensor.tensor import as_tensor_variable |
25 | 25 | from pytensor.tensor import basic as at |
26 | | -from pytensor.tensor.basic import get_vector_length, second |
| 26 | +from pytensor.tensor.basic import alloc, second |
27 | 27 | from pytensor.tensor.exceptions import NotScalarConstantError |
28 | 28 | from pytensor.tensor.math import abs as pt_abs |
29 | 29 | from pytensor.tensor.math import all as pt_all |
@@ -1584,141 +1584,6 @@ def broadcast_shape_iter( |
1584 | 1584 | return tuple(result_dims) |
1585 | 1585 |
|
1586 | 1586 |
|
1587 | | -class BroadcastTo(COp): |
1588 | | - """An `Op` for `numpy.broadcast_to`.""" |
1589 | | - |
1590 | | - _output_type_depends_on_input_value = True |
1591 | | - |
1592 | | - __props__ = () |
1593 | | - |
1594 | | - view_map = {0: [0]} |
1595 | | - |
1596 | | - def __call__(self, a, shape, **kwargs): |
1597 | | - return super().__call__(a, *shape, **kwargs) |
1598 | | - |
1599 | | - def make_node(self, a, *shape): |
1600 | | - a = at.as_tensor_variable(a) |
1601 | | - |
1602 | | - shape, static_shape = at.infer_static_shape(shape) |
1603 | | - |
1604 | | - if len(shape) < a.ndim: |
1605 | | - raise ValueError( |
1606 | | - f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims" |
1607 | | - ) |
1608 | | - |
1609 | | - out = TensorType(dtype=a.type.dtype, shape=static_shape)() |
1610 | | - |
1611 | | - # Attempt to prevent in-place operations on this view-based output |
1612 | | - out.tag.indestructible = True |
1613 | | - |
1614 | | - return Apply(self, [a] + shape, [out]) |
1615 | | - |
1616 | | - def perform(self, node, inputs, output_storage): |
1617 | | - a, *shape = inputs |
1618 | | - z = output_storage[0] |
1619 | | - z[0] = np.broadcast_to(a, shape) |
1620 | | - |
1621 | | - def grad(self, inputs, outputs_gradients): |
1622 | | - a, *shape = inputs |
1623 | | - (dout,) = outputs_gradients |
1624 | | - |
1625 | | - # Determine the dimensions that were added by broadcasting |
1626 | | - new_dims = list(range(dout.ndim - a.ndim)) |
1627 | | - |
1628 | | - d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) |
1629 | | - |
1630 | | - # Determine the dimensions that were broadcast |
1631 | | - _, static_shape = at.infer_static_shape(shape) |
1632 | | - |
1633 | | - # TODO: This needs to be performed at run-time when static shape |
1634 | | - # information isn't available. |
1635 | | - bcast_sums = [ |
1636 | | - i |
1637 | | - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) |
1638 | | - if a_s == 1 and s_s != 1 |
1639 | | - ] |
1640 | | - |
1641 | | - if bcast_sums: |
1642 | | - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) |
1643 | | - |
1644 | | - return [d_wrt_a] + [ |
1645 | | - grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) |
1646 | | - ] |
1647 | | - |
1648 | | - def infer_shape(self, fgraph, node, ins_shapes): |
1649 | | - return [node.inputs[1:]] |
1650 | | - |
1651 | | - def c_code(self, node, name, inputs, outputs, sub): |
1652 | | - inp_dims = node.inputs[0].ndim |
1653 | | - out_dims = node.outputs[0].ndim |
1654 | | - new_dims = out_dims - inp_dims |
1655 | | - |
1656 | | - (x, *shape) = inputs |
1657 | | - (out,) = outputs |
1658 | | - fail = sub["fail"] |
1659 | | - |
1660 | | - # TODO: Could just use `PyArray_Return`, no? |
1661 | | - dims_array = ", ".join( |
1662 | | - [ |
1663 | | - f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]" |
1664 | | - for i, shape in enumerate(shape) |
1665 | | - ] |
1666 | | - ) |
1667 | | - |
1668 | | - src = ( |
1669 | | - """ |
1670 | | - npy_intp itershape[%(out_dims)s] = {%(dims_array)s}; |
1671 | | -
|
1672 | | - NpyIter *iter; |
1673 | | - PyArrayObject *ops[1] = {%(x)s}; |
1674 | | - npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; |
1675 | | - npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; |
1676 | | - PyArray_Descr *op_dtypes[1] = {NULL}; |
1677 | | - int oa_ndim = %(out_dims)s; |
1678 | | - int* op_axes[1] = {NULL}; |
1679 | | - npy_intp buffersize = 0; |
1680 | | -
|
1681 | | - for(int i = 0; i < %(inp_dims)s; i++) |
1682 | | - { |
1683 | | - if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s])) |
1684 | | - { |
1685 | | - PyErr_Format(PyExc_ValueError, |
1686 | | - "Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.", |
1687 | | - i, |
1688 | | - (long long int) itershape[i + %(new_dims)s], |
1689 | | - (long long int) PyArray_DIMS(%(x)s)[i] |
1690 | | - ); |
1691 | | - %(fail)s |
1692 | | - } |
1693 | | - } |
1694 | | -
|
1695 | | - iter = NpyIter_AdvancedNew( |
1696 | | - 1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize |
1697 | | - ); |
1698 | | - %(out)s = NpyIter_GetIterView(iter, 0); |
1699 | | -
|
1700 | | - if(%(out)s == NULL){ |
1701 | | - NpyIter_Deallocate(iter); |
1702 | | - %(fail)s; |
1703 | | - } |
1704 | | -
|
1705 | | - if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { |
1706 | | - %(fail)s; |
1707 | | - } |
1708 | | -
|
1709 | | - """ |
1710 | | - % locals() |
1711 | | - ) |
1712 | | - |
1713 | | - return src |
1714 | | - |
1715 | | - def c_code_cache_version(self): |
1716 | | - return (2,) |
1717 | | - |
1718 | | - |
1719 | | -broadcast_to_ = BroadcastTo() |
1720 | | - |
1721 | | - |
1722 | 1587 | def geomspace(start, end, steps, base=10.0): |
1723 | 1588 | from pytensor.tensor.math import log |
1724 | 1589 |
|
@@ -1762,13 +1627,7 @@ def broadcast_to( |
1762 | 1627 | broadcasted array may refer to a single memory location. |
1763 | 1628 |
|
1764 | 1629 | """ |
1765 | | - x = at.as_tensor(x) |
1766 | | - shape_len = get_vector_length(shape) |
1767 | | - |
1768 | | - if x.ndim == 0 and shape_len == 0: |
1769 | | - return x |
1770 | | - |
1771 | | - return broadcast_to_(x, shape) |
| 1630 | + return alloc(x, *shape) |
1772 | 1631 |
|
1773 | 1632 |
|
1774 | 1633 | def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: |
|
0 commit comments