|
|
|
@ -22,6 +22,47 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
inline const char* cudnnGetErrorString(cudnnStatus_t status) {
|
|
|
|
|
switch (status) {
|
|
|
|
|
case CUDNN_STATUS_SUCCESS:
|
|
|
|
|
return "CUDNN_STATUS_SUCCESS";
|
|
|
|
|
case CUDNN_STATUS_NOT_INITIALIZED:
|
|
|
|
|
return "CUDNN_STATUS_NOT_INITIALIZED";
|
|
|
|
|
case CUDNN_STATUS_ALLOC_FAILED:
|
|
|
|
|
return "CUDNN_STATUS_ALLOC_FAILED";
|
|
|
|
|
case CUDNN_STATUS_BAD_PARAM:
|
|
|
|
|
return "CUDNN_STATUS_BAD_PARAM";
|
|
|
|
|
case CUDNN_STATUS_INTERNAL_ERROR:
|
|
|
|
|
return "CUDNN_STATUS_INTERNAL_ERROR";
|
|
|
|
|
case CUDNN_STATUS_INVALID_VALUE:
|
|
|
|
|
return "CUDNN_STATUS_INVALID_VALUE";
|
|
|
|
|
case CUDNN_STATUS_ARCH_MISMATCH:
|
|
|
|
|
return "CUDNN_STATUS_ARCH_MISMATCH";
|
|
|
|
|
case CUDNN_STATUS_MAPPING_ERROR:
|
|
|
|
|
return "CUDNN_STATUS_MAPPING_ERROR";
|
|
|
|
|
case CUDNN_STATUS_EXECUTION_FAILED:
|
|
|
|
|
return "CUDNN_STATUS_EXECUTION_FAILED";
|
|
|
|
|
case CUDNN_STATUS_NOT_SUPPORTED:
|
|
|
|
|
return "CUDNN_STATUS_NOT_SUPPORTED";
|
|
|
|
|
case CUDNN_STATUS_LICENSE_ERROR:
|
|
|
|
|
return "CUDNN_STATUS_LICENSE_ERROR";
|
|
|
|
|
default:
|
|
|
|
|
return "Unknown cudnn error number";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define CUDNN_VERSION_MIN(major, minor, patch) \
|
|
|
|
|
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
|
|
|
|
|
|
|
|
|
|
#define CUDNN_ENFORCE(condition) \
|
|
|
|
|
do { \
|
|
|
|
|
cudnnStatus_t status = condition; \
|
|
|
|
|
if (status != CUDNN_STATUS_SUCCESS) { \
|
|
|
|
|
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \
|
|
|
|
|
PADDLE_THROW("cuDNN call failed"); \
|
|
|
|
|
} \
|
|
|
|
|
} while (false)
|
|
|
|
|
|
|
|
|
|
enum class DataLayout {
|
|
|
|
|
kNHWC,
|
|
|
|
|
kNCHW,
|
|
|
|
@ -40,12 +81,30 @@ template <>
|
|
|
|
|
class CudnnDataType<float> {
|
|
|
|
|
public:
|
|
|
|
|
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
|
|
|
|
|
typedef const float ScalingParamType;
|
|
|
|
|
static ScalingParamType* kOne() {
|
|
|
|
|
static ScalingParamType v = 1.0;
|
|
|
|
|
return &v;
|
|
|
|
|
}
|
|
|
|
|
static ScalingParamType* kZero() {
|
|
|
|
|
static ScalingParamType v = 0.0;
|
|
|
|
|
return &v;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
class CudnnDataType<double> {
|
|
|
|
|
public:
|
|
|
|
|
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
|
|
|
|
|
typedef const double ScalingParamType;
|
|
|
|
|
static ScalingParamType* kOne() {
|
|
|
|
|
static ScalingParamType v = 1.0;
|
|
|
|
|
return &v;
|
|
|
|
|
}
|
|
|
|
|
static ScalingParamType* kZero() {
|
|
|
|
|
static ScalingParamType v = 0.0;
|
|
|
|
|
return &v;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
|
|
|
|
|