Skip to content

Commit c778cd3

Browse files
committed
relax linear interpolant contraints
1 parent b9b343d commit c778cd3

File tree

1 file changed

+84
-59
lines changed

1 file changed

+84
-59
lines changed

source/mir/interpolate/linear.d

Lines changed: 84 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ version(mir_test)
7171
static immutable data = [0.0011, 0.0030, 0.0064, 0.0104, 0.0144, 0.0176, 0.0207, 0.0225, 0.0243, 0.0261, 0.0268, 0.0274, 0.0281, 0.0288, 0.0295, 0.0302, 0.0309, 0.0316, 0.0322, 0.0329, 0.0332, 0.0335, 0.0337, 0.0340, 0.0342, 0.0345, 0.0348, 0.0350, 0.0353, 0.0356];
7272

7373
assert(xs.sliced.vmap(interpolant).all!((a, b) => approxEqual(a, b, 1e-4, 1e-4))(data));
74+
75+
auto d = interpolant.withDerivative(9.0);
76+
auto de = interpolant.opCall!2(9.0);
77+
assert(de[0 .. 2] == d);
78+
assert(de[2] == 0);
7479
}
7580

7681
/// R^2 -> R: Bilinear interpolation
@@ -265,7 +270,6 @@ struct Linear(F, size_t N = 1, X = F)
265270

