diff --git a/include/gsl/gsl b/include/gsl/gsl index 26742be..b3dbe33 100644 --- a/include/gsl/gsl +++ b/include/gsl/gsl @@ -117,6 +117,42 @@ std::ostream& operator<<(std::ostream& os, const not_null& val) return os; } +template +auto operator==(const not_null& lhs, const not_null& rhs) -> decltype(lhs.get() == rhs.get()) +{ + return lhs.get() == rhs.get(); +} + +template +auto operator!=(const not_null& lhs, const not_null& rhs) -> decltype(lhs.get() != rhs.get()) +{ + return lhs.get() != rhs.get(); +} + +template +auto operator<(const not_null& lhs, const not_null& rhs) -> decltype(lhs.get() < rhs.get()) +{ + return lhs.get() < rhs.get(); +} + +template +auto operator<=(const not_null& lhs, const not_null& rhs) -> decltype(lhs.get() <= rhs.get()) +{ + return lhs.get() <= rhs.get(); +} + +template +auto operator>(const not_null& lhs, const not_null& rhs) -> decltype(lhs.get() > rhs.get()) +{ + return lhs.get() > rhs.get(); +} + +template +auto operator>=(const not_null& lhs, const not_null& rhs) -> decltype(lhs.get() >= rhs.get()) +{ + return lhs.get() >= rhs.get(); +} + // more unwanted operators template std::ptrdiff_t operator-(const not_null&, const not_null&) = delete; diff --git a/tests/notnull_tests.cpp b/tests/notnull_tests.cpp index 201b797..1c2b50e 100644 --- a/tests/notnull_tests.cpp +++ b/tests/notnull_tests.cpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include using namespace gsl; @@ -33,6 +35,54 @@ struct RefCounted T* p_; }; +// user defined smart pointer with comparison operators returning non bool value +template +struct CustomPtr +{ + CustomPtr(T* p) : p_(p) {} + operator T*() { return p_; } + bool operator !=(std::nullptr_t)const { return p_ != nullptr; } + T* p_ = nullptr; +}; + +template +std::string operator==(CustomPtr const& lhs, CustomPtr const& rhs) +{ + return reinterpret_cast(lhs.p_) == reinterpret_cast(rhs.p_) ? "true" : "false"; +} + +template +std::string operator!=(CustomPtr const& lhs, CustomPtr const& rhs) +{ + return reinterpret_cast(lhs.p_) != reinterpret_cast(rhs.p_) ? "true" : "false"; +} + +template +std::string operator<(CustomPtr const& lhs, CustomPtr const& rhs) +{ + return reinterpret_cast(lhs.p_) < reinterpret_cast(rhs.p_) ? "true" : "false"; +} + +template +std::string operator>(CustomPtr const& lhs, CustomPtr const& rhs) +{ + return reinterpret_cast(lhs.p_) > reinterpret_cast(rhs.p_) ? "true" : "false"; +} + +template +std::string operator<=(CustomPtr const& lhs, CustomPtr const& rhs) +{ + return reinterpret_cast(lhs.p_) <= reinterpret_cast(rhs.p_) ? "true" : "false"; +} + +template +std::string operator>=(CustomPtr const& lhs, CustomPtr const& rhs) +{ + return reinterpret_cast(lhs.p_) >= reinterpret_cast(rhs.p_) ? "true" : "false"; +} + + + SUITE(NotNullTests) { @@ -95,6 +145,101 @@ SUITE(NotNullTests) int* q = nullptr; CHECK_THROW(p = q, fail_fast); } + + TEST(TestNotNullRawPointerComparison) + { + int ints[2] = {42, 43}; + int* p1 = &ints[0]; + const int* p2 = &ints[1]; + + using NotNull1 = not_null; + using NotNull2 = not_null; + + CHECK((NotNull1(p1) == NotNull1(p1)) == true); + CHECK((NotNull1(p1) == NotNull2(p2)) == false); + + CHECK((NotNull1(p1) != NotNull1(p1)) == false); + CHECK((NotNull1(p1) != NotNull2(p2)) == true); + + CHECK((NotNull1(p1) < NotNull1(p1)) == false); + CHECK((NotNull1(p1) < NotNull2(p2)) == (p1 < p2)); + CHECK((NotNull2(p2) < NotNull1(p1)) == (p2 < p1)); + + CHECK((NotNull1(p1) > NotNull1(p1)) == false); + CHECK((NotNull1(p1) > NotNull2(p2)) == (p1 > p2)); + CHECK((NotNull2(p2) > NotNull1(p1)) == (p2 > p1)); + + CHECK((NotNull1(p1) <= NotNull1(p1)) == true); + CHECK((NotNull1(p1) <= NotNull2(p2)) == (p1 <= p2)); + CHECK((NotNull2(p2) <= NotNull1(p1)) == (p2 <= p1)); + + CHECK((NotNull1(p1) >= NotNull1(p1)) == true); + CHECK((NotNull1(p1) >= NotNull2(p2)) == (p1 >= p2)); + CHECK((NotNull2(p2) >= NotNull1(p1)) == (p2 >= p1)); + } + + TEST(TestNotNullSharedPtrComparison) + { + auto sp1 = std::make_shared(42); + auto sp2 = std::make_shared(43); + + using NotNullSp1 = not_null; + using NotNullSp2 = not_null; + + CHECK((NotNullSp1(sp1) == NotNullSp1(sp1)) == true); + CHECK((NotNullSp1(sp1) == NotNullSp2(sp2)) == false); + + CHECK((NotNullSp1(sp1) != NotNullSp1(sp1)) == false); + CHECK((NotNullSp1(sp1) != NotNullSp2(sp2)) == true); + + CHECK((NotNullSp1(sp1) < NotNullSp1(sp1)) == false); + CHECK((NotNullSp1(sp1) < NotNullSp2(sp2)) == (sp1 < sp2)); + CHECK((NotNullSp2(sp2) < NotNullSp1(sp1)) == (sp2 < sp1)); + + CHECK((NotNullSp1(sp1) > NotNullSp1(sp1)) == false); + CHECK((NotNullSp1(sp1) > NotNullSp2(sp2)) == (sp1 > sp2)); + CHECK((NotNullSp2(sp2) > NotNullSp1(sp1)) == (sp2 > sp1)); + + CHECK((NotNullSp1(sp1) <= NotNullSp1(sp1)) == true); + CHECK((NotNullSp1(sp1) <= NotNullSp2(sp2)) == (sp1 <= sp2)); + CHECK((NotNullSp2(sp2) <= NotNullSp1(sp1)) == (sp2 <= sp1)); + + CHECK((NotNullSp1(sp1) >= NotNullSp1(sp1)) == true); + CHECK((NotNullSp1(sp1) >= NotNullSp2(sp2)) == (sp1 >= sp2)); + CHECK((NotNullSp2(sp2) >= NotNullSp1(sp1)) == (sp2 >= sp1)); + } + + TEST(TestNotNullCustomPtrComparison) + { + int ints[2] = { 42, 43 }; + CustomPtr p1(&ints[0]); + CustomPtr p2(&ints[1]); + + using NotNull1 = not_null; + using NotNull2 = not_null; + + CHECK((NotNull1(p1) == NotNull1(p1)) == "true"); + CHECK((NotNull1(p1) == NotNull2(p2)) == "false"); + + CHECK((NotNull1(p1) != NotNull1(p1)) == "false"); + CHECK((NotNull1(p1) != NotNull2(p2)) == "true"); + + CHECK((NotNull1(p1) < NotNull1(p1)) == "false"); + CHECK((NotNull1(p1) < NotNull2(p2)) == (p1 < p2)); + CHECK((NotNull2(p2) < NotNull1(p1)) == (p2 < p1)); + + CHECK((NotNull1(p1) > NotNull1(p1)) == "false"); + CHECK((NotNull1(p1) > NotNull2(p2)) == (p1 > p2)); + CHECK((NotNull2(p2) > NotNull1(p1)) == (p2 > p1)); + + CHECK((NotNull1(p1) <= NotNull1(p1)) == "true"); + CHECK((NotNull1(p1) <= NotNull2(p2)) == (p1 <= p2)); + CHECK((NotNull2(p2) <= NotNull1(p1)) == (p2 <= p1)); + + CHECK((NotNull1(p1) >= NotNull1(p1)) == "true"); + CHECK((NotNull1(p1) >= NotNull2(p2)) == (p1 >= p2)); + CHECK((NotNull2(p2) >= NotNull1(p1)) == (p2 >= p1)); + } } int main(int, const char *[])