Skip to content

Commit ba702df

Browse files
committed
correct mask handling, add Complement and Structural
1 parent 9736ceb commit ba702df

File tree

14 files changed

+118
-349
lines changed

14 files changed

+118
-349
lines changed

src/abstractgbarray.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,29 @@ function subassign!(
380380
C::AbstractGBArray, A::GBArrayOrTranspose, I, J;
381381
mask = nothing, accum = nothing, desc = nothing
382382
)
383+
# before we make I and J into GraphBLAS internal types
384+
# get their size to check if A should be reshaped from nx1 -> 1xn
385+
ni_sizecheck = I isa Colon ? size(C, 1) : length(I)
386+
nj_sizecheck = J isa Colon ? size(C, 2) : length(J)
383387
I, ni = idx(I)
384388
J, nj = idx(J)
385-
mask === nothing && (mask = C_NULL)
389+
desc = _handledescriptor(desc; in1=A)
390+
mask = _handlemask!(desc, mask)
386391
I = decrement!(I)
387392
J = decrement!(J)
388-
# we know A isn't adjoint/transpose on input
389-
desc = _handledescriptor(desc; in1=A)
390-
@wraperror LibGraphBLAS.GxB_Matrix_subassign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
393+
rereshape = false
394+
sz1 = size(A, 1)
395+
# reshape A: nx1 -> 1xn
396+
if A isa GBVector && (ni_sizecheck == size(A, 2) && nj_sizecheck == sz1)
397+
@wraperror LibGraphBLAS.GxB_Matrix_reshape(gbpointer(parent(A)), true, 1, sz1, C_NULL)
398+
rereshape = true
399+
end
400+
@wraperror LibGraphBLAS.GxB_Matrix_subassign(gbpointer(C), mask,
401+
_handleaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
402+
if rereshape # undo the reshape. Need size(A, 2) here
403+
@wraperror LibGraphBLAS.GxB_Matrix_reshape(
404+
gbpointer(parent(A)), true, sz1, 1, C_NULL)
405+
end
391406
increment!(I)
392407
increment!(J)
393408
return A
@@ -402,8 +417,8 @@ function subassign!(C::AbstractGBArray{T}, x, I, J;
402417
I = decrement!(I)
403418
J = decrement!(J)
404419
desc = _handledescriptor(desc)
405-
mask, accum = _handlenothings(mask, accum)
406-
_subassign(C, x, I, ni, J, nj, mask, getaccum(accum, eltype(C)), desc)
420+
mask = _handlemask!(desc, mask)
421+
_subassign(C, x, I, ni, J, nj, mask, _handleaccum(accum, eltype(C)), desc)
407422
increment!(I)
408423
increment!(J)
409424
return x
@@ -465,12 +480,11 @@ function assign!(
465480
)
466481
I, ni = idx(I)
467482
J, nj = idx(J)
468-
mask === nothing && (mask = C_NULL)
483+
desc = _handledescriptor(desc; in1=A)
484+
mask = _handlemask!(desc, mask)
469485
I = decrement!(I)
470486
J = decrement!(J)
471-
# we know A isn't adjoint/transpose on input
472-
desc = _handledescriptor(desc; in1=A)
473-
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
487+
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, _handleaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
474488
increment!(I)
475489
increment!(J)
476490
return A
@@ -485,7 +499,7 @@ function assign!(C::AbstractGBArray{T}, x, I, J;
485499
I = decrement!(I)
486500
J = decrement!(J)
487501
desc = _handledescriptor(desc)
488-
_assign(gbpointer(C), x, I, ni, J, nj, mask, getaccum(accum, eltype(C)), desc)
502+
_assign(gbpointer(C), x, I, ni, J, nj, mask, _handleaccum(accum, eltype(C)), desc)
489503
increment!(I)
490504
increment!(J)
491505
return x

src/operations/ewise.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ function emul!(
3030
accum = nothing,
3131
desc = nothing
3232
)
33-
mask, accum = _handlenothings(mask, accum)
3433
desc = _handledescriptor(desc; in1=A, in2=B)
34+
mask = _handlemask!(desc, mask)
3535
size(C, 1) == size(A, 1) == size(B, 1) &&
3636
size(C, 2) == size(A, 2) == size(B, 2) || throw(DimensionMismatch())
3737
op = binaryop(op, eltype(A), eltype(B))
38-
accum = getaccum(accum, eltype(C))
38+
accum = _handleaccum(accum, eltype(C))
3939
if op isa TypedBinaryOperator
4040
@wraperror LibGraphBLAS.GrB_Matrix_eWiseMult_BinaryOp(gbpointer(C), mask, accum, op, gbpointer(parent(A)), gbpointer(parent(B)), desc)
4141
return C
@@ -115,12 +115,13 @@ function eadd!(
115115
accum = nothing,
116116
desc = nothing
117117
)
118-
mask, accum = _handlenothings(mask, accum)
118+
119119
desc = _handledescriptor(desc; in1=A, in2 = B)
120+
mask = _handlemask!(desc, mask)
120121
size(C, 1) == size(A, 1) == size(B, 1) &&
121122
size(C, 2) == size(A, 2) == size(B, 2) || throw(DimensionMismatch())
122123
op = binaryop(op, eltype(A), eltype(B))
123-
accum = getaccum(accum, eltype(C))
124+
accum = _handleaccum(accum, eltype(C))
124125
if op isa TypedBinaryOperator
125126
@wraperror LibGraphBLAS.GrB_Matrix_eWiseAdd_BinaryOp(gbpointer(C), mask, accum, op, gbpointer(parent(A)), gbpointer(parent(B)), desc)
126127
return C
@@ -200,12 +201,13 @@ function eunion!(
200201
accum = nothing,
201202
desc = nothing
202203
) where {T, U}
203-
mask, accum = _handlenothings(mask, accum)
204+
204205
desc = _handledescriptor(desc; in1=A, in2 = B)
206+
mask = _handlemask!(desc, mask)
205207
size(C, 1) == size(A, 1) == size(B, 1) &&
206208
size(C, 2) == size(A, 2) == size(B, 2) || throw(DimensionMismatch())
207209
op = binaryop(op, eltype(A), eltype(B))
208-
accum = getaccum(accum, eltype(C))
210+
accum = _handleaccum(accum, eltype(C))
209211
if op isa TypedBinaryOperator
210212
@wraperror LibGraphBLAS.GxB_Matrix_eWiseUnion(gbpointer(C), mask, accum, op, gbpointer(parent(A)), GBScalar(α), gbpointer(parent(B)), GBScalar(β), desc)
211213
return C

src/operations/extract.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ function extract!(
6464
J, nj = idx(J)
6565
I isa Number && (I = UInt64[I])
6666
J isa Number && (J = UInt64[J])
67-
mask === nothing && (mask = C_NULL)
6867
desc = _handledescriptor(desc; in1 = A)
68+
mask = _handlemask!(desc, mask)
6969
I = decrement!(I)
7070
J = decrement!(J)
71-
@wraperror LibGraphBLAS.GrB_Matrix_extract(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
71+
@wraperror LibGraphBLAS.GrB_Matrix_extract(gbpointer(C), mask, _handleaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
7272
I isa AbstractVector && increment!(I)
7373
J isa AbstractVector && increment!(J)
7474
return C
@@ -136,8 +136,8 @@ function extract!(
136136
I, ni = idx(I)
137137
I = decrement!(I)
138138
desc = _handledescriptor(desc)
139-
mask === nothing && (mask = C_NULL)
140-
@wraperror LibGraphBLAS.GrB_Matrix_extract(gbpointer(w), mask, getaccum(accum, eltype(w)), gbpointer(u), I, ni, UInt64[0], 1, desc)
139+
mask = _handlemask!(desc, mask)
140+
@wraperror LibGraphBLAS.GrB_Matrix_extract(gbpointer(w), mask, _handleaccum(accum, eltype(w)), gbpointer(u), I, ni, UInt64[0], 1, desc)
141141
I isa AbstractVector && increment!(I)
142142
return w
143143
end

src/operations/kronecker.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ function LinearAlgebra.kron!(
1212
accum = nothing,
1313
desc = nothing
1414
)
15-
mask, accum = _handlenothings(mask, accum)
1615
desc = _handledescriptor(desc; in1=A, in2=B)
16+
mask = _handlemask!(desc, mask)
1717
op = binaryop(op, eltype(A), eltype(B))
18-
accum = getaccum(accum, eltype(C))
18+
accum = _handleaccum(accum, eltype(C))
1919
@wraperror LibGraphBLAS.GxB_kron(gbpointer(C), mask, accum, op, gbpointer(parent(A)), gbpointer(parent(B)), desc)
2020
return C
2121
end

src/operations/map.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ function apply!(
22
op, C::GBVecOrMat, A::GBArrayOrTranspose{T};
33
mask = nothing, accum = nothing, desc = nothing
44
) where {T}
5-
mask, accum = _handlenothings(mask, accum)
65
desc = _handledescriptor(desc; in1=A)
6+
mask = _handlemask!(desc, mask)
77
op = unaryop(op, eltype(A))
8-
accum = getaccum(accum, eltype(C))
8+
accum = _handleaccum(accum, eltype(C))
99
@wraperror LibGraphBLAS.GrB_Matrix_apply(gbpointer(C), mask, accum, op, gbpointer(parent(A)), desc)
1010
return C
1111
end
@@ -52,10 +52,10 @@ function apply!(
5252
op, C::GBVecOrMat, x, A::GBArrayOrTranspose{T};
5353
mask = nothing, accum = nothing, desc = nothing
5454
) where {T}
55-
mask, accum = _handlenothings(mask, accum)
5655
desc = _handledescriptor(desc; in2=A)
56+
mask = _handlemask!(desc, mask)
5757
op = binaryop(op, eltype(A), typeof(x))
58-
accum = getaccum(accum, eltype(C))
58+
accum = _handleaccum(accum, eltype(C))
5959
@wraperror LibGraphBLAS.GxB_Matrix_apply_BinaryOp1st(gbpointer(C), mask, accum, op, GBScalar(x), gbpointer(parent(A)), desc)
6060
return C
6161
end
@@ -79,10 +79,10 @@ function apply!(
7979
op, C::GBVecOrMat, A::GBArrayOrTranspose{T}, x;
8080
mask = nothing, accum = nothing, desc = nothing
8181
) where {T}
82-
mask, accum = _handlenothings(mask, accum)
8382
desc = _handledescriptor(desc; in1=A)
83+
mask = _handlemask!(desc, mask)
8484
op = binaryop(op, eltype(A), typeof(x))
85-
accum = getaccum(accum, eltype(C))
85+
accum = _handleaccum(accum, eltype(C))
8686
@wraperror LibGraphBLAS.GxB_Matrix_apply_BinaryOp2nd(gbpointer(C), mask, accum, op, gbpointer(parent(A)), GBScalar(x), desc)
8787
return C
8888
end

src/operations/mul.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ function LinearAlgebra.mul!(
1111
accum = nothing,
1212
desc = nothing
1313
)
14-
mask, accum = _handlenothings(mask, accum)
1514
desc = _handledescriptor(desc; in1=A, in2=B)
15+
mask = _handlemask!(desc, mask)
1616
size(A, 2) == size(B, 1) || throw(DimensionMismatch("size(A, 2) != size(B, 1)"))
1717
size(A, 1) == size(C, 1) || throw(DimensionMismatch("size(A, 1) != size(C, 1)"))
1818
size(B, 2) == size(C, 2) || throw(DimensionMismatch("size(B, 2) != size(C, 2)"))
1919
op = semiring(op, eltype(A), eltype(B))
20-
accum = getaccum(accum, eltype(C))
20+
accum = _handleaccum(accum, eltype(C))
2121
op isa TypedSemiring || throw(ArgumentError("$op is not a valid TypedSemiring"))
2222
@wraperror LibGraphBLAS.GrB_mxm(gbpointer(C), mask, accum, op, gbpointer(parent(A)), gbpointer(parent(B)), desc)
2323
return C

src/operations/operationutils.jl

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
2-
getaccum(::Nothing, t) = C_NULL
3-
getaccum(::Ptr{Nothing}, t) = C_NULL
4-
getaccum(op::Function, t) = binaryop(op, t, t)
5-
getaccum(op::Function, tleft, tright) = binaryop(op, tleft, tright)
6-
getaccum(op::TypedBinaryOperator, x...) = op
7-
81
inferunarytype(::Type{T}, f::F) where {T, F<:Base.Callable} = Base._return_type(f, Tuple{T})
92
inferunarytype(::Type{X}, op::TypedUnaryOperator) where X = ztype(op)
103

@@ -20,9 +13,54 @@ inferbinarytype(::Type{X}, ::Type{Y}, op::TypedBinaryOperator{F, X, Y, Z}) where
2013
inferbinarytype(::Type{X}, ::Type{X}, op::TypedMonoid{F, X, Z}) where {F, X, Z} = ztype(op)
2114
inferbinarytype(::Type{X}, ::Type{Y}, op::TypedSemiring{F, X, Y, Z}) where {F, X, Y, Z} = ztype(op)
2215
inferbinarytype(::Type{X}, ::Type{Y}, op::TypedBinaryOperator{F, X2, Y2, Z}) where {F, X, X2, Y, Y2, Z} = ztype(op)
23-
function _handlenothings(kwargs...)
24-
return (x === nothing ? C_NULL : x for x in kwargs)
16+
17+
struct Complement{T}
18+
parent::T
19+
end
20+
21+
Complement(A::T) where {T<:GBArrayOrTranspose}= Complement{T}(A)
22+
Base.:~(A::T) where {T<:GBArrayOrTranspose} = Complement(A)
23+
Base.parent(C::Complement) = C.parent
24+
25+
struct Structural{T}
26+
parent::T
27+
end
28+
29+
Structural(A::T) where {T<:GBArrayOrTranspose}= Structural{T}(A)
30+
Base.parent(C::Structural) = C.parent
31+
32+
_handlemask!(desc, mask::Nothing) = C_NULL
33+
_handlemask!(desc, mask::AbstractGBArray) = mask
34+
function _handlemask!(desc, mask)
35+
while !(mask isa AbstractGBArray)
36+
if mask isa Transpose
37+
mask = copy(mask)
38+
elseif mask isa Complement
39+
mask = parent(mask)
40+
desc.complement_mask = true
41+
elseif mask isa Structural
42+
mask = parent(mask)
43+
desc.structural_mask = true
44+
end
45+
end
46+
return mask
47+
end
48+
49+
50+
_handleaccum(::Nothing, t) = C_NULL
51+
_handleaccum(::Ptr{Nothing}, t) = C_NULL
52+
_handleaccum(op::Function, t) = binaryop(op, t, t)
53+
_handleaccum(op::Function, tleft, tright) = binaryop(op, tleft, tright)
54+
_handleaccum(op::TypedBinaryOperator, x...) = op
55+
56+
function _kwargtoc(desc, x)
57+
x.second === nothing && return C_NULL
58+
if x.first === :mask
59+
return _handlemask!(desc, x.second)
60+
end
61+
return x.second
2562
end
63+
2664

2765
"""
2866
xtype(op::GrBOp)::DataType

src/operations/reduce.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ function reduce!(
22
op, w::AbstractGBVector, A::GBArrayOrTranspose;
33
mask = nothing, accum = nothing, desc = nothing
44
)
5-
mask, accum = _handlenothings(mask, accum)
65
desc = _handledescriptor(desc; in1=A)
6+
mask = _handlemask!(desc, mask)
7+
78
op = typedmonoid(op, eltype(w))
8-
accum = getaccum(accum, eltype(w))
9+
accum = _handleaccum(accum, eltype(w))
910
@wraperror LibGraphBLAS.GrB_Matrix_reduce_Monoid(
1011
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(w)), mask, accum, op, gbpointer(parent(A)), desc
1112
)
@@ -23,7 +24,7 @@ function Base.reduce(
2324
desc = nothing
2425
)
2526
desc = _handledescriptor(desc; in1=A)
26-
mask, accum = _handlenothings(mask, accum)
27+
mask = _handlemask!(desc, mask)
2728
if typeout === nothing
2829
typeout = eltype(A)
2930
end
@@ -47,7 +48,7 @@ function Base.reduce(
4748
if nnz(c) == 1 && accum == C_NULL
4849
accum = binaryop(op)
4950
end
50-
accum = getaccum(accum, typeout)
51+
accum = _handleaccum(accum, typeout)
5152
@wraperror LibGraphBLAS.GrB_Matrix_reduce_Monoid_Scalar(c, accum, op, gbpointer(parent(A)), desc)
5253
return c[]
5354
end

src/operations/select.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ function select!(
1111
desc = nothing
1212
)
1313
op = SelectOp(op)
14-
mask, accum = _handlenothings(mask, accum)
1514
desc = _handledescriptor(desc; in1=A)
15+
mask = _handlemask!(desc, mask)
1616
thunk === nothing && (thunk = C_NULL)
17-
accum = getaccum(accum, eltype(C))
17+
accum = _handleaccum(accum, eltype(C))
1818
if thunk isa Number
1919
thunk = GBScalar(thunk)
2020
end
@@ -59,7 +59,6 @@ function select(
5959
desc = nothing
6060
)
6161
op = SelectOp(op)
62-
mask, accum = _handlenothings(mask, accum)
6362
C = similar(A)
6463
select!(op, C, A, thunk; accum, mask, desc)
6564
return C

src/operations/sort.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ function Base.sort!(
1212
)
1313
A isa GBMatrixOrTranspose && dims === nothing && throw(ArgumentError("dims must be either 1 (sort columns) or 2 (sort rows) for matrix arguments."))
1414
A isa GBVector && (dims = 1)
15-
C, P = _handlenothings(C, P)
15+
C === nothing && (C = C_NULL)
16+
P === nothing && (P = C_NULL)
1617
C == C_NULL && P == C_NULL && throw(ArgumentError("One (or both) of C and P must not be nothing."))
1718
op = binaryop(lt, eltype(A))
1819
if dims == 1

0 commit comments

Comments
 (0)