|
13 | 13 | Callable, |
14 | 14 | Hashable, |
15 | 15 | ) |
16 | | -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, overload |
| 16 | +from typing import TYPE_CHECKING, Any, Literal, cast, overload |
17 | 17 |
|
18 | 18 | import numpy as np |
19 | 19 |
|
20 | 20 | from xarray.compat.array_api_compat import to_like_array |
21 | | -from xarray.computation.apply_ufunc import apply_ufunc |
22 | 21 | from xarray.core import dtypes, duck_array_ops, utils |
23 | 22 | from xarray.core.common import zeros_like |
24 | 23 | from xarray.core.duck_array_ops import datetime_to_numeric |
@@ -467,6 +466,8 @@ def cross( |
467 | 466 | " dimensions without coordinates must have have a length of 2 or 3" |
468 | 467 | ) |
469 | 468 |
|
| 469 | + from xarray.computation.apply_ufunc import apply_ufunc |
| 470 | + |
470 | 471 | c = apply_ufunc( |
471 | 472 | duck_array_ops.cross, |
472 | 473 | a, |
@@ -629,6 +630,8 @@ def dot( |
629 | 630 | # subscripts should be passed to np.einsum as arg, not as kwargs. We need |
630 | 631 | # to construct a partial function for apply_ufunc to work. |
631 | 632 | func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) |
| 633 | + from xarray.computation.apply_ufunc import apply_ufunc |
| 634 | + |
632 | 635 | result = apply_ufunc( |
633 | 636 | func, |
634 | 637 | *arrays, |
@@ -729,6 +732,8 @@ def where(cond, x, y, keep_attrs=None): |
729 | 732 | keep_attrs = _get_keep_attrs(default=False) |
730 | 733 |
|
731 | 734 | # alignment for three arguments is complicated, so don't support it yet |
| 735 | + from xarray.computation.apply_ufunc import apply_ufunc |
| 736 | + |
732 | 737 | result = apply_ufunc( |
733 | 738 | duck_array_ops.where, |
734 | 739 | cond, |
@@ -951,80 +956,3 @@ def _calc_idxminmax( |
951 | 956 | res.attrs = indx.attrs |
952 | 957 |
|
953 | 958 | return res |
954 | | - |
955 | | - |
956 | | -_T = TypeVar("_T", bound=Union["Dataset", "DataArray"]) |
957 | | -_U = TypeVar("_U", bound=Union["Dataset", "DataArray"]) |
958 | | -_V = TypeVar("_V", bound=Union["Dataset", "DataArray"]) |
959 | | - |
960 | | - |
961 | | -@overload |
962 | | -def unify_chunks(__obj: _T) -> tuple[_T]: ... |
963 | | - |
964 | | - |
965 | | -@overload |
966 | | -def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: ... |
967 | | - |
968 | | - |
969 | | -@overload |
970 | | -def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: ... |
971 | | - |
972 | | - |
973 | | -@overload |
974 | | -def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: ... |
975 | | - |
976 | | - |
977 | | -def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: |
978 | | - """ |
979 | | - Given any number of Dataset and/or DataArray objects, returns |
980 | | - new objects with unified chunk size along all chunked dimensions. |
981 | | -
|
982 | | - Returns |
983 | | - ------- |
984 | | - unified (DataArray or Dataset) – Tuple of objects with the same type as |
985 | | - *objects with consistent chunk sizes for all dask-array variables |
986 | | -
|
987 | | - See Also |
988 | | - -------- |
989 | | - dask.array.core.unify_chunks |
990 | | - """ |
991 | | - from xarray.core.dataarray import DataArray |
992 | | - |
993 | | - # Convert all objects to datasets |
994 | | - datasets = [ |
995 | | - obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy() |
996 | | - for obj in objects |
997 | | - ] |
998 | | - |
999 | | - # Get arguments to pass into dask.array.core.unify_chunks |
1000 | | - unify_chunks_args = [] |
1001 | | - sizes: dict[Hashable, int] = {} |
1002 | | - for ds in datasets: |
1003 | | - for v in ds._variables.values(): |
1004 | | - if v.chunks is not None: |
1005 | | - # Check that sizes match across different datasets |
1006 | | - for dim, size in v.sizes.items(): |
1007 | | - try: |
1008 | | - if sizes[dim] != size: |
1009 | | - raise ValueError( |
1010 | | - f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}" |
1011 | | - ) |
1012 | | - except KeyError: |
1013 | | - sizes[dim] = size |
1014 | | - unify_chunks_args += [v._data, v._dims] |
1015 | | - |
1016 | | - # No dask arrays: Return inputs |
1017 | | - if not unify_chunks_args: |
1018 | | - return objects |
1019 | | - |
1020 | | - chunkmanager = get_chunked_array_type(*list(unify_chunks_args)) |
1021 | | - _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args) |
1022 | | - chunked_data_iter = iter(chunked_data) |
1023 | | - out: list[Dataset | DataArray] = [] |
1024 | | - for obj, ds in zip(objects, datasets, strict=True): |
1025 | | - for k, v in ds._variables.items(): |
1026 | | - if v.chunks is not None: |
1027 | | - ds._variables[k] = v.copy(data=next(chunked_data_iter)) |
1028 | | - out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) |
1029 | | - |
1030 | | - return tuple(out) |
0 commit comments