/////////////////////////////////////////////////////////////////////////////// // // Copyright (c) 2015 Microsoft Corporation. All rights reserved. // // This code is licensed under the MIT License (MIT). // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // /////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include using namespace gsl; struct MyBase { }; struct MyDerived : public MyBase { }; struct Unrelated { }; // stand-in for a user-defined ref-counted class template struct RefCounted { RefCounted(T* p) : p_(p) {} operator T*() { return p_; } T* p_; }; // user defined smart pointer with comparison operators returning non bool value template struct CustomPtr { CustomPtr(T* p) : p_(p) {} operator T*() { return p_; } bool operator!=(std::nullptr_t) const { return p_ != nullptr; } T* p_ = nullptr; }; template std::string operator==(CustomPtr const& lhs, CustomPtr const& rhs) { return reinterpret_cast(lhs.p_) == reinterpret_cast(rhs.p_) ? "true" : "false"; } template std::string operator!=(CustomPtr const& lhs, CustomPtr const& rhs) { return reinterpret_cast(lhs.p_) != reinterpret_cast(rhs.p_) ? "true" : "false"; } template std::string operator<(CustomPtr const& lhs, CustomPtr const& rhs) { return reinterpret_cast(lhs.p_) < reinterpret_cast(rhs.p_) ? "true" : "false"; } template std::string operator>(CustomPtr const& lhs, CustomPtr const& rhs) { return reinterpret_cast(lhs.p_) > reinterpret_cast(rhs.p_) ? "true" : "false"; } template std::string operator<=(CustomPtr const& lhs, CustomPtr const& rhs) { return reinterpret_cast(lhs.p_) <= reinterpret_cast(rhs.p_) ? "true" : "false"; } template std::string operator>=(CustomPtr const& lhs, CustomPtr const& rhs) { return reinterpret_cast(lhs.p_) >= reinterpret_cast(rhs.p_) ? "true" : "false"; } bool helper(not_null p) { return *p == 12; } TEST_CASE("TestNotNullConstructors") { #ifdef CONFIRM_COMPILATION_ERRORS not_null p = nullptr; // yay...does not compile! not_null*> p = 0; // yay...does not compile! not_null p; // yay...does not compile! std::unique_ptr up = std::make_unique(120); not_null p = up; // Forbid non-nullptr assignable types not_null> f(std::vector{1}); not_null z(10); not_null> y({1, 2}); #endif int i = 12; auto rp = RefCounted(&i); not_null p(rp); CHECK(p.get() == &i); not_null> x( std::make_shared(10)); // shared_ptr is nullptr assignable } TEST_CASE("TestNotNullCasting") { MyBase base; MyDerived derived; Unrelated unrelated; not_null u = &unrelated; (void) u; not_null p = &derived; not_null q = &base; q = p; // allowed with heterogeneous copy ctor CHECK(q == p); #ifdef CONFIRM_COMPILATION_ERRORS q = u; // no viable conversion possible between MyBase* and Unrelated* p = q; // not possible to implicitly convert MyBase* to MyDerived* not_null r = p; not_null s = reinterpret_cast(p); #endif not_null t = reinterpret_cast(p.get()); CHECK(reinterpret_cast(p.get()) == reinterpret_cast(t.get())); } TEST_CASE("TestNotNullAssignment") { int i = 12; not_null p = &i; CHECK(helper(p)); int* q = nullptr; CHECK_THROWS_AS(p = q, fail_fast); } TEST_CASE("TestNotNullRawPointerComparison") { int ints[2] = {42, 43}; int* p1 = &ints[0]; const int* p2 = &ints[1]; using NotNull1 = not_null; using NotNull2 = not_null; CHECK((NotNull1(p1) == NotNull1(p1)) == true); CHECK((NotNull1(p1) == NotNull2(p2)) == false); CHECK((NotNull1(p1) != NotNull1(p1)) == false); CHECK((NotNull1(p1) != NotNull2(p2)) == true); CHECK((NotNull1(p1) < NotNull1(p1)) == false); CHECK((NotNull1(p1) < NotNull2(p2)) == (p1 < p2)); CHECK((NotNull2(p2) < NotNull1(p1)) == (p2 < p1)); CHECK((NotNull1(p1) > NotNull1(p1)) == false); CHECK((NotNull1(p1) > NotNull2(p2)) == (p1 > p2)); CHECK((NotNull2(p2) > NotNull1(p1)) == (p2 > p1)); CHECK((NotNull1(p1) <= NotNull1(p1)) == true); CHECK((NotNull1(p1) <= NotNull2(p2)) == (p1 <= p2)); CHECK((NotNull2(p2) <= NotNull1(p1)) == (p2 <= p1)); CHECK((NotNull1(p1) >= NotNull1(p1)) == true); CHECK((NotNull1(p1) >= NotNull2(p2)) == (p1 >= p2)); CHECK((NotNull2(p2) >= NotNull1(p1)) == (p2 >= p1)); } TEST_CASE("TestNotNullSharedPtrComparison") { auto sp1 = std::make_shared(42); auto sp2 = std::make_shared(43); using NotNullSp1 = not_null; using NotNullSp2 = not_null; CHECK((NotNullSp1(sp1) == NotNullSp1(sp1)) == true); CHECK((NotNullSp1(sp1) == NotNullSp2(sp2)) == false); CHECK((NotNullSp1(sp1) != NotNullSp1(sp1)) == false); CHECK((NotNullSp1(sp1) != NotNullSp2(sp2)) == true); CHECK((NotNullSp1(sp1) < NotNullSp1(sp1)) == false); CHECK((NotNullSp1(sp1) < NotNullSp2(sp2)) == (sp1 < sp2)); CHECK((NotNullSp2(sp2) < NotNullSp1(sp1)) == (sp2 < sp1)); CHECK((NotNullSp1(sp1) > NotNullSp1(sp1)) == false); CHECK((NotNullSp1(sp1) > NotNullSp2(sp2)) == (sp1 > sp2)); CHECK((NotNullSp2(sp2) > NotNullSp1(sp1)) == (sp2 > sp1)); CHECK((NotNullSp1(sp1) <= NotNullSp1(sp1)) == true); CHECK((NotNullSp1(sp1) <= NotNullSp2(sp2)) == (sp1 <= sp2)); CHECK((NotNullSp2(sp2) <= NotNullSp1(sp1)) == (sp2 <= sp1)); CHECK((NotNullSp1(sp1) >= NotNullSp1(sp1)) == true); CHECK((NotNullSp1(sp1) >= NotNullSp2(sp2)) == (sp1 >= sp2)); CHECK((NotNullSp2(sp2) >= NotNullSp1(sp1)) == (sp2 >= sp1)); } TEST_CASE("TestNotNullCustomPtrComparison") { int ints[2] = {42, 43}; CustomPtr p1(&ints[0]); CustomPtr p2(&ints[1]); using NotNull1 = not_null; using NotNull2 = not_null; CHECK((NotNull1(p1) == NotNull1(p1)) == "true"); CHECK((NotNull1(p1) == NotNull2(p2)) == "false"); CHECK((NotNull1(p1) != NotNull1(p1)) == "false"); CHECK((NotNull1(p1) != NotNull2(p2)) == "true"); CHECK((NotNull1(p1) < NotNull1(p1)) == "false"); CHECK((NotNull1(p1) < NotNull2(p2)) == (p1 < p2)); CHECK((NotNull2(p2) < NotNull1(p1)) == (p2 < p1)); CHECK((NotNull1(p1) > NotNull1(p1)) == "false"); CHECK((NotNull1(p1) > NotNull2(p2)) == (p1 > p2)); CHECK((NotNull2(p2) > NotNull1(p1)) == (p2 > p1)); CHECK((NotNull1(p1) <= NotNull1(p1)) == "true"); CHECK((NotNull1(p1) <= NotNull2(p2)) == (p1 <= p2)); CHECK((NotNull2(p2) <= NotNull1(p1)) == (p2 <= p1)); CHECK((NotNull1(p1) >= NotNull1(p1)) == "true"); CHECK((NotNull1(p1) >= NotNull2(p2)) == (p1 >= p2)); CHECK((NotNull2(p2) >= NotNull1(p1)) == (p2 >= p1)); }