diff --git a/include/gsl.h b/include/gsl.h index 1357d76..b78c6e2 100644 --- a/include/gsl.h +++ b/include/gsl.h @@ -162,7 +162,11 @@ private: template class maybe_null_dbg { + template + friend class maybe_null_dbg; public: + static_assert(std::is_constructible::value, "maybe_null's template parameter must be constructible from nullptr"); + maybe_null_dbg() : ptr_(nullptr), tested_(false) {} maybe_null_dbg(const T& p) : ptr_(p), tested_(false) {} @@ -202,8 +206,10 @@ public: bool operator==(const T& rhs) const { tested_ = true; return ptr_ == rhs; } bool operator!=(const T& rhs) const { return !(*this == rhs); } - bool operator==(const maybe_null_dbg& rhs) const { tested_ = true; rhs.tested_ = true; return ptr_ == rhs.ptr_; } - bool operator!=(const maybe_null_dbg& rhs) const { return !(*this == rhs); } + template ::value>> + bool operator==(const maybe_null_dbg& rhs) const { tested_ = true; rhs.tested_ = true; return ptr_ == rhs.ptr_; } + template ::value>> + bool operator!=(const maybe_null_dbg& rhs) const { return !(*this == rhs); } T get() const { fail_fast_assert(tested_); @@ -217,8 +223,6 @@ public: T operator->() const { return get(); } private: - const size_t ptee_size_ = sizeof(*ptr_); // T must be a pointer type - // unwanted operators...pointers only point to single objects! // TODO ensure all arithmetic ops on this type are unavailable maybe_null_dbg& operator++() = delete; @@ -238,6 +242,8 @@ template class maybe_null_ret { public: + static_assert(std::is_constructible::value, "maybe_null's template parameter must be constructible from nullptr"); + maybe_null_ret() : ptr_(nullptr) {} maybe_null_ret(std::nullptr_t) : ptr_(nullptr) {} maybe_null_ret(const T& p) : ptr_(p) {} @@ -280,7 +286,6 @@ private: maybe_null_ret& operator-(size_t) = delete; maybe_null_ret& operator-=(size_t) = delete; - const size_t ptee_size_ = sizeof(*ptr_); // T must be a pointer type T ptr_; }; diff --git a/tests/maybenull_tests.cpp b/tests/maybenull_tests.cpp index 0a9d891..e1244fd 100644 --- a/tests/maybenull_tests.cpp +++ b/tests/maybenull_tests.cpp @@ -241,6 +241,20 @@ SUITE(MaybeNullTests) // Make sure we no longer throw here CHECK(p1.get() != nullptr); } + + TEST(TestMaybeNullPtrT) + { + maybe_null p1; + maybe_null p2; + + CHECK_THROW(p1.get(), fail_fast); + + CHECK(p1 == p2); + + // Make sure we no longer throw here + CHECK(p1.get() == nullptr); + CHECK(p2.get() == nullptr); + } } int main(int, const char *[])