!1150 GPU upadate shape infer

Merge pull request !1150 from VectorSL/gpu-update-shapeinfer
pull/1150/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit aacb03563e

@ -184,10 +184,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
if (!result) { if (!result) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
std::string build_type = "in [";
std::for_each(std::begin(inputs_type), std::end(inputs_type),
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
build_type += "] out [";
std::for_each(std::begin(outputs_type), std::end(outputs_type),
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
build_type += "]";
auto supported_type_lists = SupportedTypeList(kernel_node); auto supported_type_lists = SupportedTypeList(kernel_node);
MS_LOG(EXCEPTION) << "Select GPU kernel op[" << kernel_name MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name
<< "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists; << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
<< ", but get " << build_type;
} }
builder->SetKernelType(kernel_type); builder->SetKernelType(kernel_type);
builder->SetProcessor(kernel::Processor::CUDA); builder->SetProcessor(kernel::Processor::CUDA);

@ -178,46 +178,33 @@ class ArrayReduceGpuKernel : public GpuKernel {
return; return;
} }
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) { void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
std::vector<size_t> inputA_shape = input_shape; std::vector<int> inputA;
std::vector<size_t> outputC_shape = output_shape; std::vector<size_t> outputC_shape = output_shape;
std::vector<int> real_input_shape; ShapeNdTo4d(input_shape, &inputA);
int shapeA_n, shapeA_c, shapeA_h, shapeA_w; CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0],
shapeA_n = inputA_shape.size() < 4 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 4]); inputA[1], inputA[2], inputA[3]),
shapeA_c = inputA_shape.size() < 3 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 3]);
shapeA_h = inputA_shape.size() < 2 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 2]);
shapeA_w = inputA_shape.size() == 0 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 1]);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, shapeA_n,
shapeA_c, shapeA_h, shapeA_w),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
int shapeC_n, shapeC_c, shapeC_h, shapeC_w;
if (axis_[0] == -1) { if (axis_[0] == -1) {
shapeC_n = 1; CHECK_CUDNN_RET_WITH_EXCEPT(
shapeC_c = 1; cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
shapeC_h = 1;
shapeC_w = 1;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
shapeC_n, shapeC_c, shapeC_h, shapeC_w),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
if (shapeA_n == shapeC_n && shapeA_c == shapeC_c && shapeA_h == shapeC_h && shapeA_w == shapeC_w) { if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) {
all_match_ = true; all_match_ = true;
} }
return; return;
} }
if (!keep_dims_) { if (!keep_dims_) {
for (auto i : axis_) { for (auto i : axis_) {
(void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); (void)(outputC_shape.insert(outputC_shape.begin() + i, 1));
} }
} }
shapeC_n = outputC_shape.size() < 4 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 4]); std::vector<int> outputC;
shapeC_c = outputC_shape.size() < 3 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 3]); ShapeNdTo4d(outputC_shape, &outputC);
shapeC_h = outputC_shape.size() < 2 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 2]); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
shapeC_w = outputC_shape.size() == 0 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 1]); outputC[0], outputC[1], outputC[2], outputC[3]),
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, shapeC_n,
shapeC_c, shapeC_h, shapeC_w),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
if (shapeA_n == shapeC_n && shapeA_c == shapeC_c && shapeA_h == shapeC_h && shapeA_w == shapeC_w) { if (inputA == outputC) {
all_match_ = true; all_match_ = true;
} }
return; return;

@ -52,15 +52,7 @@ class SliceGpuFwdKernel : public GpuKernel {
return false; return false;
} }
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
int shape_n = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); ShapeNdTo4d(input_shape, &input_shape_);
int shape_c = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]);
int shape_h = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]);
int shape_w = SizeToInt(input_shape[input_shape.size() - 1]);
input_shape_.push_back(shape_n);
input_shape_.push_back(shape_c);
input_shape_.push_back(shape_h);
input_shape_.push_back(shape_w);
auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"); auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides");
if (strides) { if (strides) {
strides_ = GetAttr<std::vector<int>>(kernel_node, "strides"); strides_ = GetAttr<std::vector<int>>(kernel_node, "strides");
@ -89,7 +81,7 @@ class SliceGpuFwdKernel : public GpuKernel {
} }
} }
input_size_ = IntToSize(shape_n * shape_c * shape_h * shape_w) * sizeof(T); input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);
auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = sizeof(T); output_size_ = sizeof(T);

