Skip to content

Commit 388d4f9

Browse files
committed
extend interpolation derivatives api
1 parent c778cd3 commit 388d4f9

File tree

2 files changed

+78
-43
lines changed

2 files changed

+78
-43
lines changed

source/mir/interpolate/linear.d

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ struct LinearKernel(X)
386386

387387
///
388388
template opCall(uint derivative = 0)
389-
if (derivative <= 1)
389+
// if (derivative <= 1)
390390
{
391391
///
392392
auto opCall(Y)(const Y y0, const Y y1)
@@ -509,37 +509,30 @@ struct MetaLinear(T, X)
509509
auto opCall(X...)(const X xs) scope const @trusted
510510
if (X.length == dimensionCount)
511511
{
512-
static if (derivative > derivativeOrder)
512+
static if (isInterval!(typeof(xs[0])))
513513
{
514-
auto res = this.opCall!derivativeOrder(xs);
515-
typeof(res[0])[derivative + 1] ret = 0;
516-
ret[0 .. derivativeOrder + 1] = res;
517-
return ret;
514+
size_t index = xs[0][1];
515+
auto x = xs[0][0];
518516
}
519517
else
520-
{
521-
static if (isInterval!(typeof(xs[0])))
522-
{
523-
size_t index = xs[0][1];
524-
auto x = xs[0][0];
525-
}
526-
else
527-
{
528-
alias x = xs[0];
529-
size_t index = this.findInterval(x);
530-
}
531-
auto lhs = data[index + 0].opCall!derivative(xs[1 .. $]);
532-
auto rhs = data[index + 1].opCall!derivative(xs[1 .. $]);
533-
alias E = typeof(lhs);
534-
alias F = DeepType!E;
535-
auto kernel = LinearKernel!F(grid[index], grid[index + 1], x);
536-
return kernel.opCall!derivative(lhs, rhs);
518+
{
519+
alias x = xs[0];
520+
size_t index = this.findInterval(x);
537521
}
522+
auto lhs = data[index + 0].opCall!derivative(xs[1 .. $]);
523+
auto rhs = data[index + 1].opCall!derivative(xs[1 .. $]);
524+
alias E = typeof(lhs);
525+
alias F = DeepType!E;
526+
auto kernel = LinearKernel!F(grid[index], grid[index + 1], x);
527+
return kernel.opCall!derivative(lhs, rhs);
538528
}
539529
}
540530

541531
///
542532
alias withDerivative = opCall!1;
533+
534+
///
535+
alias withTwoDerivatives = opCall!2;
543536
}
544537

545538
/// 2D trapezoid-like (not rectilinear) linear interpolation

source/mir/interpolate/spline.d

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import core.lifetime: move;
1919
import mir.functional;
2020
import mir.internal.utility;
2121
import mir.interpolate;
22-
import mir.interpolate: Repeat;
2322
import mir.math.common;
2423
import mir.ndslice.slice;
2524
import mir.primitives;
@@ -1618,7 +1617,7 @@ struct MetaSpline(T, X)
16181617
// alias ElementInterpolator = Linear!(F, N, X);
16191618
alias F = ValueType!(T, X);
16201619
///
1621-
private Spline!F spline;
1620+
private Repeat!(3, Spline!F) splines;
16221621
///
16231622
RCArray!(const T) data;
16241623
//
@@ -1649,7 +1648,9 @@ struct MetaSpline(T, X)
16491648

16501649
this.data = data.move;
16511650
this._temp = grid.length;
1652-
this.spline = grid.moveToSlice;
1651+
this.splines[0] = grid.asSlice;
1652+
this.splines[1] = grid.asSlice;
1653+
this.splines[2] = grid.moveToSlice;
16531654
this.configuration = configuration;
16541655
}
16551656

@@ -1668,7 +1669,7 @@ struct MetaSpline(T, X)
16681669
immutable(X)[] gridScopeView(size_t dimension = 0)() scope return const @property @trusted
16691670
if (dimension == 0)
16701671
{
1671-
return spline.gridScopeView;
1672+
return splines[0].gridScopeView;
16721673
}
16731674

