diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 24ada37807..03802f3853 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -27,6 +27,7 @@ limitations under the License. */ #endif // PADDLE_WITH_CUDA #include +#include #include #include #include @@ -307,7 +308,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ __LINE__); \ - } while (false) + } while (0) /* * Some enforce helpers here, usage: @@ -366,28 +367,72 @@ using CommonType1 = typename std::add_lvalue_reference< template using CommonType2 = typename std::add_lvalue_reference< typename std::add_const::Type2>::type>::type; + +// Here, we use SFINAE to check whether T can be converted to std::string +template +struct CanToString { + private: + using YesType = uint8_t; + using NoType = uint16_t; + + template + static YesType Check(decltype(std::cout << std::declval())) { + return 0; + } + + template + static NoType Check(...) { + return 0; + } + + public: + static constexpr bool kValue = + std::is_same(std::cout))>::value; +}; + +template +struct BinaryCompareMessageConverter { + template + static std::string Convert(const char* expression, const T& value) { + return expression + std::string(":") + string::to_string(value); + } +}; + +template <> +struct BinaryCompareMessageConverter { + template + static const char* Convert(const char* expression, const T& value) { + return expression; + } +}; + } // namespace details -#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \ - do { \ - auto __val1 = (__VAL1); \ - auto __val2 = (__VAL2); \ - using __TYPE1__ = decltype(__val1); \ - using __TYPE2__ = decltype(__val2); \ - using __COMMON_TYPE1__ = \ - ::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \ - using __COMMON_TYPE2__ = \ - ::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \ - bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \ - static_cast<__COMMON_TYPE2__>(__val2)); \ - if (UNLIKELY(!__is_not_error)) { \ - PADDLE_THROW("Enforce failed. Expected %s " #__CMP \ - " %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \ - #__VAL1, #__VAL2, #__VAL1, \ - ::paddle::string::to_string(__val1), #__VAL2, \ - ::paddle::string::to_string(__val2), \ - ::paddle::string::Sprintf(__VA_ARGS__)); \ - } \ +#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \ + do { \ + auto __val1 = (__VAL1); \ + auto __val2 = (__VAL2); \ + using __TYPE1__ = decltype(__val1); \ + using __TYPE2__ = decltype(__val2); \ + using __COMMON_TYPE1__ = \ + ::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \ + using __COMMON_TYPE2__ = \ + ::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \ + bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \ + static_cast<__COMMON_TYPE2__>(__val2)); \ + if (UNLIKELY(!__is_not_error)) { \ + constexpr bool __kCanToString__ = \ + ::paddle::platform::details::CanToString<__TYPE1__>::kValue && \ + ::paddle::platform::details::CanToString<__TYPE2__>::kValue; \ + PADDLE_THROW("Enforce failed. Expected %s " #__CMP \ + " %s, but received %s " #__INV_CMP " %s.\n%s", \ + #__VAL1, #__VAL2, \ + ::paddle::platform::details::BinaryCompareMessageConverter< \ + __kCanToString__>::Convert(#__VAL1, __val1), \ + ::paddle::platform::details::BinaryCompareMessageConverter< \ + __kCanToString__>::Convert(#__VAL2, __val2), \ + ::paddle::string::Sprintf(__VA_ARGS__)); \ + } \ } while (0) #define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index ceba13b4d6..4e34f3cbf5 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -11,7 +11,9 @@ limitations under the License. */ #include #include +#include #include +#include #include "gtest/gtest.h" #include "paddle/fluid/platform/enforce.h" @@ -296,3 +298,64 @@ TEST(enforce, cuda_success) { #endif } #endif + +struct CannotToStringType { + explicit CannotToStringType(int num) : num_(num) {} + + bool operator==(const CannotToStringType& other) const { + return num_ == other.num_; + } + + bool operator!=(const CannotToStringType& other) const { + return num_ != other.num_; + } + + private: + int num_; +}; + +TEST(enforce, cannot_to_string_type) { + static_assert( + !paddle::platform::details::CanToString::kValue, + "CannotToStringType must not be converted to string"); + static_assert(paddle::platform::details::CanToString::kValue, + "int can be converted to string"); + CannotToStringType obj1(3), obj2(4), obj3(3); + + PADDLE_ENFORCE_NE(obj1, obj2, "Object 1 is not equal to Object 2"); + PADDLE_ENFORCE_EQ(obj1, obj3, "Object 1 is equal to Object 3"); + + std::string msg = "Compare obj1 with obj2"; + try { + PADDLE_ENFORCE_EQ(obj1, obj2, msg); + } catch (paddle::platform::EnforceNotMet& error) { + std::string ex_msg = error.what(); + LOG(INFO) << ex_msg; + EXPECT_TRUE(ex_msg.find(msg) != std::string::npos); + EXPECT_TRUE( + ex_msg.find("Expected obj1 == obj2, but received obj1 != obj2") != + std::string::npos); + } + + msg = "Compare x with y"; + try { + int x = 3, y = 2; + PADDLE_ENFORCE_EQ(x, y, msg); + } catch (paddle::platform::EnforceNotMet& error) { + std::string ex_msg = error.what(); + LOG(INFO) << ex_msg; + EXPECT_TRUE(ex_msg.find(msg) != std::string::npos); + EXPECT_TRUE(ex_msg.find("Expected x == y, but received x:3 != y:2") != + std::string::npos); + } + + std::set set; + PADDLE_ENFORCE_EQ(set.begin(), set.end()); + set.insert(3); + PADDLE_ENFORCE_NE(set.begin(), set.end()); + + std::list list; + PADDLE_ENFORCE_EQ(list.begin(), list.end()); + list.push_back(4); + PADDLE_ENFORCE_NE(list.begin(), list.end()); +}