@ -66,19 +66,12 @@ class SliceGradGpuKernel : public GpuKernel {
size_ = GetAttr<std::vector<int>>(kernel_node, "end"); size_ = GetAttr<std::vector<int>>(kernel_node, "end");
} else { } else {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
input_shape_.push_back(input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4])); ShapeNdTo4d(input_shape, &input_shape_);
input_shape_.push_back(input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]));
input_shape_.push_back(input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]));
input_shape_.push_back(SizeToInt(input_shape[input_shape.size() - 1]));
size_ = GetAttr<std::vector<int>>(kernel_node, "size"); size_ = GetAttr<std::vector<int>>(kernel_node, "size");
} }
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
dy_shape_.push_back(dy_shape.size() < 4 ? 1 : SizeToInt(dy_shape[dy_shape.size() - 4])); ShapeNdTo4d(dy_shape, &dy_shape_);
dy_shape_.push_back(dy_shape.size() < 3 ? 1 : SizeToInt(dy_shape[dy_shape.size() - 3]));
dy_shape_.push_back(dy_shape.size() < 2 ? 1 : SizeToInt(dy_shape[dy_shape.size() - 2]));
dy_shape_.push_back(SizeToInt(dy_shape[dy_shape.size() - 1]));
begin_ = GetAttr<std::vector<int>>(kernel_node, "begin"); begin_ = GetAttr<std::vector<int>>(kernel_node, "begin");
DealParam(); DealParam();
input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);

@ -39,7 +39,6 @@ class GpuKernel : public KernelMod {
virtual void InitSizeLists() = 0; virtual void InitSizeLists() = 0;
template <typename T> template <typename T>
inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) { inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
if (index >= addr_list.size()) { if (index >= addr_list.size()) {
MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
@ -62,6 +61,24 @@ class GpuKernel : public KernelMod {
} }
return GetValue<T>(attr); return GetValue<T>(attr);
} }
// expand Nd Shape to 4d (N in [0,4])
void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) {
dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4]));
dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3]));
dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2]));
dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1]));
}
inline void CheckBroadcast4TensorOp(const std::vector<int> &A, const std::vector<int> &B,
const std::vector<int> &Out) {
if (A != Out && B != Out) {
MS_EXCEPTION(ValueError)
<< "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n"
"InputA must match the corresponding dimension of the destination tensor outC, and each "
"dimension of the inputB "
"must match the corresponding dimension of outC or must be equal to 1.";
}
}
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -87,32 +87,24 @@ class TensorAddGpuFwdKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (input_shape != output_shape && input_shapeB != output_shape) {
MS_LOG(ERROR) << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n"
"InputA must match the corresponding dimension of the destination tensor outC, and each "
"dimension of the inputB "
"must match the corresponding dimension of outC or must be equal to 1.";
return false;
}
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB); is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB);
if (is_null_input_) { if (is_null_input_) {
MS_LOG(WARNING) << "TensorAddGpuFwdKernel input is null"; MS_LOG(WARNING) << "TensorAddGpuFwdKernel input is null";
InitSizeLists(); InitSizeLists();
return true; return true;
} }
int shape_n = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); std::vector<int> shapeA;
int shape_c = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); std::vector<int> shapeB;
int shape_h = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]); std::vector<int> shapeOut;
int shape_w = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]); ShapeNdTo4d(input_shape, &shapeA);
ShapeNdTo4d(input_shapeB, &shapeB);
ShapeNdTo4d(output_shape, &shapeOut);
CheckBroadcast4TensorOp(shapeA, shapeB, shapeOut);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape_n, shape_c, shape_h, shape_w), shapeA[0], shapeA[1], shapeA[2], shapeA[3]),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
int shapeB_n = input_shapeB.size() < 4 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 4]);
int shapeB_c = input_shapeB.size() < 3 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 3]);
int shapeB_h = input_shapeB.size() < 2 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 2]);
int shapeB_w = input_shapeB.size() == 0 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 1]);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputB_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputB_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shapeB_n, shapeB_c, shapeB_h, shapeB_w), shapeB[0], shapeB[1], shapeB[2], shapeB[3]),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(

@ -107,8 +107,8 @@ class PoolingGpuFwdKernel : public GpuKernel {
SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
auto window = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); auto window = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize"));
int window_height = window[3]; int window_height = window[2];
int window_width = window[2]; int window_width = window[3];
stride_ = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); stride_ = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
SetPoolingMode(kernel_node); SetPoolingMode(kernel_node);
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {

@ -101,8 +101,8 @@ class PoolingGradGpuFwdKernel : public GpuKernel {
return false; return false;
} }
auto window = GetAttr<std::vector<int>>(kernel_node, "ksize"); auto window = GetAttr<std::vector<int>>(kernel_node, "ksize");
int window_height = window[3]; int window_height = window[2];
int window_width = window[2]; int window_width = window[3];
SetPoolingMode(kernel_node); SetPoolingMode(kernel_node);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);

