|
|
|
@ -27,6 +27,7 @@ limitations under the License. */
|
|
|
|
|
#endif // PADDLE_WITH_CUDA
|
|
|
|
|
|
|
|
|
|
#include <iomanip>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <stdexcept>
|
|
|
|
@ -307,7 +308,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
|
|
|
|
|
do { \
|
|
|
|
|
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
|
|
|
|
|
__LINE__); \
|
|
|
|
|
} while (false)
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* Some enforce helpers here, usage:
|
|
|
|
@ -366,28 +367,72 @@ using CommonType1 = typename std::add_lvalue_reference<
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
using CommonType2 = typename std::add_lvalue_reference<
|
|
|
|
|
typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;
|
|
|
|
|
|
|
|
|
|
// Here, we use SFINAE to check whether T can be converted to std::string
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct CanToString {
|
|
|
|
|
private:
|
|
|
|
|
using YesType = uint8_t;
|
|
|
|
|
using NoType = uint16_t;
|
|
|
|
|
|
|
|
|
|
template <typename U>
|
|
|
|
|
static YesType Check(decltype(std::cout << std::declval<U>())) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename U>
|
|
|
|
|
static NoType Check(...) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
static constexpr bool kValue =
|
|
|
|
|
std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <bool kCanToString /* = true */>
|
|
|
|
|
struct BinaryCompareMessageConverter {
|
|
|
|
|
template <typename T>
|
|
|
|
|
static std::string Convert(const char* expression, const T& value) {
|
|
|
|
|
return expression + std::string(":") + string::to_string(value);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct BinaryCompareMessageConverter<false> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
static const char* Convert(const char* expression, const T& value) {
|
|
|
|
|
return expression;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|
|
|
|
|
|
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
|
|
|
|
|
do { \
|
|
|
|
|
auto __val1 = (__VAL1); \
|
|
|
|
|
auto __val2 = (__VAL2); \
|
|
|
|
|
using __TYPE1__ = decltype(__val1); \
|
|
|
|
|
using __TYPE2__ = decltype(__val2); \
|
|
|
|
|
using __COMMON_TYPE1__ = \
|
|
|
|
|
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
|
|
|
|
|
using __COMMON_TYPE2__ = \
|
|
|
|
|
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
|
|
|
|
|
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
|
|
|
|
|
static_cast<__COMMON_TYPE2__>(__val2)); \
|
|
|
|
|
if (UNLIKELY(!__is_not_error)) { \
|
|
|
|
|
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
|
|
|
|
|
" %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \
|
|
|
|
|
#__VAL1, #__VAL2, #__VAL1, \
|
|
|
|
|
::paddle::string::to_string(__val1), #__VAL2, \
|
|
|
|
|
::paddle::string::to_string(__val2), \
|
|
|
|
|
::paddle::string::Sprintf(__VA_ARGS__)); \
|
|
|
|
|
} \
|
|
|
|
|
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
|
|
|
|
|
do { \
|
|
|
|
|
auto __val1 = (__VAL1); \
|
|
|
|
|
auto __val2 = (__VAL2); \
|
|
|
|
|
using __TYPE1__ = decltype(__val1); \
|
|
|
|
|
using __TYPE2__ = decltype(__val2); \
|
|
|
|
|
using __COMMON_TYPE1__ = \
|
|
|
|
|
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
|
|
|
|
|
using __COMMON_TYPE2__ = \
|
|
|
|
|
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
|
|
|
|
|
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
|
|
|
|
|
static_cast<__COMMON_TYPE2__>(__val2)); \
|
|
|
|
|
if (UNLIKELY(!__is_not_error)) { \
|
|
|
|
|
constexpr bool __kCanToString__ = \
|
|
|
|
|
::paddle::platform::details::CanToString<__TYPE1__>::kValue && \
|
|
|
|
|
::paddle::platform::details::CanToString<__TYPE2__>::kValue; \
|
|
|
|
|
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
|
|
|
|
|
" %s, but received %s " #__INV_CMP " %s.\n%s", \
|
|
|
|
|
#__VAL1, #__VAL2, \
|
|
|
|
|
::paddle::platform::details::BinaryCompareMessageConverter< \
|
|
|
|
|
__kCanToString__>::Convert(#__VAL1, __val1), \
|
|
|
|
|
::paddle::platform::details::BinaryCompareMessageConverter< \
|
|
|
|
|
__kCanToString__>::Convert(#__VAL2, __val2), \
|
|
|
|
|
::paddle::string::Sprintf(__VA_ARGS__)); \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
|
|
|
|
|