|
|
|
@ -36,6 +36,21 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
struct EnforceNotMet : public std::exception {
|
|
|
|
|
std::exception_ptr exp_;
|
|
|
|
|
std::string err_str_;
|
|
|
|
|
|
|
|
|
|
EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
|
|
|
|
|
try {
|
|
|
|
|
std::rethrow_exception(exp_);
|
|
|
|
|
} catch (const std::exception& exp) {
|
|
|
|
|
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char* what() const noexcept { return err_str_.c_str(); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Because most enforce conditions would evaluate to true, we can use
|
|
|
|
|
// __builtin_expect to instruct the C++ compiler to generate code that
|
|
|
|
|
// always forces branch prediction of true.
|
|
|
|
@ -52,9 +67,7 @@ template <typename... Args>
|
|
|
|
|
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
|
|
|
|
|
int stat, const Args&... args) {
|
|
|
|
|
if (UNLIKELY(!(stat))) {
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
string::Sprintf(args...) +
|
|
|
|
|
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
|
|
|
|
|
throw std::runtime_error(string::Sprintf(args...));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -64,12 +77,8 @@ template <typename... Args>
|
|
|
|
|
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
|
|
|
|
|
cudaError_t e, const Args&... args) {
|
|
|
|
|
if (UNLIKELY(e)) {
|
|
|
|
|
// clang-format off
|
|
|
|
|
throw thrust::system_error(
|
|
|
|
|
e, thrust::cuda_category(),
|
|
|
|
|
string::Sprintf(args...) +
|
|
|
|
|
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
|
|
|
|
|
// clang-format on
|
|
|
|
|
throw thrust::system_error(e, thrust::cuda_category(),
|
|
|
|
|
string::Sprintf(args...));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -77,12 +86,8 @@ template <typename... Args>
|
|
|
|
|
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
|
|
|
|
|
curandStatus_t stat, const Args&... args) {
|
|
|
|
|
if (stat != CURAND_STATUS_SUCCESS) {
|
|
|
|
|
// clang-format off
|
|
|
|
|
throw thrust::system_error(
|
|
|
|
|
cudaErrorLaunchFailure, thrust::cuda_category(),
|
|
|
|
|
string::Sprintf(args...) +
|
|
|
|
|
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
|
|
|
|
|
// clang-format on
|
|
|
|
|
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
|
|
|
|
|
string::Sprintf(args...));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
|
|
|
|
|
if (stat == CUDNN_STATUS_SUCCESS) {
|
|
|
|
|
return;
|
|
|
|
|
} else {
|
|
|
|
|
// clang-format off
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
platform::dynload::cudnnGetErrorString(stat) +
|
|
|
|
|
string::Sprintf(args...) +
|
|
|
|
|
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
|
|
|
|
|
// clang-format on
|
|
|
|
|
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
|
|
|
|
|
string::Sprintf(args...));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
|
|
|
|
|
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
|
|
|
|
|
err = "CUBLAS: license error, ";
|
|
|
|
|
}
|
|
|
|
|
throw std::runtime_error(err + string::Sprintf(args...) +
|
|
|
|
|
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
|
|
|
|
|
throw std::runtime_error(err + string::Sprintf(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif // PADDLE_ONLY_CPU
|
|
|
|
|
|
|
|
|
|
#define PADDLE_THROW(...) \
|
|
|
|
|
do { \
|
|
|
|
|
throw std::runtime_error( \
|
|
|
|
|
string::Sprintf(__VA_ARGS__) + \
|
|
|
|
|
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
|
|
|
|
|
#define PADDLE_THROW(...) \
|
|
|
|
|
do { \
|
|
|
|
|
throw ::paddle::platform::EnforceNotMet( \
|
|
|
|
|
std::make_exception_ptr( \
|
|
|
|
|
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
|
|
|
|
|
__FILE__, __LINE__); \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
#define PADDLE_ENFORCE(...) \
|
|
|
|
|
do { \
|
|
|
|
|
::paddle::platform::throw_on_error(__VA_ARGS__); \
|
|
|
|
|
#define PADDLE_ENFORCE(...) \
|
|
|
|
|
do { \
|
|
|
|
|
try { \
|
|
|
|
|
::paddle::platform::throw_on_error(__VA_ARGS__); \
|
|
|
|
|
} catch (...) { \
|
|
|
|
|
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
|
|
|
|
|
__FILE__, __LINE__); \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
|