@ -31,8 +31,7 @@ class ReLUGpuFwdKernel : public GpuKernel {
: cudnn_handle_(nullptr), : cudnn_handle_(nullptr),
activation_desc_(nullptr), activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU), mode_(CUDNN_ACTIVATION_RELU),
input_descriptor_(nullptr), data_descriptor_(nullptr),
output_descriptor_(nullptr),
is_null_input_(false), is_null_input_(false),
cudnn_data_type_(CUDNN_DATA_FLOAT), cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0), input_size_(0),
@ -53,8 +52,8 @@ class ReLUGpuFwdKernel : public GpuKernel {
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, input_descriptor_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input,
input, &beta, output_descriptor_, output), &beta, data_descriptor_, output),
"ReLUGpuFwdKernel failed"); "ReLUGpuFwdKernel failed");
return true; return true;
@ -75,18 +74,12 @@ class ReLUGpuFwdKernel : public GpuKernel {
return true; return true;
} }
mode_ = CUDNN_ACTIVATION_RELU; mode_ = CUDNN_ACTIVATION_RELU;
int batch_size = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); std::vector<int> shape;
int channel_size = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); ShapeNdTo4d(input_shape, &shape);
int height = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]);
int width = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0),
"SetActivationDescriptor failed"); "SetActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
batch_size, channel_size, height, width), shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
batch_size, channel_size, height, width),
"SetTensor4dDescriptor failed"); "SetTensor4dDescriptor failed");
InitSizeLists(); InitSizeLists();
return true; return true;
@ -95,18 +88,16 @@ class ReLUGpuFwdKernel : public GpuKernel {
protected: protected:
void InitResource() override { void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_),
"cudnnCreateActivationDescriptor failed"); "cudnnCreateActivationDescriptor failed");
} }
void InitSizeLists() override { void InitSizeLists() override {
if (!is_null_input_) { if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_descriptor_, &output_size_),
"cudnnGetTensorSizeInBytes failed"); "cudnnGetTensorSizeInBytes failed");
output_size_ = input_size_;
} }
input_size_list_.push_back(input_size_); input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_); output_size_list_.push_back(output_size_);
@ -116,15 +107,13 @@ class ReLUGpuFwdKernel : public GpuKernel {
void DestroyResource() noexcept { void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed"); "cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed");
} }
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_; cudnnActivationDescriptor_t activation_desc_;
cudnnActivationMode_t mode_; cudnnActivationMode_t mode_;
cudnnTensorDescriptor_t input_descriptor_; cudnnTensorDescriptor_t data_descriptor_;
cudnnTensorDescriptor_t output_descriptor_;
bool is_null_input_; bool is_null_input_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;

@ -31,7 +31,7 @@ class ReluGradGpuFwdKernel : public GpuKernel {
: cudnn_handle_(nullptr), : cudnn_handle_(nullptr),
activation_desc_(nullptr), activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU), mode_(CUDNN_ACTIVATION_RELU),
input_descriptor_(nullptr), data_descriptor_(nullptr),
is_null_input_(false), is_null_input_(false),
cudnn_data_type_(CUDNN_DATA_FLOAT), cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0) {} input_size_(0) {}
@ -52,8 +52,8 @@ class ReluGradGpuFwdKernel : public GpuKernel {
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, input_descriptor_, y, input_descriptor_, dy, cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
input_descriptor_, y, &beta, input_descriptor_, dx), data_descriptor_, y, &beta, data_descriptor_, dx),
"cudnnActivationBackward failed"); "cudnnActivationBackward failed");
return true; return true;
@ -74,14 +74,12 @@ class ReluGradGpuFwdKernel : public GpuKernel {
InitSizeLists(); InitSizeLists();
return true; return true;
} }
int batch_size = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); std::vector<int> shape;
int channel_size = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); ShapeNdTo4d(input_shape, &shape);
int height = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]);
int width = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0),
"SetActivationDescriptor failed"); "SetActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
batch_size, channel_size, height, width), shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed"); "SetTensor4dDescriptor failed");
InitSizeLists(); InitSizeLists();
@ -91,13 +89,13 @@ class ReluGradGpuFwdKernel : public GpuKernel {
protected: protected:
void InitResource() override { void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_),
"cudnnCreateActivationDescriptor failed"); "cudnnCreateActivationDescriptor failed");
} }
void InitSizeLists() override { void InitSizeLists() override {
if (!is_null_input_) { if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed"); "cudnnGetTensorSizeInBytes failed");
} }
input_size_list_.push_back(input_size_); input_size_list_.push_back(input_size_);
@ -109,13 +107,13 @@ class ReluGradGpuFwdKernel : public GpuKernel {
void DestroyResource() noexcept { void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed"); "cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
} }
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_; cudnnActivationDescriptor_t activation_desc_;
cudnnActivationMode_t mode_; cudnnActivationMode_t mode_;
cudnnTensorDescriptor_t input_descriptor_; cudnnTensorDescriptor_t data_descriptor_;
bool is_null_input_; bool is_null_input_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;

Loading…
Cancel
Save