Removed arrow_proxy class, fixes bugs in reverse bounds_iterator

This commit is contained in:
Anna Gringauze 2015-10-16 12:15:22 -07:00
parent c973e82dff
commit a4654a46b5
2 changed files with 199 additions and 133 deletions

View File

@ -74,25 +74,6 @@ namespace details
{ {
static const SizeType max_value = std::is_signed<SizeType>::value ? static_cast<typename std::make_unsigned<SizeType>::type>(-1) / 2 : static_cast<SizeType>(-1); static const SizeType max_value = std::is_signed<SizeType>::value ? static_cast<typename std::make_unsigned<SizeType>::type>(-1) / 2 : static_cast<SizeType>(-1);
}; };
template <typename T>
class arrow_proxy
{
public:
explicit arrow_proxy(T t)
: val(t)
{}
const T operator*() const noexcept
{
return val;
}
const T* operator->() const noexcept
{
return &val;
}
private:
T val;
};
} }
template <size_t Rank, typename ValueType = size_t> template <size_t Rank, typename ValueType = size_t>
@ -730,8 +711,9 @@ public:
using size_type = SizeType; using size_type = SizeType;
using index_type = index<rank, size_type>; using index_type = index<rank, size_type>;
using iterator = bounds_iterator<index_type>; using const_index_type = std::add_const_t<index_type>;
using const_iterator = bounds_iterator<index_type>; using iterator = bounds_iterator<const_index_type>;
using const_iterator = bounds_iterator<const_index_type>;
using difference_type = ptrdiff_t; using difference_type = ptrdiff_t;
using sliced_type = static_bounds<SizeType, RestRanges...>; using sliced_type = static_bounds<SizeType, RestRanges...>;
using mapping_type = contiguous_mapping_tag; using mapping_type = contiguous_mapping_tag;
@ -822,7 +804,7 @@ public:
constexpr const_iterator begin() const noexcept constexpr const_iterator begin() const noexcept
{ {
return const_iterator(*this); return const_iterator(*this, index_type{});
} }
constexpr const_iterator end() const noexcept constexpr const_iterator end() const noexcept
@ -845,8 +827,9 @@ public:
using difference_type = SizeType; using difference_type = SizeType;
using value_type = SizeType; using value_type = SizeType;
using index_type = index<rank, size_type>; using index_type = index<rank, size_type>;
using iterator = bounds_iterator<index_type>; using const_index_type = std::add_const_t<index_type>;
using const_iterator = bounds_iterator<index_type>; using iterator = bounds_iterator<const_index_type>;
using const_iterator = bounds_iterator<const_index_type>;
static const int dynamic_rank = rank; static const int dynamic_rank = rank;
static const size_t static_size = dynamic_range; static const size_t static_size = dynamic_range;
using sliced_type = std::conditional_t<rank != 0, strided_bounds<rank - 1>, void>; using sliced_type = std::conditional_t<rank != 0, strided_bounds<rank - 1>, void>;
@ -920,11 +903,11 @@ public:
{ {
return m_extents; return m_extents;
} }
const_iterator begin() const noexcept constexpr const_iterator begin() const noexcept
{ {
return const_iterator{ *this }; return const_iterator{ *this, index_type{} };
} }
const_iterator end() const noexcept constexpr const_iterator end() const noexcept
{ {
return const_iterator{ *this, index_bounds() }; return const_iterator{ *this, index_bounds() };
} }
@ -941,15 +924,11 @@ template <size_t Rank, typename SizeType>
struct is_bounds<strided_bounds<Rank, SizeType>> : std::integral_constant<bool, true> {}; struct is_bounds<strided_bounds<Rank, SizeType>> : std::integral_constant<bool, true> {};
template <typename IndexType> template <typename IndexType>
class bounds_iterator class bounds_iterator: public std::iterator<std::random_access_iterator_tag, IndexType>
: public std::iterator<std::random_access_iterator_tag,
IndexType,
ptrdiff_t,
const details::arrow_proxy<IndexType>,
const IndexType>
{ {
private: private:
using Base = std::iterator <std::random_access_iterator_tag, IndexType, ptrdiff_t, const details::arrow_proxy<IndexType>, const IndexType>; using Base = std::iterator <std::random_access_iterator_tag, IndexType>;
public: public:
static const size_t rank = IndexType::rank; static const size_t rank = IndexType::rank;
using typename Base::reference; using typename Base::reference;
@ -959,79 +938,88 @@ public:
using index_type = value_type; using index_type = value_type;
using index_size_type = typename IndexType::value_type; using index_size_type = typename IndexType::value_type;
template <typename Bounds> template <typename Bounds>
explicit bounds_iterator(const Bounds& bnd, value_type curr = value_type{}) noexcept explicit bounds_iterator(const Bounds& bnd, value_type curr) noexcept
: boundary(bnd.index_bounds()) : boundary(bnd.index_bounds()), curr(std::move(curr))
, curr(std::move(curr))
{ {
static_assert(is_bounds<Bounds>::value, "Bounds type must be provided"); static_assert(is_bounds<Bounds>::value, "Bounds type must be provided");
} }
reference operator*() const noexcept
constexpr reference operator*() const noexcept
{ {
return curr; return curr;
} }
pointer operator->() const noexcept
constexpr pointer operator->() const noexcept
{ {
return details::arrow_proxy<value_type>{ curr }; return &curr;
} }
bounds_iterator& operator++() noexcept
constexpr bounds_iterator& operator++() noexcept
{ {
for (size_t i = rank; i-- > 0;) for (size_t i = rank; i-- > 0;)
{ {
if (++curr[i] < boundary[i]) if (curr[i] < boundary[i] - 1)
{ {
curr[i]++;
return *this; return *this;
} }
else
{
curr[i] = 0; curr[i] = 0;
} }
}
// If we're here we've wrapped over - set to past-the-end. // If we're here we've wrapped over - set to past-the-end.
for (size_t i = 0; i < rank; ++i) curr = boundary;
{
curr[i] = boundary[i];
}
return *this; return *this;
} }
bounds_iterator operator++(int) noexcept
constexpr bounds_iterator operator++(int) noexcept
{ {
auto ret = *this; auto ret = *this;
++(*this); ++(*this);
return ret; return ret;
} }
bounds_iterator& operator--() noexcept
constexpr bounds_iterator& operator--() noexcept
{ {
for (size_t i = rank; i-- > 0;) if (!less(curr, boundary))
{ {
if (curr[i]-- > 0) // if at the past-the-end, set to last element
{ for (size_t i = 0; i < rank; ++i)
return *this;
}
else
{ {
curr[i] = boundary[i] - 1; curr[i] = boundary[i] - 1;
} }
return *this;
}
for (size_t i = rank; i-- > 0;)
{
if (curr[i] >= 1)
{
curr[i]--;
return *this;
}
curr[i] = boundary[i] - 1;
} }
// If we're here the preconditions were violated // If we're here the preconditions were violated
// "pre: there exists s such that r == ++s" // "pre: there exists s such that r == ++s"
fail_fast_assert(false); fail_fast_assert(false);
return *this; return *this;
} }
bounds_iterator operator--(int) noexcept
constexpr bounds_iterator operator--(int) noexcept
{ {
auto ret = *this; auto ret = *this;
--(*this); --(*this);
return ret; return ret;
} }
bounds_iterator operator+(difference_type n) const noexcept
constexpr bounds_iterator operator+(difference_type n) const noexcept
{ {
bounds_iterator ret{ *this }; bounds_iterator ret{ *this };
return ret += n; return ret += n;
} }
bounds_iterator& operator+=(difference_type n) noexcept
constexpr bounds_iterator& operator+=(difference_type n) noexcept
{ {
auto linear_idx = linearize(curr) + n; auto linear_idx = linearize(curr) + n;
value_type stride; std::remove_const_t<value_type> stride;
stride[rank - 1] = 1; stride[rank - 1] = 1;
for (size_t i = rank - 1; i-- > 0;) for (size_t i = rank - 1; i-- > 0;)
{ {
@ -1042,76 +1030,84 @@ public:
curr[i] = linear_idx / stride[i]; curr[i] = linear_idx / stride[i];
linear_idx = linear_idx % stride[i]; linear_idx = linear_idx % stride[i];
} }
fail_fast_assert(!less(curr, index_type{}) && !less(boundary, curr), "index is out of bounds of the array");
return *this; return *this;
} }
bounds_iterator operator-(difference_type n) const noexcept
constexpr bounds_iterator operator-(difference_type n) const noexcept
{ {
bounds_iterator ret{ *this }; bounds_iterator ret{ *this };
return ret -= n; return ret -= n;
} }
bounds_iterator& operator-=(difference_type n) noexcept
constexpr bounds_iterator& operator-=(difference_type n) noexcept
{ {
return *this += -n; return *this += -n;
} }
difference_type operator-(const bounds_iterator& rhs) const noexcept
constexpr difference_type operator-(const bounds_iterator& rhs) const noexcept
{ {
return linearize(curr) - linearize(rhs.curr); return linearize(curr) - linearize(rhs.curr);
} }
reference operator[](difference_type n) const noexcept
constexpr reference operator[](difference_type n) const noexcept
{ {
return *(*this + n); return *(*this + n);
} }
bool operator==(const bounds_iterator& rhs) const noexcept
constexpr bool operator==(const bounds_iterator& rhs) const noexcept
{ {
return curr == rhs.curr; return curr == rhs.curr;
} }
bool operator!=(const bounds_iterator& rhs) const noexcept
constexpr bool operator!=(const bounds_iterator& rhs) const noexcept
{ {
return !(*this == rhs); return !(*this == rhs);
} }
bool operator<(const bounds_iterator& rhs) const noexcept
constexpr bool operator<(const bounds_iterator& rhs) const noexcept
{ {
for (size_t i = 0; i < rank; ++i) return less(curr, rhs.curr);
{
if (curr[i] < rhs.curr[i])
return true;
} }
return false;
} constexpr bool operator<=(const bounds_iterator& rhs) const noexcept
bool operator<=(const bounds_iterator& rhs) const noexcept
{ {
return !(rhs < *this); return !(rhs < *this);
} }
bool operator>(const bounds_iterator& rhs) const noexcept
constexpr bool operator>(const bounds_iterator& rhs) const noexcept
{ {
return rhs < *this; return rhs < *this;
} }
bool operator>=(const bounds_iterator& rhs) const noexcept
constexpr bool operator>=(const bounds_iterator& rhs) const noexcept
{ {
return !(rhs > *this); return !(rhs > *this);
} }
void swap(bounds_iterator& rhs) noexcept void swap(bounds_iterator& rhs) noexcept
{ {
std::swap(boundary, rhs.boundary); std::swap(boundary, rhs.boundary);
std::swap(curr, rhs.curr); std::swap(curr, rhs.curr);
} }
private: private:
index_size_type linearize(const value_type& idx) const noexcept constexpr bool less(index_type& one, index_type& other) const noexcept
{
for (size_t i = 0; i < rank; ++i)
{
if (one[i] < other[i])
return true;
}
return false;
}
constexpr index_size_type linearize(const value_type& idx) const noexcept
{ {
// TODO: Smarter impl. // TODO: Smarter impl.
// Check if past-the-end // Check if past-the-end
bool pte = true;
for (size_t i = 0; i < rank; ++i)
{
if (idx[i] != boundary[i])
{
pte = false;
break;
}
}
index_size_type multiplier = 1; index_size_type multiplier = 1;
index_size_type res = 0; index_size_type res = 0;
if (pte) if (!less(idx, boundary))
{ {
res = 1; res = 1;
for (size_t i = rank; i-- > 0;) for (size_t i = rank; i-- > 0;)
@ -1130,19 +1126,15 @@ private:
} }
return res; return res;
} }
value_type boundary; value_type boundary;
value_type curr; std::remove_const_t<value_type> curr;
}; };
template <typename SizeType> template <typename SizeType>
class bounds_iterator<index<1, SizeType>> class bounds_iterator<index<1, SizeType>> : public std::iterator<std::random_access_iterator_tag, index<1, SizeType>>
: public std::iterator<std::random_access_iterator_tag,
index<1, SizeType>,
ptrdiff_t,
const details::arrow_proxy<index<1, SizeType>>,
const index<1, SizeType>>
{ {
using Base = std::iterator<std::random_access_iterator_tag, index<1, SizeType>, ptrdiff_t, const details::arrow_proxy<index<1, SizeType>>, const index<1, SizeType>>; using Base = std::iterator<std::random_access_iterator_tag, index<1, SizeType>>;
public: public:
using typename Base::reference; using typename Base::reference;
@ -1153,96 +1145,116 @@ public:
using index_size_type = typename index_type::value_type; using index_size_type = typename index_type::value_type;
template <typename Bounds> template <typename Bounds>
explicit bounds_iterator(const Bounds &, value_type curr = value_type{}) noexcept constexpr explicit bounds_iterator(const Bounds&, value_type curr) noexcept
: curr(std::move(curr)) : curr(std::move(curr))
{} {}
reference operator*() const noexcept
constexpr reference operator*() const noexcept
{ {
return curr; return curr;
} }
pointer operator->() const noexcept
constexpr pointer operator->() const noexcept
{ {
return details::arrow_proxy<value_type>{ curr }; &curr;
} }
bounds_iterator& operator++() noexcept
constexpr bounds_iterator& operator++() noexcept
{ {
++curr; ++curr;
return *this; return *this;
} }
bounds_iterator operator++(int) noexcept
constexpr bounds_iterator operator++(int) noexcept
{ {
auto ret = *this; auto ret = *this;
++(*this); ++(*this);
return ret; return ret;
} }
bounds_iterator& operator--() noexcept
constexpr bounds_iterator& operator--() noexcept
{ {
curr--; curr--;
return *this; return *this;
} }
bounds_iterator operator--(int) noexcept
constexpr bounds_iterator operator--(int) noexcept
{ {
auto ret = *this; auto ret = *this;
--(*this); --(*this);
return ret; return ret;
} }
bounds_iterator operator+(difference_type n) const noexcept
constexpr bounds_iterator operator+(difference_type n) const noexcept
{ {
bounds_iterator ret{ *this }; bounds_iterator ret{ *this };
return ret += n; return ret += n;
} }
bounds_iterator& operator+=(difference_type n) noexcept
constexpr bounds_iterator& operator+=(difference_type n) noexcept
{ {
curr += n; curr += n;
return *this; return *this;
} }
bounds_iterator operator-(difference_type n) const noexcept
constexpr bounds_iterator operator-(difference_type n) const noexcept
{ {
bounds_iterator ret{ *this }; bounds_iterator ret{ *this };
return ret -= n; return ret -= n;
} }
bounds_iterator& operator-=(difference_type n) noexcept
constexpr bounds_iterator& operator-=(difference_type n) noexcept
{ {
return *this += -n; return *this += -n;
} }
difference_type operator-(const bounds_iterator& rhs) const noexcept
constexpr difference_type operator-(const bounds_iterator& rhs) const noexcept
{ {
return curr[0] - rhs.curr[0]; return curr[0] - rhs.curr[0];
} }
reference operator[](difference_type n) const noexcept
constexpr reference operator[](difference_type n) const noexcept
{ {
return curr + n; return curr + n;
} }
bool operator==(const bounds_iterator& rhs) const noexcept
constexpr bool operator==(const bounds_iterator& rhs) const noexcept
{ {
return curr == rhs.curr; return curr == rhs.curr;
} }
bool operator!=(const bounds_iterator& rhs) const noexcept
constexpr bool operator!=(const bounds_iterator& rhs) const noexcept
{ {
return !(*this == rhs); return !(*this == rhs);
} }
bool operator<(const bounds_iterator& rhs) const noexcept
constexpr bool operator<(const bounds_iterator& rhs) const noexcept
{ {
return curr[0] < rhs.curr[0]; return curr[0] < rhs.curr[0];
} }
bool operator<=(const bounds_iterator& rhs) const noexcept
constexpr bool operator<=(const bounds_iterator& rhs) const noexcept
{ {
return !(rhs < *this); return !(rhs < *this);
} }
bool operator>(const bounds_iterator& rhs) const noexcept
constexpr bool operator>(const bounds_iterator& rhs) const noexcept
{ {
return rhs < *this; return rhs < *this;
} }
bool operator>=(const bounds_iterator& rhs) const noexcept
constexpr bool operator>=(const bounds_iterator& rhs) const noexcept
{ {
return !(rhs > *this); return !(rhs > *this);
} }
void swap(bounds_iterator& rhs) noexcept
constexpr void swap(bounds_iterator& rhs) noexcept
{ {
std::swap(curr, rhs.curr); std::swap(curr, rhs.curr);
} }
private: private:
value_type curr; std::remove_const_t<value_type> curr;
}; };
template <typename IndexType> template <typename IndexType>
@ -1304,10 +1316,11 @@ public:
using size_type = typename bounds_type::size_type; using size_type = typename bounds_type::size_type;
using index_type = typename bounds_type::index_type; using index_type = typename bounds_type::index_type;
using value_type = ValueType; using value_type = ValueType;
using const_value_type = std::add_const_t<value_type>;
using pointer = ValueType*; using pointer = ValueType*;
using reference = ValueType&; using reference = ValueType&;
using iterator = std::conditional_t<std::is_same<typename BoundsType::mapping_type, contiguous_mapping_tag>::value, contiguous_array_view_iterator<basic_array_view>, general_array_view_iterator<basic_array_view>>; using iterator = std::conditional_t<std::is_same<typename BoundsType::mapping_type, contiguous_mapping_tag>::value, contiguous_array_view_iterator<basic_array_view>, general_array_view_iterator<basic_array_view>>;
using const_iterator = std::conditional_t<std::is_same<typename BoundsType::mapping_type, contiguous_mapping_tag>::value, contiguous_array_view_iterator<basic_array_view<const ValueType, BoundsType>>, general_array_view_iterator<basic_array_view<const ValueType, BoundsType>>>; using const_iterator = std::conditional_t<std::is_same<typename BoundsType::mapping_type, contiguous_mapping_tag>::value, contiguous_array_view_iterator<basic_array_view<const_value_type, BoundsType>>, general_array_view_iterator<basic_array_view<const_value_type, BoundsType>>>;
using reverse_iterator = std::reverse_iterator<iterator>; using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>; using const_reverse_iterator = std::reverse_iterator<const_iterator>;
using sliced_type = std::conditional_t<rank == 1, value_type, basic_array_view<value_type, typename BoundsType::sliced_type>>; using sliced_type = std::conditional_t<rank == 1, value_type, basic_array_view<value_type, typename BoundsType::sliced_type>>;
@ -1360,7 +1373,7 @@ public:
} }
constexpr iterator end() const constexpr iterator end() const
{ {
return iterator {this}; return iterator {this, false};
} }
constexpr const_iterator cbegin() const constexpr const_iterator cbegin() const
{ {
@ -1368,7 +1381,7 @@ public:
} }
constexpr const_iterator cend() const constexpr const_iterator cend() const
{ {
return const_iterator {reinterpret_cast<const basic_array_view<const value_type, bounds_type> *>(this)}; return const_iterator {reinterpret_cast<const basic_array_view<const value_type, bounds_type> *>(this), false};
} }
constexpr reverse_iterator rbegin() const constexpr reverse_iterator rbegin() const
@ -1999,7 +2012,7 @@ private:
{ {
fail_fast_assert(m_pdata >= m_validator->m_pdata && m_pdata < m_validator->m_pdata + m_validator->size(), "iterator is out of range of the array"); fail_fast_assert(m_pdata >= m_validator->m_pdata && m_pdata < m_validator->m_pdata + m_validator->size(), "iterator is out of range of the array");
} }
contiguous_array_view_iterator (const ArrayView *container, bool isbegin = false) : contiguous_array_view_iterator (const ArrayView *container, bool isbegin) :
m_pdata(isbegin ? container->m_pdata : container->m_pdata + container->size()), m_validator(container) {} m_pdata(isbegin ? container->m_pdata : container->m_pdata + container->size()), m_validator(container) {}
public: public:
reference operator*() const noexcept reference operator*() const noexcept
@ -2115,16 +2128,16 @@ private:
friend class basic_array_view; friend class basic_array_view;
const ArrayView * m_container; const ArrayView * m_container;
typename ArrayView::bounds_type::iterator m_itr; typename ArrayView::bounds_type::iterator m_itr;
general_array_view_iterator(const ArrayView *container, bool isbegin = false) : general_array_view_iterator(const ArrayView *container, bool isbegin) :
m_container(container), m_itr(isbegin ? m_container->bounds().begin() : m_container->bounds().end()) m_container(container), m_itr(isbegin ? m_container->bounds().begin() : m_container->bounds().end())
{ {
} }
public: public:
reference operator*() const noexcept reference operator*() noexcept
{ {
return (*m_container)[*m_itr]; return (*m_container)[*m_itr];
} }
pointer operator->() const noexcept pointer operator->() noexcept
{ {
return &(*m_container)[*m_itr]; return &(*m_container)[*m_itr];
} }

