Skip to content

Commit 34af083

Browse files
author
Hongyuhe
committed
update
1 parent b946e82 commit 34af083

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
10141014

10151015

10161016
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
1017-
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
1017+
shape1 = x1.shape
1018+
shape2 = x2.shape
1019+
rank1 = len(shape1)
1020+
rank2 = len(shape2)
1021+
if rank1 == 0 or rank2 == 0:
1022+
raise ValueError(
1023+
f"Vector dot product requires non-scalar inputs (rank > 0). "
1024+
f"Got ranks {rank1} and {rank2} for shapes {shape1} and {shape2}."
1025+
)
1026+
try:
1027+
norm_axis1 = axis if axis >= 0 else rank1 + axis
1028+
if not (0 <= norm_axis1 < rank1):
1029+
raise IndexError # Axis out of bounds for x1
1030+
norm_axis2 = axis if axis >= 0 else rank2 + axis
1031+
if not (0 <= norm_axis2 < rank2):
1032+
raise IndexError # Axis out of bounds for x2
1033+
size1 = shape1[norm_axis1]
1034+
size2 = shape2[norm_axis2]
1035+
except IndexError:
1036+
raise ValueError(
1037+
f"Axis {axis} is out of bounds for input shapes {shape1} (rank {rank1}) "
1038+
f"and/or {shape2} (rank {rank2})."
1039+
)
1040+
1041+
if size1 != size2:
1042+
raise ValueError(
1043+
f"Inputs must have the same dimension size along the reduction axis ({axis}). "
1044+
f"Got shapes {shape1} and {shape2}, with sizes {size1} and {size2} "
1045+
f"along the normalized axis {norm_axis1} and {norm_axis2} respectively."
1046+
)
10181047
return paddle.linalg.vecdot(x1, x2, axis=axis)
10191048

10201049

0 commit comments

Comments
 (0)