Skip to content

Commit ca8fce7

Browse files
committed
add specialisation for map!"a"
1 parent 53797c8 commit ca8fce7

File tree

1 file changed

+27
-24
lines changed

1 file changed

+27
-24
lines changed

source/mir/ndslice/topology.d

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,32 +1915,18 @@ auto retro
19151915
{
19161916
static if (kind == Contiguous || kind == Canonical)
19171917
{
1918-
static if (kind == Contiguous)
1919-
{
1920-
ptrdiff_t shift = 1;
1921-
foreach(i; Iota!(packs[0]))
1922-
shift *= slice._lengths[i];
1923-
--shift;
1924-
}
1925-
else
1926-
{
1927-
ptrdiff_t shift = 0;
1928-
foreach(i; Iota!(packs[0]))
1929-
shift += slice.backIndex!i;
1930-
}
19311918
static if (is(Iterator : RetroIterator!It, It))
19321919
{
19331920
alias Ret = Slice!(kind, packs, It);
19341921
mixin _DefineRet_;
1935-
ret._iterator = slice._iterator._iterator;
1922+
ret._iterator = slice._iterator._iterator - ret.lastIndex;
19361923
}
19371924
else
19381925
{
19391926
alias Ret = Slice!(kind, packs, RetroIterator!Iterator);
19401927
mixin _DefineRet_;
1941-
ret._iterator = RetroIterator!Iterator(slice._iterator);
1928+
ret._iterator = RetroIterator!Iterator(slice._iterator + slice.lastIndex);
19421929
}
1943-
ret._iterator -= shift;
19441930
foreach (i; Iota!(ret.N))
19451931
ret._lengths[i] = slice._lengths[i];
19461932
foreach (i; Iota!(ret.S))
@@ -2115,9 +2101,8 @@ template map(fun...)
21152101
import mir.functional: adjoin, naryFun, pipe;
21162102
static if (fun.length == 1)
21172103
{
2118-
static if (__traits(isSame, naryFun!fun, fun))
2104+
static if (__traits(isSame, naryFun!"a", fun[0]))
21192105
{
2120-
alias f = fun[0];
21212106
/++
21222107
Params:
21232108
slice = An input slice.
@@ -2128,25 +2113,43 @@ template map(fun...)
21282113
@fastmath auto map(SliceKind kind, size_t[] packs, Iterator)
21292114
(Slice!(kind, packs, Iterator) slice)
21302115
{
2131-
// Specialization for packed tensors (tensors composed of tensors).
21322116
static if (packs.length == 1)
21332117
{
2134-
import mir.ndslice.iterator: mapIterator;
2135-
auto iterator = slice._iterator.mapIterator!f;
2136-
return Slice!(kind, packs, typeof(iterator))(slice._lengths, slice._strides, iterator);
2118+
return slice;
21372119
}
21382120
else
21392121
{
21402122
alias It = SliceIterator!(TemplateArgsOf!(slice.DeepElemType));
21412123
auto sl = slice.universal;
2142-
return .map!f(Slice!(Universal, packs[0 .. 1], It)(
2124+
return Slice!(Universal, packs[0 .. 1], It)(
21432125
sl._lengths[0 .. packs[0]],
21442126
sl._strides[0 .. packs[0]],
21452127
It(
21462128
sl._lengths[packs[0] .. packs[0] + It._lengths.length],
21472129
sl._strides[packs[0] .. packs[0] + It._strides.length],
21482130
sl._iterator,
2149-
)));
2131+
));
2132+
}
2133+
}
2134+
}
2135+
else
2136+
static if (__traits(isSame, naryFun!(fun[0]), fun[0]))
2137+
{
2138+
alias f = fun[0];
2139+
@fastmath auto map(SliceKind kind, size_t[] packs, Iterator)
2140+
(Slice!(kind, packs, Iterator) slice)
2141+
{
2142+
// Specialization for packed tensors (tensors composed of tensors).
2143+
static if (packs.length == 1)
2144+
{
2145+
import mir.ndslice.iterator: mapIterator;
2146+
auto iterator = slice._iterator.mapIterator!f;
2147+
return Slice!(kind, packs, typeof(iterator))(slice._lengths, slice._strides, iterator);
2148+
}
2149+
else
2150+
{
2151+
alias It = SliceIterator!(TemplateArgsOf!(slice.DeepElemType));
2152+
return .map!f(.map!(naryFun!"a")(slice));
21502153
}
21512154
}
21522155
}

0 commit comments

Comments
 (0)