Skip to content

Commit e5bd939

Browse files
authored
Merge pull request #2363 from JohanMabille/chunk_assign
Fixed chunk assignment
2 parents 5b74468 + ec8db28 commit e5bd939

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

include/xtensor/xchunked_assign.hpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,12 @@ namespace xt
172172
bool operator!=(const self_type& rhs) const;
173173

174174
const slice_vector& get_slice_vector() const;
175+
slice_vector get_chunk_slice_vector() const;
175176

176177
private:
177178

179+
void fill_slice_vector(size_type index);
180+
178181
E* p_chunked_expression;
179182
shape_type m_chunk_index;
180183
size_type m_chunk_linear_index;
@@ -198,9 +201,20 @@ namespace xt
198201
inline auto xchunked_semantic<D>::assign_xexpression(const xexpression<E>& e) -> derived_type&
199202
{
200203
auto& d = this->derived_cast();
201-
for (auto it = d.chunk_begin(); it != d.chunk_end(); ++it)
204+
const auto& chunk_shape = d.chunk_shape();
205+
size_t i = 0;
206+
auto it_end = d.chunk_end();
207+
for (auto it = d.chunk_begin(); it != it_end; ++it, ++i)
202208
{
203-
noalias(*it) = strided_view(e.derived_cast(), it.get_slice_vector());
209+
auto rhs = strided_view(e.derived_cast(), it.get_slice_vector());
210+
if (rhs.shape() != chunk_shape)
211+
{
212+
noalias(strided_view(*it, it.get_chunk_slice_vector())) = rhs;
213+
}
214+
else
215+
{
216+
noalias(*it) = rhs;
217+
}
204218
}
205219

206220
return this->derived_cast();
@@ -262,17 +276,7 @@ namespace xt
262276
{
263277
for (size_type i = 0; i < m_chunk_index.size(); ++i)
264278
{
265-
if (m_chunk_index[i] == 0)
266-
{
267-
m_slice_vector[i] = range(0, p_chunked_expression->chunk_shape()[i]);
268-
}
269-
else
270-
{
271-
size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape()[i];
272-
size_type range_end = std::min((m_chunk_index[i] + 1) * p_chunked_expression->chunk_shape()[i],
273-
p_chunked_expression->shape()[i]);
274-
m_slice_vector[i] = range(range_start, range_end);
275-
}
279+
fill_slice_vector(i);
276280
}
277281
}
278282

@@ -288,15 +292,12 @@ namespace xt
288292
if (m_chunk_index[i] + 1u == p_chunked_expression->grid_shape()[i])
289293
{
290294
m_chunk_index[i] = 0;
291-
m_slice_vector[i] = range(0, p_chunked_expression->chunk_shape()[i]);
295+
fill_slice_vector(i);
292296
}
293297
else
294298
{
295299
m_chunk_index[i] += 1;
296-
size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape()[i];
297-
size_type range_end = std::min((m_chunk_index[i] + 1) * p_chunked_expression->chunk_shape()[i],
298-
p_chunked_expression->shape()[i]);
299-
m_slice_vector[i] = range(range_start, range_end);
300+
fill_slice_vector(i);
300301
break;
301302
}
302303
}
@@ -336,6 +337,29 @@ namespace xt
336337
{
337338
return m_slice_vector;
338339
}
340+
341+
template <class E>
342+
inline auto xchunk_iterator<E>::get_chunk_slice_vector() const -> slice_vector
343+
{
344+
slice_vector slices(m_chunk_index.size());
345+
for (size_type i = 0; i < m_chunk_index.size(); ++i)
346+
{
347+
size_type chunk_shape = p_chunked_expression->chunk_shape()[i];
348+
size_type end = std::min(chunk_shape,
349+
p_chunked_expression->shape()[i] - m_chunk_index[i] * chunk_shape);
350+
slices[i] = range(0u, end);
351+
}
352+
return slices;
353+
}
354+
355+
template <class E>
356+
inline void xchunk_iterator<E>::fill_slice_vector(size_type i)
357+
{
358+
size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape()[i];
359+
size_type range_end = std::min((m_chunk_index[i] + 1) * p_chunked_expression->chunk_shape()[i],
360+
p_chunked_expression->shape()[i]);
361+
m_slice_vector[i] = range(range_start, range_end);
362+
}
339363
}
340364

341365
#endif

test/test_xchunked_array.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ namespace xt
7777
EXPECT_EQ(v, 2. * val + 2.);
7878
}
7979

80-
8180
xarray<double> a3
8281
{{1., 2., 3.},
8382
{4., 5., 6.},
@@ -110,6 +109,20 @@ namespace xt
110109
{
111110
EXPECT_EQ(v, 3);
112111
}
112+
113+
std::vector<size_t> shape3 = {3, 3};
114+
std::vector<size_t> chunk_shape3 = {1, 2};
115+
auto a7 = chunked_array<double>(shape3, chunk_shape3);
116+
for (auto it = a7.chunks().begin(); it != a7.chunks().end(); ++it)
117+
{
118+
it->resize(chunk_shape3);
119+
}
120+
121+
a7 = a3;
122+
for (auto it = a7.chunks().begin(); it != a7.chunks().end(); ++it)
123+
{
124+
EXPECT_EQ(it->shape(), chunk_shape3);
125+
}
113126
}
114127

115128
TEST(xchunked_array, noalias)

0 commit comments

Comments
 (0)