Skip to content

Commit 5b74468

Browse files
authored
Merge pull request #2362 from JohanMabille/chunk_iterator
Added const chunk iterators
2 parents ca4c2e6 + 60ce400 commit 5b74468

File tree

4 files changed

+126
-15
lines changed

4 files changed

+126
-15
lines changed

include/xtensor/xchunked_array.hpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ namespace xt
7575
using bool_load_type = xt::bool_load_type<value_type>;
7676
static constexpr layout_type static_layout = layout_type::dynamic;
7777
static constexpr bool contiguous_layout = false;
78-
using chunk_iterator_type = xchunk_iterator<self_type>;
78+
using chunk_iterator = xchunk_iterator<self_type>;
79+
using const_chunk_iterator = xchunk_iterator<const self_type>;
7980

8081
template <class S>
8182
xchunked_array(chunk_storage_type&& chunks, S&& shape, S&& chunk_shape, layout_type chunk_memory_layout = XTENSOR_DEFAULT_LAYOUT);
@@ -136,8 +137,13 @@ namespace xt
136137
chunk_storage_type& chunks();
137138
const chunk_storage_type& chunks() const;
138139

139-
chunk_iterator_type chunk_begin();
140-
chunk_iterator_type chunk_end();
140+
chunk_iterator chunk_begin();
141+
chunk_iterator chunk_end();
142+
143+
const_chunk_iterator chunk_begin() const;
144+
const_chunk_iterator chunk_end() const;
145+
const_chunk_iterator chunk_cbegin() const;
146+
const_chunk_iterator chunk_cend() const;
141147

142148
private:
143149

@@ -489,17 +495,44 @@ namespace xt
489495
}
490496

491497
template <class CS>
492-
inline auto xchunked_array<CS>::chunk_begin() -> chunk_iterator_type
498+
inline auto xchunked_array<CS>::chunk_begin() -> chunk_iterator
499+
{
500+
shape_type chunk_index(m_shape.size(), size_type(0));
501+
return chunk_iterator(*this, std::move(chunk_index), 0u);
502+
}
503+
504+
template <class CS>
505+
inline auto xchunked_array<CS>::chunk_end() -> chunk_iterator
506+
{
507+
shape_type sh = xtl::forward_sequence<shape_type, const grid_shape_type>(grid_shape());
508+
return chunk_iterator(*this, std::move(sh), grid_size());
509+
}
510+
511+
template <class CS>
512+
inline auto xchunked_array<CS>::chunk_begin() const -> const_chunk_iterator
493513
{
494514
shape_type chunk_index(m_shape.size(), size_type(0));
495-
return chunk_iterator_type(*this, std::move(chunk_index), 0u);
515+
return const_chunk_iterator(*this, std::move(chunk_index), 0u);
496516
}
497517

498518
template <class CS>
499-
inline auto xchunked_array<CS>::chunk_end() -> chunk_iterator_type
519+
inline auto xchunked_array<CS>::chunk_end() const -> const_chunk_iterator
500520
{
501521
shape_type sh = xtl::forward_sequence<shape_type, const grid_shape_type>(grid_shape());
502-
return chunk_iterator_type(*this, std::move(sh), grid_size());
522+
return const_chunk_iterator(*this, std::move(sh), grid_size());
523+
}
524+
525+
template <class CS>
526+
inline auto xchunked_array<CS>::chunk_cbegin() const -> const_chunk_iterator
527+
{
528+
return chunk_begin();
529+
}
530+
531+
template <class CS>
532+
inline auto xchunked_array<CS>::chunk_cend() const -> const_chunk_iterator
533+
{
534+
return chunk_end();
535+
503536
}
504537

505538
template <class CS>

