Skip to content

Commit 114ed73

Browse files
committed
allow ewise ops between gbvector and gbmat
1 parent 8e8b29c commit 114ed73

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

src/indexutils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,15 @@ function idx(I)
3333
Union{UnitRange, StepRange, Vector, Integer}, typeof(I)))
3434
end
3535
end
36+
37+
# This function assumes that szA and szB are
38+
# technically equal and that
39+
# 1 <= length(szA | szB) <= 2
40+
# size checks should be done elsewhere.
41+
function _combinesizes(A, B)
42+
if A isa AbstractVector || B isa AbstractVector
43+
return (size(A, 1), size(A, 2))
44+
else
45+
return size(A)
46+
end
47+
end

src/operations/ewise.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ function emul!(
3232
)
3333
mask, accum = _handlenothings(mask, accum)
3434
desc = _handledescriptor(desc; in1=A, in2=B)
35-
size(C) == size(A) == size(B) || throw(DimensionMismatch())
35+
size(C, 1) == size(A, 1) == size(B, 1) &&
36+
size(C, 2) == size(A, 2) == size(B, 2) || throw(DimensionMismatch())
3637
op = binaryop(op, eltype(A), eltype(B))
3738
accum = getaccum(accum, eltype(C))
3839
if op isa TypedBinaryOperator
@@ -76,7 +77,7 @@ function emul(
7677
desc = nothing
7778
)
7879
t = inferbinarytype(eltype(A), eltype(B), op)
79-
C = similar(A, t, size(A); fill=_promotefill(parent(A).fill, parent(B).fill))
80+
C = similar(A, t, _combinesizes(A, B); fill=_promotefill(parent(A).fill, parent(B).fill))
8081
return emul!(C, A, B, op; mask, accum, desc)
8182
end
8283

@@ -116,7 +117,8 @@ function eadd!(
116117
)
117118
mask, accum = _handlenothings(mask, accum)
118119
desc = _handledescriptor(desc; in1=A, in2 = B)
119-
size(C) == size(A) == size(B) || throw(DimensionMismatch())
120+
size(C, 1) == size(A, 1) == size(B, 1) &&
121+
size(C, 2) == size(A, 2) == size(B, 2) || throw(DimensionMismatch())
120122
op = binaryop(op, eltype(A), eltype(B))
121123
accum = getaccum(accum, eltype(C))
122124
if op isa TypedBinaryOperator
@@ -159,7 +161,7 @@ function eadd(
159161
desc = nothing
160162
)
161163
t = inferbinarytype(eltype(A), eltype(B), op)
162-
C = similar(A, t, size(A); fill=_promotefill(parent(A).fill, parent(B).fill))
164+
C = similar(A, t, _combinesizes(A, B); fill=_promotefill(parent(A).fill, parent(B).fill))
163165
return eadd!(C, A, B, op; mask, accum, desc)
164166
end
165167

@@ -200,7 +202,8 @@ function eunion!(
200202
) where {T, U}
201203
mask, accum = _handlenothings(mask, accum)
202204
desc = _handledescriptor(desc; in1=A, in2 = B)
203-
size(C) == size(A) == size(B) || throw(DimensionMismatch())
205+
size(C, 1) == size(A, 1) == size(B, 1) &&
206+
size(C, 2) == size(A, 2) == size(B, 2) || throw(DimensionMismatch())
204207
op = binaryop(op, eltype(A), eltype(B))
205208
accum = getaccum(accum, eltype(C))
206209
if op isa TypedBinaryOperator
@@ -243,7 +246,7 @@ function eunion(
243246
desc = nothing
244247
) where {T, U}
245248
t = inferbinarytype(eltype(A), eltype(B), op)
246-
C = similar(A, t, size(A); fill=_promotefill(parent(A).fill, parent(B).fill))
249+
C = similar(A, t, _combinesizes(A, B); fill=_promotefill(parent(A).fill, parent(B).fill))
247250
return eunion!(C, A, α, B, β, op; mask, accum, desc)
248251
end
249252

0 commit comments

Comments
 (0)