|
|
|
@ -162,11 +162,50 @@ inline void throw_on_error(T e) {
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
|
|
|
|
|
PADDLE_ENFORCE((__VAL0) == (__VAL1), \
|
|
|
|
|
"enforce %s == %s failed, %s != %s\n%s", #__VAL0, #__VAL1, \
|
|
|
|
|
std::to_string(__VAL0), std::to_string(__VAL1), \
|
|
|
|
|
/*
|
|
|
|
|
* Some enforce helpers here, usage:
|
|
|
|
|
* int a = 1;
|
|
|
|
|
* int b = 2;
|
|
|
|
|
* PADDLE_ENFORCE_EQ(a, b);
|
|
|
|
|
*
|
|
|
|
|
* will raise an expression described as follows:
|
|
|
|
|
* "enforce a == b failed, 1 != 2" with detailed stack infomation.
|
|
|
|
|
*
|
|
|
|
|
* extra messages is also supported, for example:
|
|
|
|
|
* PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
// if two values have different data types, choose a compatible type for them.
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
struct CompatibleType {
|
|
|
|
|
static constexpr const bool& t1_to_t2 = std::is_convertible<T1, T2>::value;
|
|
|
|
|
typedef typename std::conditional<t1_to_t2, T2, T1>::type type;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
|
|
|
|
|
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \
|
|
|
|
|
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
|
|
|
|
|
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
|
|
|
|
|
#__VAL0, #__VAL1, std::to_string(__VAL0), \
|
|
|
|
|
std::to_string(__VAL1), \
|
|
|
|
|
paddle::string::Sprintf("" __VA_ARGS__));
|
|
|
|
|
|
|
|
|
|
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
|
|
|
|
|
typename paddle::platform::CompatibleType<decltype(__VAL0), \
|
|
|
|
|
decltype(__VAL1)>::type(__VAL)
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|
|
|
|
|