@@ -172,9 +172,12 @@ namespace xt
172172 bool operator !=(const self_type& rhs) const ;
173173
174174 const slice_vector& get_slice_vector () const ;
175+ slice_vector get_chunk_slice_vector () const ;
175176
176177 private:
177178
179+ void fill_slice_vector (size_type index);
180+
178181 E* p_chunked_expression;
179182 shape_type m_chunk_index;
180183 size_type m_chunk_linear_index;
@@ -198,9 +201,20 @@ namespace xt
198201 inline auto xchunked_semantic<D>::assign_xexpression(const xexpression<E>& e) -> derived_type&
199202 {
200203 auto & d = this ->derived_cast ();
201- for (auto it = d.chunk_begin (); it != d.chunk_end (); ++it)
204+ const auto & chunk_shape = d.chunk_shape ();
205+ size_t i = 0 ;
206+ auto it_end = d.chunk_end ();
207+ for (auto it = d.chunk_begin (); it != it_end; ++it, ++i)
202208 {
203- noalias (*it) = strided_view (e.derived_cast (), it.get_slice_vector ());
209+ auto rhs = strided_view (e.derived_cast (), it.get_slice_vector ());
210+ if (rhs.shape () != chunk_shape)
211+ {
212+ noalias (strided_view (*it, it.get_chunk_slice_vector ())) = rhs;
213+ }
214+ else
215+ {
216+ noalias (*it) = rhs;
217+ }
204218 }
205219
206220 return this ->derived_cast ();
@@ -262,17 +276,7 @@ namespace xt
262276 {
263277 for (size_type i = 0 ; i < m_chunk_index.size (); ++i)
264278 {
265- if (m_chunk_index[i] == 0 )
266- {
267- m_slice_vector[i] = range (0 , p_chunked_expression->chunk_shape ()[i]);
268- }
269- else
270- {
271- size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape ()[i];
272- size_type range_end = std::min ((m_chunk_index[i] + 1 ) * p_chunked_expression->chunk_shape ()[i],
273- p_chunked_expression->shape ()[i]);
274- m_slice_vector[i] = range (range_start, range_end);
275- }
279+ fill_slice_vector (i);
276280 }
277281 }
278282
@@ -288,15 +292,12 @@ namespace xt
288292 if (m_chunk_index[i] + 1u == p_chunked_expression->grid_shape ()[i])
289293 {
290294 m_chunk_index[i] = 0 ;
291- m_slice_vector[i] = range ( 0 , p_chunked_expression-> chunk_shape ()[i] );
295+ fill_slice_vector (i );
292296 }
293297 else
294298 {
295299 m_chunk_index[i] += 1 ;
296- size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape ()[i];
297- size_type range_end = std::min ((m_chunk_index[i] + 1 ) * p_chunked_expression->chunk_shape ()[i],
298- p_chunked_expression->shape ()[i]);
299- m_slice_vector[i] = range (range_start, range_end);
300+ fill_slice_vector (i);
300301 break ;
301302 }
302303 }
@@ -336,6 +337,29 @@ namespace xt
336337 {
337338 return m_slice_vector;
338339 }
340+
341+ template <class E >
342+ inline auto xchunk_iterator<E>::get_chunk_slice_vector() const -> slice_vector
343+ {
344+ slice_vector slices (m_chunk_index.size ());
345+ for (size_type i = 0 ; i < m_chunk_index.size (); ++i)
346+ {
347+ size_type chunk_shape = p_chunked_expression->chunk_shape ()[i];
348+ size_type end = std::min (chunk_shape,
349+ p_chunked_expression->shape ()[i] - m_chunk_index[i] * chunk_shape);
350+ slices[i] = range (0u , end);
351+ }
352+ return slices;
353+ }
354+
355+ template <class E >
356+ inline void xchunk_iterator<E>::fill_slice_vector(size_type i)
357+ {
358+ size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape ()[i];
359+ size_type range_end = std::min ((m_chunk_index[i] + 1 ) * p_chunked_expression->chunk_shape ()[i],
360+ p_chunked_expression->shape ()[i]);
361+ m_slice_vector[i] = range (range_start, range_end);
362+ }
339363}
340364
341365#endif
0 commit comments