Added support of not_null<smart_ptr> comparison (#473)

* Added support of not_null<smart_ptr> comparison

* The return type of not_null comparison operators is determined using SFINAE
#474

* tests for gsl::not_null comparison were added

* not_null comparison tests were rewritten to compare pointers to objects located in the same array

* not_null<shared_ptr> comparison was simplified
This commit is contained in:
Alexey Malov 2017-04-13 03:34:39 +03:00 committed by Neil MacIntosh
parent ebab8cab7f
commit 534bb4c663
2 changed files with 181 additions and 0 deletions

View File

@ -117,6 +117,42 @@ std::ostream& operator<<(std::ostream& os, const not_null<T>& val)
return os;
}
template <class T, class U>
auto operator==(const not_null<T>& lhs, const not_null<U>& rhs) -> decltype(lhs.get() == rhs.get())
{
return lhs.get() == rhs.get();
}
template <class T, class U>
auto operator!=(const not_null<T>& lhs, const not_null<U>& rhs) -> decltype(lhs.get() != rhs.get())
{
return lhs.get() != rhs.get();
}
template <class T, class U>
auto operator<(const not_null<T>& lhs, const not_null<U>& rhs) -> decltype(lhs.get() < rhs.get())
{
return lhs.get() < rhs.get();
}
template <class T, class U>
auto operator<=(const not_null<T>& lhs, const not_null<U>& rhs) -> decltype(lhs.get() <= rhs.get())
{
return lhs.get() <= rhs.get();
}
template <class T, class U>
auto operator>(const not_null<T>& lhs, const not_null<U>& rhs) -> decltype(lhs.get() > rhs.get())
{
return lhs.get() > rhs.get();
}
template <class T, class U>
auto operator>=(const not_null<T>& lhs, const not_null<U>& rhs) -> decltype(lhs.get() >= rhs.get())
{
return lhs.get() >= rhs.get();
}
// more unwanted operators
template <class T, class U>
std::ptrdiff_t operator-(const not_null<T>&, const not_null<U>&) = delete;

View File