266271
///
267272
template opCall(uint derivative = 0)
268-
if (derivative <= derivativeOrder)
269273
{
270274
/++
271275
`(x)` operator.
@@ -275,62 +279,72 @@ struct Linear(F, size_t N = 1, X = F)
275279
auto opCall(X...)(const X xs) scope const @trusted
276280
if (X.length == N)
277281
{
278-
import mir.functional: AliasCall;
279-
import mir.ndslice.topology: iota;
280-
alias Kernel = AliasCall!(LinearKernel!F, "opCall", derivative);
282+
static if (derivative > derivativeOrder)
283+
{
284+
auto res = this.opCall!derivativeOrder(xs);
285+
typeof(res[0])[derivative + 1] ret = 0;
286+
ret[0 .. derivativeOrder + 1] = res;
287+
return ret;
288+
}
289+
else
290+
{
291+
import mir.functional: AliasCall;
292+
import mir.ndslice.topology: iota;
293+
alias Kernel = AliasCall!(LinearKernel!F, "opCall", derivative);
281294

282-
size_t[N] indices;
283-
Kernel[N] kernels;
295+
size_t[N] indices;
296+
Kernel[N] kernels;
284297

285-
enum rp2d = derivative;
298+
enum rp2d = derivative;
286299

287-
foreach(i; Iota!N)
288-
{
289-
static if (isInterval!(typeof(xs[i])))
300+
foreach(i; Iota!N)
290301
{
291-
indices[i] = xs[i][1];
292-
auto x = xs[i][0];
302+
static if (isInterval!(typeof(xs[i])))
303+
{
304+
indices[i] = xs[i][1];
305+
auto x = xs[i][0];
306+
}
307+
else
308+
{
309+
alias x = xs[i];
310+
indices[i] = this.findInterval!i(x);
311+
}
312+
kernels[i] = LinearKernel!F(_grid[i][indices[i]], _grid[i][indices[i] + 1], x);
293313
}
294-
else
295-
{
296-
alias x = xs[i];
297-
indices[i] = this.findInterval!i(x);
298-
}
299-
kernels[i] = LinearKernel!F(_grid[i][indices[i]], _grid[i][indices[i] + 1], x);
300-
}
301314

302-
align(64) F[2 ^^ N][derivative + 1] local;
303-
immutable strides = _data._lengths.iota.strides;
315+
align(64) F[2 ^^ N][derivative + 1] local;
316+
immutable strides = _data._lengths.iota.strides;
304317

305-
void load(sizediff_t i)(F* from, F* to)
306-
{
307-
version(LDC) pragma(inline, true);
308-
static if (i == -1)
309-
{
310-
*to = *from;
311-
}
312-
else
318+
void load(sizediff_t i)(F* from, F* to)
313319
{
314-
from += strides[i] * indices[i];
315-
load!(i - 1)(from, to);
316-
from += strides[i];
317-
enum s = 2 ^^ (N - 1 - i);
318-
to += s;
319-
load!(i - 1)(from, to);
320+
version(LDC) pragma(inline, true);
321+
static if (i == -1)
322+
{
323+
*to = *from;
324+
}
325+
else
326+
{
327+
from += strides[i] * indices[i];
328+
load!(i - 1)(from, to);
329+
from += strides[i];
330+
enum s = 2 ^^ (N - 1 - i);
331+
to += s;
332+
load!(i - 1)(from, to);
333+
}
320334
}
321-
}
322335

323-
load!(N - 1)(cast(F*) _data.ptr, cast(F*)local[0].ptr);
336+
load!(N - 1)(cast(F*) _data.ptr, cast(F*)local[0].ptr);
324337

325-
foreach(i; Iota!N)
326-
{
327-
enum P = 2 ^^ (N - 1 - i);
328-
enum L = 2 ^^ (N - i * (1 - rp2d)) / 2;
329-
vectorize(kernels[i], local[0][0 * L .. 1 * L], local[0][1 * L .. 2 * L], *cast(F[L][2 ^^ rp2d]*)local[rp2d].ptr);
330-
static if (rp2d == 1)
331-
shuffle3!1(local[1][0 .. L], local[1][L .. 2 * L], local[0][0 .. L], local[0][L .. 2 * L]);
332-
static if (i + 1 == N)
333-
return *cast(SplineReturnType!(F, N, 2 ^^ rp2d)*) local[0].ptr;
338+
foreach(i; Iota!N)
339+
{
340+
enum P = 2 ^^ (N - 1 - i);
341+
enum L = 2 ^^ (N - i * (1 - rp2d)) / 2;
342+
vectorize(kernels[i], local[0][0 * L .. 1 * L], local[0][1 * L .. 2 * L], *cast(F[L][2 ^^ rp2d]*)local[rp2d].ptr);
343+
static if (rp2d == 1)
344+
shuffle3!1(local[1][0 .. L], local[1][L .. 2 * L], local[0][0 .. L], local[0][L .. 2 * L]);
345+
static if (i + 1 == N)
346+
return *cast(SplineReturnType!(F, N, 2 ^^ rp2d)*) local[0].ptr;
347+
}
334348
}
335349
}
336350
}
@@ -486,7 +500,6 @@ struct MetaLinear(T, X)
486500

487501
///
488502
template opCall(uint derivative = 0)
489-
if (derivative <= derivativeOrder)
490503
{
491504
/++
492505
`(x)` operator.
@@ -496,22 +509,32 @@ struct MetaLinear(T, X)
496509
auto opCall(X...)(const X xs) scope const @trusted
497510
if (X.length == dimensionCount)
498511
{
499-
static if (isInterval!(typeof(xs[0])))
512+
static if (derivative > derivativeOrder)
500513
{
501-
size_t index = xs[0][1];
502-
auto x = xs[0][0];
514+
auto res = this.opCall!derivativeOrder(xs);
515+
typeof(res[0])[derivative + 1] ret = 0;
516+
ret[0 .. derivativeOrder + 1] = res;
517+
return ret;
503518
}
504519
else
505-
{
506-
alias x = xs[0];
507-
size_t index = this.findInterval(x);
520+
{
521+
static if (isInterval!(typeof(xs[0])))
522+
{
523+
size_t index = xs[0][1];
524+
auto x = xs[0][0];
525+
}
526+
else
527+
{
528+
alias x = xs[0];
529+
size_t index = this.findInterval(x);
530+
}
531+
auto lhs = data[index + 0].opCall!derivative(xs[1 .. $]);
532+
auto rhs = data[index + 1].opCall!derivative(xs[1 .. $]);
533+
alias E = typeof(lhs);
534+
alias F = DeepType!E;
535+
auto kernel = LinearKernel!F(grid[index], grid[index + 1], x);
536+
return kernel.opCall!derivative(lhs, rhs);
508537
}
509-
auto lhs = data[index + 0].opCall!derivative(xs[1 .. $]);
510-
auto rhs = data[index + 1].opCall!derivative(xs[1 .. $]);
511-
alias E = typeof(lhs);
512-
alias F = DeepType!E;
513-
auto kernel = LinearKernel!F(grid[index], grid[index + 1], x);
514-
return kernel.opCall!derivative(lhs, rhs);
515538
}
516539
}
517540

@@ -520,6 +543,7 @@ struct MetaLinear(T, X)
520543
}
521544

522545
/// 2D trapezoid-like (not rectilinear) linear interpolation
546+
version(mir_test)
523547
unittest
524548
{
525549
auto x = [
@@ -549,6 +573,7 @@ unittest
549573
auto valWithDerivative = trapezoidInterpolator.withDerivative(9.0, 1.8);
550574
}
551575

576+
version(mir_test)
552577
unittest
553578
{
554579
import mir.math.common: approxEqual;

0 commit comments

Comments
 (0)