|
|
|
@ -484,46 +484,5 @@ struct BinaryCompareMessageConverter<false> {
|
|
|
|
|
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
#define __PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL1, __VAL2, __CMP, \
|
|
|
|
|
__INV_CMP, ...) \
|
|
|
|
|
do { \
|
|
|
|
|
auto __val1 = (__VAL1); \
|
|
|
|
|
auto __val2 = (__VAL2); \
|
|
|
|
|
if (!__CTX->IsRuntime()) { \
|
|
|
|
|
if (__val1 == -1 || __val2 == -1) { \
|
|
|
|
|
break; \
|
|
|
|
|
} \
|
|
|
|
|
} \
|
|
|
|
|
using __TYPE1__ = decltype(__val1); \
|
|
|
|
|
using __TYPE2__ = decltype(__val2); \
|
|
|
|
|
using __COMMON_TYPE1__ = \
|
|
|
|
|
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
|
|
|
|
|
using __COMMON_TYPE2__ = \
|
|
|
|
|
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
|
|
|
|
|
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
|
|
|
|
|
static_cast<__COMMON_TYPE2__>(__val2)); \
|
|
|
|
|
if (UNLIKELY(!__is_not_error)) { \
|
|
|
|
|
PADDLE_THROW("Expected %s " #__CMP " %s, but received %s:%s " #__INV_CMP \
|
|
|
|
|
" %s:%s.\n%s", \
|
|
|
|
|
#__VAL1, #__VAL2, #__VAL1, \
|
|
|
|
|
::paddle::string::to_string(__val1), #__VAL2, \
|
|
|
|
|
::paddle::string::to_string(__val2), \
|
|
|
|
|
::paddle::string::Sprintf(__VA_ARGS__)); \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
#define PADDLE_INFERSHAPE_ENFORCE_EQ(__CTX, __VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, ==, !=, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_INFERSHAPE_ENFORCE_NE(__CTX, __VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, !=, ==, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_INFERSHAPE_ENFORCE_GT(__CTX, __VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >, <=, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_INFERSHAPE_ENFORCE_GE(__CTX, __VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >=, <, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_INFERSHAPE_ENFORCE_LT(__CTX, __VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <, >=, __VA_ARGS__)
|
|
|
|
|
#define PADDLE_INFERSHAPE_ENFORCE_LE(__CTX, __VAL0, __VAL1, ...) \
|
|
|
|
|
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <=, >, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|
|
|
|
|