16741675
/++
@@ -1689,8 +1690,7 @@ struct MetaSpline(T, X)
16891690

16901691
///
16911692
template opCall(uint derivative = 0)
1692-
// if (derivative <= derivativeOrder)
1693-
if (derivative <= 0) // doesn't support derivatives for now
1693+
if (derivative <= derivativeOrder)
16941694
{
16951695
/++
16961696
`(x)` operator.
@@ -1700,25 +1700,61 @@ struct MetaSpline(T, X)
17001700
auto opCall(X...)(const X xs) scope const @trusted
17011701
if (X.length == dimensionCount)
17021702
{
1703-
auto mutable = cast(F[2][]) spline._data.lightScope.field;
1704-
assert(mutable.length == data.length);
1705-
foreach (i, ref d; data)
1706-
mutable[i][0] = d(xs[1 .. $]);
1707-
(*cast(Spline!F*)&spline)._computeDerivativesTemp(
1708-
configuration.kind,
1709-
configuration.param,
1710-
configuration.leftBoundary,
1711-
configuration.rightBoundary,
1712-
(cast(F[])_temp[]).sliced);
1713-
return spline(xs[0]);
1703+
F[2][][derivative + 1] mutable;
1704+
1705+
static foreach (o; 0 .. derivative + 1)
1706+
{
1707+
mutable[o] = cast(F[2][]) splines[o]._data.lightScope.field;
1708+
assert(mutable[o].length == data.length);
1709+
}
1710+
1711+
static if (!derivative)
1712+
{
1713+
foreach (i, ref d; data)
1714+
mutable[0][i][0] = d(xs[1 .. $]);
1715+
}
1716+
else
1717+
{
1718+
foreach (i, ref d; data)
1719+
{
1720+
auto node = d.opCall!derivative(xs[1 .. $]);
1721+
static foreach (o; 0 .. derivative + 1)
1722+
mutable[o][i][0] = node[o];
1723+
}
1724+
}
1725+
1726+
static foreach (o; 0 .. derivative + 1)
1727+
{
1728+
(*cast(Spline!F*)&splines[o])._computeDerivativesTemp(
1729+
configuration.kind,
1730+
configuration.param,
1731+
configuration.leftBoundary,
1732+
configuration.rightBoundary,
1733+
(cast(F[])_temp[]).sliced);
1734+
}
1735+
1736+
static if (!derivative)
1737+
{
1738+
return splines[0](xs[0]);
1739+
}
1740+
else
1741+
{
1742+
typeof(splines[0].opCall!derivative(xs[0]))[derivative + 1] ret = void;
1743+
static foreach (o; 0 .. derivative + 1)
1744+
ret[o] = splines[o].opCall!derivative(xs[0]);
1745+
return ret;
1746+
}
17141747
}
17151748
}
17161749

1717-
// ///
1718-
// alias withDerivative = opCall!1;
1750+
///
1751+
alias withDerivative = opCall!1;
1752+
///
1753+
alias withTwoDerivatives = opCall!2;
17191754
}
17201755

17211756
/// 2D trapezoid-like (not rectilinear) linear interpolation
1757+
version(mir_test)
17221758
unittest
17231759
{
17241760
auto x = [
@@ -1745,9 +1781,15 @@ unittest
17451781
auto trapezoidInterpolator = metaSpline!double(g.rcarray!(immutable double), d.lightConst);
17461782

17471783
auto val = trapezoidInterpolator(9.0, 1.8);
1748-
1784+
auto ext = trapezoidInterpolator.opCall!2(9.0, 1.8);
1785+
assert(ext[0][0] == val);
1786+
assert(ext == [
1787+
[-0.6323361344537806, -2.2344649859943977, 1.362173669467787],
1788+
[3.6676610644257703, -2.984652194211018, 0.9911251167133525],
1789+
[2.4365546218487393, -1.556932773109244, 0.3836134453781514]]);
17491790
}
17501791

1792+
version(mir_test)
17511793
unittest
17521794
{
17531795
import mir.math.common: approxEqual;

0 commit comments

Comments
 (0)