diff --git a/include/xtensor/misc/xmanipulation.hpp b/include/xtensor/misc/xmanipulation.hpp index f100744ac..9b95015a7 100644 --- a/include/xtensor/misc/xmanipulation.hpp +++ b/include/xtensor/misc/xmanipulation.hpp @@ -1052,17 +1052,15 @@ namespace xt { auto cpy = empty_like(e); const auto& shape = cpy.shape(); - std::size_t saxis = static_cast(axis); - if (axis < 0) - { - axis += std::ptrdiff_t(cpy.dimension()); - } + const auto dim = cpy.dimension(); - if (saxis >= cpy.dimension() || axis < 0) + if (axis < -static_cast(dim) || axis >= static_cast(dim)) { - XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension."); + XTENSOR_THROW(std::runtime_error, "axis is not within shape dimension."); } + std::size_t saxis = normalize_axis(dim, axis); + const auto axis_dim = static_cast(shape[saxis]); while (shift < 0) { diff --git a/test/test_xmanipulation.cpp b/test/test_xmanipulation.cpp index 5e1d43cc5..f105a5a7d 100644 --- a/test/test_xmanipulation.cpp +++ b/test/test_xmanipulation.cpp @@ -502,6 +502,18 @@ namespace xt xarray expected8 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}}; ASSERT_EQ(expected8, xt::roll(e2, -2, /*axis*/ 2)); + + EXPECT_THROW(xt::roll(e2, 1, /*axis*/ 3), std::runtime_error); + EXPECT_THROW(xt::roll(e2, 1, /*axis*/ -4), std::runtime_error); + + xarray expected9 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}}; + ASSERT_EQ(expected9, xt::roll(e2, -2, /*axis*/ -1)); + + xarray expected10 = {{{1, 2, 3}}, {{4, 5, 6}}, {{7, 8, 9}}}; + ASSERT_EQ(expected10, xt::roll(e2, -2, /*axis*/ -2)); + + xarray expected11 = {{{4, 5, 6}}, {{7, 8, 9}}, {{1, 2, 3}}}; + ASSERT_EQ(expected11, xt::roll(e2, 2, /*axis*/ -3)); } TEST(xmanipulation, repeat_all_elements_of_axis_0_of_int_array_2_times)