Skip to content

Commit 28c6b1f

Browse files
authored
Merge pull request #2367 from davidbrochart/specialize_chunked_view
Specialize operator= when RHS is chunked
2 parents 644970a + 9936e0e commit 28c6b1f

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

include/xtensor/xchunked_array.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ namespace xt
184184
template<class E>
185185
constexpr bool is_chunked(const xexpression<E>& e);
186186

187+
template<class E>
188+
constexpr bool is_chunked();
189+
187190
/**
188191
* Creates an in-memory chunked array.
189192
* This function returns an uninitialized ``xchunked_array<xarray<T>>``.
@@ -286,6 +289,12 @@ namespace xt
286289

287290
template<class E>
288291
constexpr bool is_chunked(const xexpression<E>&)
292+
{
293+
return is_chunked<E>();
294+
}
295+
296+
template<class E>
297+
constexpr bool is_chunked()
289298
{
290299
using return_type = typename detail::chunk_helper<E>::is_chunked;
291300
return return_type::value;

include/xtensor/xchunked_view.hpp

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
#include "xnoalias.hpp"
1616
#include "xstorage.hpp"
1717
#include "xstrided_view.hpp"
18+
#include "xchunked_array.hpp"
1819

1920
namespace xt
2021
{
2122

23+
template <class E>
24+
struct is_chunked_t: detail::chunk_helper<E>::is_chunked
25+
{
26+
};
27+
2228
/*****************
2329
* xchunked_view *
2430
*****************/
@@ -30,7 +36,7 @@ namespace xt
3036
class xchunked_view
3137
{
3238
public:
33-
39+
3440
using self_type = xchunked_view<E>;
3541
using expression_type = std::decay_t<E>;
3642
using value_type = typename expression_type::value_type;
@@ -48,7 +54,15 @@ namespace xt
4854
xchunked_view(OE&& e, S&& chunk_shape);
4955

5056
template <class OE>
51-
xchunked_view<E>& operator=(const OE& e);
57+
xchunked_view(OE&& e);
58+
59+
void init();
60+
61+
template <class OE>
62+
typename std::enable_if_t<!is_chunked_t<OE>::value, xchunked_view<E>&> operator=(const OE& e);
63+
64+
template <class OE>
65+
typename std::enable_if_t<is_chunked_t<OE>::value, xchunked_view<E>&> operator=(const OE& e);
5266

5367
size_type dimension() const noexcept;
5468
const shape_type& shape() const noexcept;
@@ -92,6 +106,22 @@ namespace xt
92106
m_shape.resize(e.dimension());
93107
const auto& s = e.shape();
94108
std::copy(s.cbegin(), s.cend(), m_shape.begin());
109+
init();
110+
}
111+
112+
template <class E>
113+
template <class OE>
114+
inline xchunked_view<E>::xchunked_view(OE&& e)
115+
: m_expression(std::forward<OE>(e))
116+
{
117+
m_shape.resize(e.dimension());
118+
const auto& s = e.shape();
119+
std::copy(s.cbegin(), s.cend(), m_shape.begin());
120+
}
121+
122+
template <class E>
123+
void xchunked_view<E>::init()
124+
{
95125
// compute chunk number in each dimension
96126
m_grid_shape.resize(m_shape.size());
97127
std::transform
@@ -114,16 +144,47 @@ namespace xt
114144

115145
template <class E>
116146
template <class OE>
117-
xchunked_view<E>& xchunked_view<E>::operator=(const OE& e)
147+
typename std::enable_if_t<!is_chunked_t<OE>::value, xchunked_view<E>&> xchunked_view<E>::operator=(const OE& e)
118148
{
119-
for (auto it = chunk_begin(); it != chunk_end(); it++)
149+
auto end = chunk_end();
150+
for (auto it = chunk_begin(); it != end; ++it)
120151
{
121152
auto el = *it;
122153
noalias(el) = strided_view(e, it.get_slice_vector());
123154
}
124155
return *this;
125156
}
126157

158+
template <class E>
159+
template <class OE>
160+
typename std::enable_if_t<is_chunked_t<OE>::value, xchunked_view<E>&> xchunked_view<E>::operator=(const OE& e)
161+
{
162+
m_chunk_shape.resize(e.dimension());
163+
const auto& cs = e.chunk_shape();
164+
std::copy(cs.cbegin(), cs.cend(), m_chunk_shape.begin());
165+
init();
166+
auto it2 = e.chunks().begin();
167+
auto end1 = chunk_end();
168+
for (auto it1 = chunk_begin(); it1 != end1; ++it1, ++it2)
169+
{
170+
auto el1 = *it1;
171+
auto el2 = *it2;
172+
auto lhs_shape = el1.shape();
173+
if (lhs_shape != el2.shape())
174+
{
175+
xstrided_slice_vector esv(el2.dimension()); // element slice in edge chunk
176+
std::transform(lhs_shape.begin(), lhs_shape.end(), esv.begin(),
177+
[](auto size) { return range(0, size); });
178+
noalias(el1) = strided_view(el2, esv);
179+
}
180+
else
181+
{
182+
noalias(el1) = el2;
183+
}
184+
}
185+
return *this;
186+
}
187+
127188
template <class E>
128189
inline auto xchunked_view<E>::dimension() const noexcept -> size_type
129190
{
@@ -209,6 +270,12 @@ namespace xt
209270
{
210271
return xchunked_view<E>(std::forward<E>(e), std::forward<S>(chunk_shape));
211272
}
273+
274+
template <class E>
275+
inline xchunked_view<E> as_chunked(E&& e)
276+
{
277+
return xchunked_view<E>(std::forward<E>(e));
278+
}
212279
}
213280

214281
#endif

0 commit comments

Comments
 (0)