Skip to content

Commit b681c4b

Browse files
committed
Fix build for pytorch 0.4.1
THTensor is defined as an opaque data type, thus direct access to members is not possible anymore. Instead, use functions from C API to retrieve stride. Fixes #4
1 parent 6ce5c2e commit b681c4b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

aten/TH/generic/THGrid.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@ void THTensor_(grid)(THLongTensor *self, THTensor *pos, THTensor *size, THLongTe
66
int64_t *selfData = THLongTensor_data(self);
77
real *posData = THTensor_(data)(pos);
88
real *sizeData = THTensor_(data)(size);
9+
int64_t posStride0 = THTensor_(stride)(pos, 0);
10+
int64_t posStride1 = THTensor_(stride)(pos, 1);
911
int64_t *countData = THLongTensor_data(count);
1012

1113
ptrdiff_t n, d; int64_t coef, value;
1214
for (n = 0; n < THTensor_(size)(pos, 0); n++) {
1315
coef = 1; value = 0;
1416
for (d = 0; d < THTensor_(size)(pos, 1); d++) {
15-
value += coef * (int64_t) (posData[d * pos->stride[1]] / sizeData[d]);
17+
value += coef * (int64_t) (posData[d * posStride1] / sizeData[d]);
1618
coef *= countData[d];
1719
}
18-
posData += pos->stride[0];
20+
posData += posStride0;
1921
selfData[n] = value;
2022
}
2123
}

0 commit comments

Comments
 (0)