|
|
|
@ -54,22 +54,26 @@ DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size)
|
|
|
|
|
#define WARPCTC_GET_VERSION dynload::get_warpctc_version
|
|
|
|
|
#define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString
|
|
|
|
|
|
|
|
|
|
static int g_warpctcVersion = -1;
|
|
|
|
|
#ifndef PADDLE_TYPE_DOUBLE
|
|
|
|
|
#define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss
|
|
|
|
|
#define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size
|
|
|
|
|
#else
|
|
|
|
|
#define WARPCTC_LOG_FATAL \
|
|
|
|
|
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion \
|
|
|
|
|
<< "] Error: not support double precision."
|
|
|
|
|
#define WARPCTC_COMPUTE_LOSS(...) WARPCTC_LOG_FATAL(__VA_ARGS__)
|
|
|
|
|
#define WARPCTC_GET_WORKSPACE_SIZE(...) WARPCTC_LOG_FATAL(__VA_ARGS__)
|
|
|
|
|
ctcStatus_t fatal(...) {
|
|
|
|
|
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion
|
|
|
|
|
<< "] Error: not support double precision.";
|
|
|
|
|
// both of get_warpctc_version() and get_workspace_size() return an ctcStatus
|
|
|
|
|
// type value
|
|
|
|
|
return CTC_STATUS_EXECUTION_FAILED;
|
|
|
|
|
}
|
|
|
|
|
#define WARPCTC_COMPUTE_LOSS fatal
|
|
|
|
|
#define WARPCTC_GET_WORKSPACE_SIZE fatal
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Check build-in warp-ctc function using glog and it also
|
|
|
|
|
* support << operator for more details error info.
|
|
|
|
|
*/
|
|
|
|
|
static int g_warpctcVersion = -1;
|
|
|
|
|
#define CHECK_WARPCTC(warpctcStat) \
|
|
|
|
|
CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \
|
|
|
|
|
<< "warp-ctc [version " << g_warpctcVersion \
|
|
|
|
|