@@ -494,22 +494,47 @@ namespace xt
494494 using size_type = std::size_t ;
495495 using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
496496
497- template <class S >
498- inline value_type access (const tuple_type& t, size_type axis, S index ) const
497+ template <class It >
498+ inline value_type access (const tuple_type& t, size_type axis, It first, It last ) const
499499 {
500- auto match = [&index, axis](auto & arr)
500+ // trim off extra indices if provided to match behavior of containers
501+ auto dim_offset = std::distance (first, last) - std::get<0 >(t).dimension ();
502+ size_t axis_dim = *(first + axis + dim_offset);
503+ auto match = [&](auto & arr)
501504 {
502- if (index[axis] >= arr.shape ()[axis])
505+ if (axis_dim >= arr.shape ()[axis])
503506 {
504- index[axis] -= arr.shape ()[axis];
507+ axis_dim -= arr.shape ()[axis];
505508 return false ;
506509 }
507510 return true ;
508511 };
509512
510- auto get = [&index ](auto & arr)
513+ auto get = [&](auto & arr)
511514 {
512- return arr[index];
515+ size_t offset = 0 ;
516+ const size_t end = arr.dimension ();
517+ for (size_t i = 0 ; i < end; i++)
518+ {
519+ const auto & shape = arr.shape ();
520+ const size_t stride = std::accumulate (
521+ shape.begin () + i + 1 ,
522+ shape.end (),
523+ 1 ,
524+ std::multiplies<size_t >()
525+ );
526+ if (i == axis)
527+ {
528+ offset += axis_dim * stride;
529+ }
530+ else
531+ {
532+ const auto len = (*(first + i + dim_offset));
533+ offset += len * stride;
534+ }
535+ }
536+ const auto element = arr.begin () + offset;
537+ return *element;
513538 };
514539
515540 size_type i = 0 ;
@@ -533,48 +558,68 @@ namespace xt
533558 using size_type = std::size_t ;
534559 using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
535560
536- template <class S >
537- inline value_type access (const tuple_type& t, size_type axis, S index ) const
561+ template <class It >
562+ inline value_type access (const tuple_type& t, size_type axis, It first, It ) const
538563 {
539- auto get_item = [&index ](auto & arr)
564+ auto get_item = [&](auto & arr)
540565 {
541- return arr[index];
566+ size_t offset = 0 ;
567+ const size_t end = arr.dimension ();
568+ size_t after_axis = 0 ;
569+ for (size_t i = 0 ; i < end; i++)
570+ {
571+ if (i == axis)
572+ {
573+ after_axis = 1 ;
574+ }
575+ const auto & shape = arr.shape ();
576+ const size_t stride = std::accumulate (
577+ shape.begin () + i + 1 ,
578+ shape.end (),
579+ 1 ,
580+ std::multiplies<size_t >()
581+ );
582+ const auto len = (*(first + i + after_axis));
583+ offset += len * stride;
584+ }
585+ const auto element = arr.begin () + offset;
586+ return *element;
542587 };
543- size_type i = index[axis];
544- index.erase (index.begin () + std::ptrdiff_t (axis));
588+ size_type i = *(first + axis);
545589 return apply<value_type>(i, get_item, t);
546590 }
547591 };
548592
549593 template <class ... CT>
550- class vstack_access : private concatenate_access <CT...>,
551- private stack_access<CT...>
594+ class vstack_access
552595 {
553596 public:
554597
555598 using tuple_type = std::tuple<CT...>;
556599 using size_type = std::size_t ;
557600 using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
558601
559- using concatenate_base = concatenate_access<CT...>;
560- using stack_base = stack_access<CT...>;
561-
562- template <class S >
563- inline value_type access (const tuple_type& t, size_type axis, S index) const
602+ template <class It >
603+ inline value_type access (const tuple_type& t, size_type axis, It first, It last) const
564604 {
565605 if (std::get<0 >(t).dimension () == 1 )
566606 {
567- return stack_base:: access (t, axis, index );
607+ return stack. access (t, axis, first, last );
568608 }
569609 else
570610 {
571- return concatenate_base:: access (t, axis, index );
611+ return concatonate. access (t, axis, first, last );
572612 }
573613 }
614+
615+ private:
616+
617+ concatenate_access<CT...> concatonate;
618+ stack_access<CT...> stack;
574619 };
575620
576621 template <template <class ...> class F , class ... CT>
577- class concatenate_invoker : private F <CT...>
622+ class concatenate_invoker
578623 {
579624 public:
580625
@@ -592,18 +637,19 @@ namespace xt
592637 inline value_type operator ()(Args... args) const
593638 {
594639 // TODO: avoid memory allocation
595- return this ->access (m_t , m_axis, xindex ({static_cast <size_type>(args)...}));
640+ xindex index ({static_cast <size_type>(args)...});
641+ return access_method.access (m_t , m_axis, index.begin (), index.end ());
596642 }
597643
598644 template <class It >
599645 inline value_type element (It first, It last) const
600646 {
601- // TODO: avoid memory allocation
602- return this ->access (m_t , m_axis, xindex (first, last));
647+ return access_method.access (m_t , m_axis, first, last);
603648 }
604649
605650 private:
606651
652+ F<CT...> access_method;
607653 tuple_type m_t ;
608654 size_type m_axis;
609655 };
0 commit comments