Skip to content

Commit cb5d12d

Browse files
committed
rework fuseShape
1 parent b773d5c commit cb5d12d

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

source/mir/ndslice/fuse.d

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,7 @@ template fuseImpl(bool RC, T_, Dimensions...)
277277
private template fuseDimensionCount(R)
278278
{
279279
static if (is(typeof(R.init.shape) : size_t[N], size_t N) && (isDynamicArray!R || __traits(hasMember, R, "front")))
280-
{
281-
import mir.ndslice.topology: repeat;
282280
enum size_t fuseDimensionCount = N + fuseDimensionCount!(DeepElementType!R);
283-
}
284281
else
285282
enum size_t fuseDimensionCount = 0;
286283
}
@@ -309,26 +306,37 @@ size_t[fuseDimensionCount!Range] fuseShape(Range)(Range r)
309306
import mir.ndslice.topology: repeat;
310307
typeof(return) ret;
311308
ret[0 .. N] = r.shape;
312-
if (!ret[0 .. N].anyEmptyShape)
309+
bool next;
310+
if (!ret[0 .. N].anyEmptyShape) foreach (ref elem; r)
313311
{
314-
ret[N .. $] = fuseShape(mixin("r" ~ ".front".repeat(N).fuseCells.field));
315-
import mir.algorithm.iteration: all;
316-
if (!all!((a) => cast(size_t[M]) ret[N .. $] == .fuseShape(a))(r))
312+
auto elemShape = fuseShape(elem);
313+
if (next)
317314
{
318-
version (D_Exceptions)
319-
throw shapeException;
320-
else
321-
assert(0, shapeExceptionMsg);
315+
if (elemShape != ret[N .. $])
316+
{
317+
version (D_Exceptions)
318+
throw shapeException;
319+
else
320+
assert(0, shapeExceptionMsg);
321+
}
322+
next = true;
323+
}
324+
else
325+
{
326+
ret[N .. $] = elemShape;
322327
}
323328
}
324329
return ret;
325330
}
326331
}
327332

328333
private template FuseElementType(NDRange)
334+
if (fuseDimensionCount!NDRange >= 1)
329335
{
330-
import mir.ndslice.topology: repeat;
331-
alias FuseElementType = typeof(mixin("NDRange.init" ~ ".front".repeat(fuseDimensionCount!NDRange).fuseCells.field));
336+
static if (fuseDimensionCount!NDRange == 1)
337+
alias FuseElementType = typeof(NDRange.init.front);
338+
else
339+
alias FuseElementType = FuseElementType!(typeof(NDRange.init.front));
332340
}
333341

334342
/++

0 commit comments

Comments
 (0)