include/xtensor/xchunked_assign.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ namespace xt
110110
template <class A>
111111
struct xchunk_iterator_array
112112
{
113+
using reference = decltype(*(std::declval<A>().chunks().begin()));
114+
113115
inline decltype(auto) get_chunk(A& arr, typename A::size_type i, const xstrided_slice_vector&) const
114116
{
115117
return *(arr.chunks().begin() + i);
@@ -119,6 +121,8 @@ namespace xt
119121
template <class V>
120122
struct xchunk_iterator_view
121123
{
124+
using reference = decltype(xt::strided_view(std::declval<V>().expression(), std::declval<xstrided_slice_vector>()));
125+
122126
inline auto get_chunk(V& view, typename V::size_type, const xstrided_slice_vector& sv) const
123127
{
124128
return xt::strided_view(view.expression(), sv);
@@ -148,6 +152,13 @@ namespace xt
148152
using shape_type = typename E::shape_type;
149153
using slice_vector = xstrided_slice_vector;
150154

155+
using reference = typename base_type::reference;
156+
using value_type = std::remove_reference_t<reference>;
157+
using pointer = value_type*;
158+
using difference_type = typename E::difference_type;
159+
using iterator_category = std::forward_iterator_tag;
160+
161+
151162
xchunk_iterator() = default;
152163
xchunk_iterator(E& chunked_expression,
153164
shape_type&& chunk_index,

include/xtensor/xchunked_view.hpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,16 @@ namespace xt
3333

3434
using self_type = xchunked_view<E>;
3535
using expression_type = std::decay_t<E>;
36-
using size_type = size_t;
36+
using value_type = typename expression_type::value_type;
37+
using reference = typename expression_type::reference;
38+
using const_reference = typename expression_type::const_reference;
39+
using pointer = typename expression_type::pointer;
40+
using const_pointer = typename expression_type::const_pointer;
41+
using size_type = typename expression_type::size_type;
42+
using difference_type = typename expression_type::difference_type;
3743
using shape_type = svector<size_type>;
38-
using chunk_iterator_type = xchunk_iterator<self_type>;
44+
using chunk_iterator = xchunk_iterator<self_type>;
45+
using const_chunk_iterator = xchunk_iterator<const self_type>;
3946

4047
template <class OE, class S>
4148
xchunked_view(OE&& e, S&& chunk_shape);
@@ -52,8 +59,13 @@ namespace xt
5259
expression_type& expression() noexcept;
5360
const expression_type& expression() const noexcept;
5461

55-
chunk_iterator_type chunk_begin();
56-
chunk_iterator_type chunk_end();
62+
chunk_iterator chunk_begin();
63+
chunk_iterator chunk_end();
64+
65+
const_chunk_iterator chunk_begin() const;
66+
const_chunk_iterator chunk_end() const;
67+
const_chunk_iterator chunk_cbegin() const;
68+
const_chunk_iterator chunk_cend() const;
5769

5870
private:
5971

@@ -155,16 +167,41 @@ namespace xt
155167
}
156168

157169
template <class E>
158-
inline auto xchunked_view<E>::chunk_begin() -> chunk_iterator_type
170+
inline auto xchunked_view<E>::chunk_begin() -> chunk_iterator
171+
{
172+
shape_type chunk_index(m_shape.size(), size_type(0));
173+
return chunk_iterator(*this, std::move(chunk_index), 0u);
174+
}
175+
176+
template <class E>
177+
inline auto xchunked_view<E>::chunk_end() -> chunk_iterator
178+
{
179+
return chunk_iterator(*this, shape_type(grid_shape()), grid_size());
180+
}
181+
182+
template <class E>
183+
inline auto xchunked_view<E>::chunk_begin() const -> const_chunk_iterator
159184
{
160185
shape_type chunk_index(m_shape.size(), size_type(0));
161-
return chunk_iterator_type(*this, std::move(chunk_index), 0u);
186+
return const_chunk_iterator(*this, std::move(chunk_index), 0u);
187+
}
188+
189+
template <class E>
190+
inline auto xchunked_view<E>::chunk_end() const -> const_chunk_iterator
191+
{
192+
return const_chunk_iterator(*this, shape_type(grid_shape()), grid_size());
193+
}
194+
195+
template <class E>
196+
inline auto xchunked_view<E>::chunk_cbegin() const -> const_chunk_iterator
197+
{
198+
return chunk_begin();
162199
}
163200

164201
template <class E>
165-
inline auto xchunked_view<E>::chunk_end() -> chunk_iterator_type
202+
inline auto xchunked_view<E>::chunk_cend() const -> const_chunk_iterator
166203
{
167-
return chunk_iterator_type(*this, shape_type(grid_shape()), grid_size());
204+
return chunk_end();
168205
}
169206

170207
template <class E, class S>

test/test_xchunked_array.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,34 @@ namespace xt
123123

124124
EXPECT_EQ(a, b);
125125
}
126+
127+
TEST(xchunked_array, chunk_iterator)
128+
{
129+
std::vector<std::size_t> shape = {10, 10, 10};
130+
std::vector<std::size_t> chunk_shape = {2, 2, 2};
131+
auto a = chunked_array<double>(shape, chunk_shape);
132+
xt::xarray<double> b = arange(1000).reshape({10, 10, 10});
133+
noalias(a) = b;
134+
135+
auto it = a.chunk_begin();
136+
auto cit = a.chunk_cbegin();
137+
138+
for (size_t i = 0; i < 5; ++i)
139+
{
140+
for (size_t j = 0; j < 5; ++j)
141+
{
142+
for (size_t k = 0; k < 5; ++k)
143+
{
144+
EXPECT_EQ(*((*it).begin()), a(2*i, 2*j, 2*k));
145+
EXPECT_EQ(*((*cit).cbegin()), a(2*i, 2*j, 2*k));
146+
++it;
147+
++cit;
148+
}
149+
}
150+
}
151+
152+
it = a.chunk_begin();
153+
std::advance(it, 2);
154+
EXPECT_EQ(*((*it).begin()), a(0, 0, 4));
155+
}
126156
}

0 commit comments

Comments
 (0)