@@ -380,14 +380,29 @@ function subassign!(
380380 C:: AbstractGBArray , A:: GBArrayOrTranspose , I, J;
381381 mask = nothing , accum = nothing , desc = nothing
382382)
383+ # before we make I and J into GraphBLAS internal types
384+ # get their size to check if A should be reshaped from nx1 -> 1xn
385+ ni_sizecheck = I isa Colon ? size (C, 1 ) : length (I)
386+ nj_sizecheck = J isa Colon ? size (C, 2 ) : length (J)
383387 I, ni = idx (I)
384388 J, nj = idx (J)
385- mask === nothing && (mask = C_NULL )
389+ desc = _handledescriptor (desc; in1= A)
390+ mask = _handlemask! (desc, mask)
386391 I = decrement! (I)
387392 J = decrement! (J)
388- # we know A isn't adjoint/transpose on input
389- desc = _handledescriptor (desc; in1= A)
390- @wraperror LibGraphBLAS. GxB_Matrix_subassign (gbpointer (C), mask, getaccum (accum, eltype (C)), gbpointer (parent (A)), I, ni, J, nj, desc)
393+ rereshape = false
394+ sz1 = size (A, 1 )
395+ # reshape A: nx1 -> 1xn
396+ if A isa GBVector && (ni_sizecheck == size (A, 2 ) && nj_sizecheck == sz1)
397+ @wraperror LibGraphBLAS. GxB_Matrix_reshape (gbpointer (parent (A)), true , 1 , sz1, C_NULL )
398+ rereshape = true
399+ end
400+ @wraperror LibGraphBLAS. GxB_Matrix_subassign (gbpointer (C), mask,
401+ _handleaccum (accum, eltype (C)), gbpointer (parent (A)), I, ni, J, nj, desc)
402+ if rereshape # undo the reshape. Need size(A, 2) here
403+ @wraperror LibGraphBLAS. GxB_Matrix_reshape (
404+ gbpointer (parent (A)), true , sz1, 1 , C_NULL )
405+ end
391406 increment! (I)
392407 increment! (J)
393408 return A
@@ -402,8 +417,8 @@ function subassign!(C::AbstractGBArray{T}, x, I, J;
402417 I = decrement! (I)
403418 J = decrement! (J)
404419 desc = _handledescriptor (desc)
405- mask, accum = _handlenothings (mask, accum )
406- _subassign (C, x, I, ni, J, nj, mask, getaccum (accum, eltype (C)), desc)
420+ mask = _handlemask! (desc, mask )
421+ _subassign (C, x, I, ni, J, nj, mask, _handleaccum (accum, eltype (C)), desc)
407422 increment! (I)
408423 increment! (J)
409424 return x
@@ -465,12 +480,11 @@ function assign!(
465480)
466481 I, ni = idx (I)
467482 J, nj = idx (J)
468- mask === nothing && (mask = C_NULL )
483+ desc = _handledescriptor (desc; in1= A)
484+ mask = _handlemask! (desc, mask)
469485 I = decrement! (I)
470486 J = decrement! (J)
471- # we know A isn't adjoint/transpose on input
472- desc = _handledescriptor (desc; in1= A)
473- @wraperror LibGraphBLAS. GrB_Matrix_assign (gbpointer (C), mask, getaccum (accum, eltype (C)), gbpointer (parent (A)), I, ni, J, nj, desc)
487+ @wraperror LibGraphBLAS. GrB_Matrix_assign (gbpointer (C), mask, _handleaccum (accum, eltype (C)), gbpointer (parent (A)), I, ni, J, nj, desc)
474488 increment! (I)
475489 increment! (J)
476490 return A
@@ -485,7 +499,7 @@ function assign!(C::AbstractGBArray{T}, x, I, J;
485499 I = decrement! (I)
486500 J = decrement! (J)
487501 desc = _handledescriptor (desc)
488- _assign (gbpointer (C), x, I, ni, J, nj, mask, getaccum (accum, eltype (C)), desc)
502+ _assign (gbpointer (C), x, I, ni, J, nj, mask, _handleaccum (accum, eltype (C)), desc)
489503 increment! (I)
490504 increment! (J)
491505 return x
0 commit comments