Skip to content

Commit f3f2cd6

Browse files
committed
update m*Pos functions in algorithm.d
1 parent 84e7d1f commit f3f2cd6

File tree

1 file changed

+71
-17
lines changed

1 file changed

+71
-17
lines changed

source/mir/ndslice/algorithm.d

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,10 @@ bool[2] minmaxPosImpl(alias fun, SliceKind kind, size_t[] packs, Iterator)(ref s
449449
}
450450

451451
/++
452-
Finds a backward indexes such that
453-
`slice.backward(indexes[0])` is minimal and `slice.backward(indexes[1])` is maximal elements in the slice.
452+
Finds a positions (ndslices) such that
453+
`position[0].first` is minimal and `position[1].first` is maximal elements in the slice.
454+
455+
Each position is sub-ndslice of the same dimension in the right-(down-(etc)) corner.
454456
455457
Params:
456458
pred = A predicate.
@@ -472,7 +474,8 @@ template minmaxPos(alias pred = "a < b")
472474
Multidimensional backward index such that element is minimal(maximal).
473475
Backward index equals zeros, if slice is empty.
474476
+/
475-
@fastmath size_t[packs[0]][2] minmaxPos(SliceKind kind, size_t[] packs, Iterator)(Slice!(kind, packs, Iterator) slice)
477+
@fastmath Slice!(kind == Contiguous && packs[0] > 1 ? Canonical : kind, packs, Iterator)[2]
478+
minmaxPos(SliceKind kind, size_t[] packs, Iterator)(Slice!(kind, packs, Iterator) slice)
476479
{
477480
import mir.ndslice.topology: map;
478481
typeof(return) pret;
@@ -484,10 +487,34 @@ template minmaxPos(alias pred = "a < b")
484487
minmaxPosImpl!(pred, kind, packs, Iterator)(ret, iterator, slice);
485488
foreach (i; Iota!(packs[0]))
486489
{
487-
pret[0][i] = ret[i][0];
488-
pret[1][i] = ret[i][1];
490+
pret[0]._lengths[i] = ret[i][0];
491+
pret[1]._lengths[i] = ret[i][1];
492+
}
493+
static if (packs.length > 1)
494+
{
495+
pret[0]._iterator = iterator[0]._iterator;
496+
pret[1]._iterator = iterator[1]._iterator;
497+
}
498+
else
499+
{
500+
pret[0]._iterator = iterator[0];
501+
pret[1]._iterator = iterator[1];
489502
}
490503
}
504+
static if (packs.length > 1)
505+
{
506+
foreach (i; Iota!(packs[0], slice.N))
507+
{
508+
pret[0]._lengths[i] = slice._lengths[i];
509+
pret[1]._lengths[i] = slice._lengths[i];
510+
}
511+
}
512+
auto strides = slice.strides;
513+
foreach(i; Iota!(0, pret[0].S))
514+
{
515+
pret[0]._strides[i] = strides[i];
516+
pret[1]._strides[i] = strides[i];
517+
}
491518
return pret;
492519
}
493520
else
@@ -503,11 +530,15 @@ unittest
503530
-3, -2, 7, 2,
504531
].sliced(3, 4);
505532

506-
auto backwardIndex = s.minmaxPos;
533+
auto pos = s.minmaxPos;
507534

508-
assert(backwardIndex == [[2, 3], [1, 2]]);
509-
assert(s.backward(backwardIndex[0]) == -4);
510-
assert(s.backward(backwardIndex[1]) == 7);
535+
assert(pos[0] == s[$ - 2 .. $, $ - 3 .. $]);
536+
assert(pos[1] == s[$ - 1 .. $, $ - 2 .. $]);
537+
538+
assert(pos[0].first == -4);
539+
assert(s.backward(pos[0].shape) == -4);
540+
assert(pos[1].first == 7);
541+
assert(s.backward(pos[1].shape) == 7);
511542
}
512543

513544
/++
@@ -596,14 +627,35 @@ template minPos(alias pred = "a < b")
596627
Multidimensional backward index such that element is minimal(maximal).
597628
Backward index equals zeros, if slice is empty.
598629
+/
599-
@fastmath size_t[packs[0]] minPos(SliceKind kind, size_t[] packs, Iterator)(Slice!(kind, packs, Iterator) slice)
630+
@fastmath Slice!(kind == Contiguous && packs[0] > 1 ? Canonical : kind, packs, Iterator)
631+
minPos(SliceKind kind, size_t[] packs, Iterator)(Slice!(kind, packs, Iterator) slice)
600632
{
601633
typeof(return) ret;
602634
import mir.ndslice.topology: map;
603635
if (!slice.anyEmpty)
604636
{
605637
auto iterator = slice.map!"a"._iterator;
606-
minPosImpl!(pred, kind, packs, Iterator)(ret, iterator, slice);
638+
minPosImpl!(pred, kind, packs, Iterator)(ret._lengths, iterator, slice);
639+
static if (packs.length > 1)
640+
{
641+
ret._iterator = iterator._iterator;
642+
}
643+
else
644+
{
645+
ret._iterator = iterator;
646+
}
647+
}
648+
static if (packs.length > 1)
649+
{
650+
foreach (i; Iota!(packs[0], slice.N))
651+
{
652+
ret._lengths[i] = slice._lengths[i];
653+
}
654+
}
655+
auto strides = slice.strides;
656+
foreach(i; Iota!(0, ret.S))
657+
{
658+
ret._strides[i] = strides[i];
607659
}
608660
return ret;
609661
}
@@ -627,15 +679,17 @@ unittest
627679
-3, -2, 7, 2,
628680
].sliced(3, 4);
629681

630-
auto backwardIndex = s.minPos;
682+
auto pos = s.minPos;
631683

632-
assert(backwardIndex == [2, 3]);
633-
assert(s.backward(backwardIndex) == -4);
684+
assert(pos == s[$ - 2 .. $, $ - 3 .. $]);
685+
assert(pos.first == -4);
686+
assert(s.backward(pos.shape) == -4);
634687

635-
backwardIndex = s.maxPos;
688+
pos = s.maxPos;
636689

637-
assert(backwardIndex == [1, 2]);
638-
assert(s.backward(backwardIndex) == 7);
690+
assert(pos == s[$ - 1 .. $, $ - 2 .. $]);
691+
assert(pos.first == 7);
692+
assert(s.backward(pos.shape) == 7);
639693
}
640694

641695
/++

0 commit comments

Comments
 (0)