Skip to content

Commit 0dbf7dd

Browse files
author
Hongyuhe
committed
update
1 parent 34af083 commit 0dbf7dd

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,11 @@ def full(
891891
) -> array:
892892
if isinstance(shape, int):
893893
shape = (shape,)
894-
894+
if dtype is None :
895+
if isinstance(fill_value, (bool)):
896+
dtype = "bool"
897+
elif isinstance(fill_value, int):
898+
dtype = 'int64'
895899
return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device)
896900

897901

@@ -1148,11 +1152,9 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
11481152
def sign(x: array, /) -> array:
11491153
# paddle sign() does not support complex numbers and does not propagate
11501154
# nans. See https://github.com/data-apis/array-api-compat/issues/136
1151-
if paddle.is_complex(x):
1152-
out = x / paddle.abs(x)
1153-
# sign(0) = 0 but the above formula would give nan
1154-
out[x == 0 + 0j] = 0 + 0j
1155-
return out
1155+
if paddle.is_complex(x) and x.ndim == 0 and x.item() == 0j:
1156+
# Handle 0-D complex zero explicitly
1157+
return paddle.zeros_like(x, dtype=x.dtype)
11561158
else:
11571159
out = paddle.sign(x)
11581160
if paddle.is_floating_point(x):

array_api_compat/paddle/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def qr(x: array, mode: Optional[str] = None) -> array:
154154

155155
def svd(x: array, full_matrices: Optional[bool]= None) -> array:
156156
if full_matrices is None :
157-
return tuple_to_namedtuple(paddle.linalg.svd(x), ['U', 'S', 'Vh'])
157+
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
158158
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
159159

160160
__all__ = linalg_all + [

0 commit comments

Comments
 (0)