|
3 | 3 | import warnings |
4 | 4 |
|
5 | 5 | from ._lib import _compat, _utils |
6 | | -from ._lib._compat import array_namespace |
| 6 | +from ._lib._compat import ( |
| 7 | + array_namespace, is_torch_namespace, is_array_api_strict_namespace |
| 8 | +) |
7 | 9 | from ._lib._typing import Array, ModuleType |
8 | 10 |
|
9 | 11 | __all__ = [ |
|
14 | 16 | "kron", |
15 | 17 | "setdiff1d", |
16 | 18 | "sinc", |
| 19 | + "pad", |
17 | 20 | ] |
18 | 21 |
|
19 | 22 |
|
@@ -538,3 +541,54 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
538 | 541 | xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), |
539 | 542 | ) |
540 | 543 | return xp.sin(y) / y |
| 544 | + |
| 545 | + |
| 546 | +def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs): |
| 547 | + """ |
| 548 | + Pad the input array. |
| 549 | +
|
| 550 | + Parameters |
| 551 | + ---------- |
| 552 | + x : array |
| 553 | + Input array |
| 554 | + pad_width: int |
| 555 | + Pad the input array with this many elements from each side |
| 556 | + mode: str, optional |
| 557 | + Only "constant" mode is currently supported. |
| 558 | + xp : array_namespace, optional |
| 559 | + The standard-compatible namespace for `x`. Default: infer. |
| 560 | + constant_values: python scalar, optional |
| 561 | + Use this value to pad the input. Default is zero. |
| 562 | +
|
| 563 | + Returns |
| 564 | + ------- |
| 565 | + array |
| 566 | + The input array, padded with ``pad_width`` elements equal to ``constant_values`` |
| 567 | + """ |
| 568 | + # xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse |
| 569 | + # http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045 |
| 570 | + |
| 571 | + if mode != 'constant': |
| 572 | + raise NotImplementedError() |
| 573 | + |
| 574 | + value = kwargs.get("constant_values", 0) |
| 575 | + if kwargs and list(kwargs.keys()) != ['constant_values']: |
| 576 | + raise ValueError(f"Unknown kwargs: {kwargs}") |
| 577 | + |
| 578 | + if xp is None: |
| 579 | + xp = array_namespace(x) |
| 580 | + |
| 581 | + if is_array_api_strict_namespace(xp): |
| 582 | + padded = xp.full( |
| 583 | + tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype |
| 584 | + ) |
| 585 | + padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x |
| 586 | + return padded |
| 587 | + elif is_torch_namespace(xp): |
| 588 | + pad_width = xp.asarray(pad_width) |
| 589 | + pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) |
| 590 | + pad_width = xp.flip(pad_width, axis=(0,)).flatten() |
| 591 | + return xp.nn.functional.pad(x, tuple(pad_width), value=value) |
| 592 | + |
| 593 | + else: |
| 594 | + return xp.pad(x, pad_width, mode=mode, **kwargs) |
0 commit comments