@ -17,6 +17,8 @@
#include <UnitTest++/UnitTest++.h>
#include <gsl/gsl>
#include <vector>
#include <memory>
#include <string>
using namespace gsl;
@ -33,6 +35,54 @@ struct RefCounted
T* p_;
};
// user defined smart pointer with comparison operators returning non bool value
template <typename T>
struct CustomPtr
{
CustomPtr(T* p) : p_(p) {}
operator T*() { return p_; }
bool operator !=(std::nullptr_t)const { return p_ != nullptr; }
T* p_ = nullptr;
};
template <typename T, typename U>
std::string operator==(CustomPtr<T> const& lhs, CustomPtr<U> const& rhs)
{
return reinterpret_cast<const void*>(lhs.p_) == reinterpret_cast<const void*>(rhs.p_) ? "true" : "false";
}
template <typename T, typename U>
std::string operator!=(CustomPtr<T> const& lhs, CustomPtr<U> const& rhs)
{
return reinterpret_cast<const void*>(lhs.p_) != reinterpret_cast<const void*>(rhs.p_) ? "true" : "false";
}
template <typename T, typename U>
std::string operator<(CustomPtr<T> const& lhs, CustomPtr<U> const& rhs)
{
return reinterpret_cast<const void*>(lhs.p_) < reinterpret_cast<const void*>(rhs.p_) ? "true" : "false";
}
template <typename T, typename U>
std::string operator>(CustomPtr<T> const& lhs, CustomPtr<U> const& rhs)
{
return reinterpret_cast<const void*>(lhs.p_) > reinterpret_cast<const void*>(rhs.p_) ? "true" : "false";
}
template <typename T, typename U>
std::string operator<=(CustomPtr<T> const& lhs, CustomPtr<U> const& rhs)
{
return reinterpret_cast<const void*>(lhs.p_) <= reinterpret_cast<const void*>(rhs.p_) ? "true" : "false";
}
template <typename T, typename U>
std::string operator>=(CustomPtr<T> const& lhs, CustomPtr<U> const& rhs)
{
return reinterpret_cast<const void*>(lhs.p_) >= reinterpret_cast<const void*>(rhs.p_) ? "true" : "false";
}
SUITE(NotNullTests)
{
@ -95,6 +145,101 @@ SUITE(NotNullTests)
int* q = nullptr;
CHECK_THROW(p = q, fail_fast);
}
TEST(TestNotNullRawPointerComparison)
{
int ints[2] = {42, 43};
int* p1 = &ints[0];
const int* p2 = &ints[1];
using NotNull1 = not_null<decltype(p1)>;
using NotNull2 = not_null<decltype(p2)>;
CHECK((NotNull1(p1) == NotNull1(p1)) == true);
CHECK((NotNull1(p1) == NotNull2(p2)) == false);
CHECK((NotNull1(p1) != NotNull1(p1)) == false);
CHECK((NotNull1(p1) != NotNull2(p2)) == true);
CHECK((NotNull1(p1) < NotNull1(p1)) == false);
CHECK((NotNull1(p1) < NotNull2(p2)) == (p1 < p2));
CHECK((NotNull2(p2) < NotNull1(p1)) == (p2 < p1));
CHECK((NotNull1(p1) > NotNull1(p1)) == false);
CHECK((NotNull1(p1) > NotNull2(p2)) == (p1 > p2));
CHECK((NotNull2(p2) > NotNull1(p1)) == (p2 > p1));
CHECK((NotNull1(p1) <= NotNull1(p1)) == true);
CHECK((NotNull1(p1) <= NotNull2(p2)) == (p1 <= p2));
CHECK((NotNull2(p2) <= NotNull1(p1)) == (p2 <= p1));
CHECK((NotNull1(p1) >= NotNull1(p1)) == true);
CHECK((NotNull1(p1) >= NotNull2(p2)) == (p1 >= p2));
CHECK((NotNull2(p2) >= NotNull1(p1)) == (p2 >= p1));
}
TEST(TestNotNullSharedPtrComparison)
{
auto sp1 = std::make_shared<int>(42);
auto sp2 = std::make_shared<const int>(43);
using NotNullSp1 = not_null<decltype(sp1)>;
using NotNullSp2 = not_null<decltype(sp2)>;
CHECK((NotNullSp1(sp1) == NotNullSp1(sp1)) == true);
CHECK((NotNullSp1(sp1) == NotNullSp2(sp2)) == false);
CHECK((NotNullSp1(sp1) != NotNullSp1(sp1)) == false);
CHECK((NotNullSp1(sp1) != NotNullSp2(sp2)) == true);
CHECK((NotNullSp1(sp1) < NotNullSp1(sp1)) == false);
CHECK((NotNullSp1(sp1) < NotNullSp2(sp2)) == (sp1 < sp2));
CHECK((NotNullSp2(sp2) < NotNullSp1(sp1)) == (sp2 < sp1));
CHECK((NotNullSp1(sp1) > NotNullSp1(sp1)) == false);
CHECK((NotNullSp1(sp1) > NotNullSp2(sp2)) == (sp1 > sp2));
CHECK((NotNullSp2(sp2) > NotNullSp1(sp1)) == (sp2 > sp1));
CHECK((NotNullSp1(sp1) <= NotNullSp1(sp1)) == true);
CHECK((NotNullSp1(sp1) <= NotNullSp2(sp2)) == (sp1 <= sp2));
CHECK((NotNullSp2(sp2) <= NotNullSp1(sp1)) == (sp2 <= sp1));
CHECK((NotNullSp1(sp1) >= NotNullSp1(sp1)) == true);
CHECK((NotNullSp1(sp1) >= NotNullSp2(sp2)) == (sp1 >= sp2));
CHECK((NotNullSp2(sp2) >= NotNullSp1(sp1)) == (sp2 >= sp1));
}
TEST(TestNotNullCustomPtrComparison)
{
int ints[2] = { 42, 43 };
CustomPtr<int> p1(&ints[0]);
CustomPtr<const int> p2(&ints[1]);
using NotNull1 = not_null<decltype(p1)>;
using NotNull2 = not_null<decltype(p2)>;
CHECK((NotNull1(p1) == NotNull1(p1)) == "true");
CHECK((NotNull1(p1) == NotNull2(p2)) == "false");
CHECK((NotNull1(p1) != NotNull1(p1)) == "false");
CHECK((NotNull1(p1) != NotNull2(p2)) == "true");
CHECK((NotNull1(p1) < NotNull1(p1)) == "false");
CHECK((NotNull1(p1) < NotNull2(p2)) == (p1 < p2));
CHECK((NotNull2(p2) < NotNull1(p1)) == (p2 < p1));
CHECK((NotNull1(p1) > NotNull1(p1)) == "false");
CHECK((NotNull1(p1) > NotNull2(p2)) == (p1 > p2));
CHECK((NotNull2(p2) > NotNull1(p1)) == (p2 > p1));
CHECK((NotNull1(p1) <= NotNull1(p1)) == "true");
CHECK((NotNull1(p1) <= NotNull2(p2)) == (p1 <= p2));
CHECK((NotNull2(p2) <= NotNull1(p1)) == (p2 <= p1));
CHECK((NotNull1(p1) >= NotNull1(p1)) == "true");
CHECK((NotNull1(p1) >= NotNull2(p2)) == (p1 >= p2));
CHECK((NotNull2(p2) >= NotNull1(p1)) == (p2 >= p1));
}
}
int main(int, const char *[])