View File

@ -925,12 +925,36 @@ SUITE(array_view_tests)
} }
} }
size_t check_sum = 0;
for (size_t i = 0; i < length; ++i)
{
check_sum += av[i][1];
}
{
size_t idx = 0; size_t idx = 0;
size_t sum = 0;
for (auto num : section) for (auto num : section)
{ {
CHECK(num == av[idx][1]); CHECK(num == av[idx][1]);
sum += num;
idx++; idx++;
} }
CHECK(sum == check_sum);
}
{
size_t idx = length - 1;
size_t sum = 0;
for (auto iter = section.rbegin(); iter != section.rend(); ++iter)
{
CHECK(*iter == av[idx][1]);
sum += *iter;
idx--;
}
CHECK(sum == check_sum);
}
} }
TEST(array_view_section_iteration) TEST(array_view_section_iteration)
@ -1714,7 +1738,36 @@ SUITE(array_view_tests)
CHECK(wav.data() == (byte*)&a[0]); CHECK(wav.data() == (byte*)&a[0]);
CHECK(wav.length() == sizeof(a)); CHECK(wav.length() == sizeof(a));
} }
}
TEST(NonConstIterator)
{
int a[] = { 1, 2, 3, 4 };
{
array_view<int, dynamic_range> av = a;
auto wav = av.as_writeable_bytes();
for (auto& b : wav)
{
b = byte(0);
}
for (size_t i = 0; i < 4; ++i)
{
CHECK(a[i] == 0);
}
}
{
array_view<int, dynamic_range> av = a;
for (auto& n : av)
{
n = 1;
}
for (size_t i = 0; i < 4; ++i)
{
CHECK(a[i] == 1);
}
}
} }
TEST(ArrayViewComparison) TEST(ArrayViewComparison)