fix poolinggrad in transpose mode

pull/8252/head
VectorSL 4 years ago
parent 1df488c1b1
commit a3fbbe22ee

@ -85,10 +85,10 @@ class PoolingGradGpuKernel : public GpuKernel {
auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format"); format_attr_ = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) { if (format_attr_ == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC; data_format = kOpFormat_NHWC;
} }
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
@ -97,7 +97,7 @@ class PoolingGradGpuKernel : public GpuKernel {
InitSizeLists(); InitSizeLists();
return true; return true;
} }
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format);
const int nbDims = 4; const int nbDims = 4;
int dimA[4]; int dimA[4];
int strideAin[4]; int strideAin[4];
@ -107,14 +107,14 @@ class PoolingGradGpuKernel : public GpuKernel {
int strideAdy[4]; int strideAdy[4];
int dimAout[4]; int dimAout[4];
int strideAout[4]; int strideAout[4];
SetDimA(input_shape, dimA, 4, data_format_); SetDimA(input_shape, dimA, 4, data_format);
SetStrideA(input_shape, strideAin, 4, data_format_); SetStrideA(input_shape, strideAin, 4, data_format);
SetDimA(input_mask, dimAy, 4, data_format_); SetDimA(input_mask, dimAy, 4, data_format);
SetStrideA(input_mask, strideAiny, 4, data_format_); SetStrideA(input_mask, strideAiny, 4, data_format);
SetDimA(dout_shape, dimAdy, 4, data_format_); SetDimA(dout_shape, dimAdy, 4, data_format);
SetStrideA(dout_shape, strideAdy, 4, data_format_); SetStrideA(dout_shape, strideAdy, 4, data_format);
SetDimA(output_shape, dimAout, 4, data_format_); SetDimA(output_shape, dimAout, 4, data_format);
SetStrideA(output_shape, strideAout, 4, data_format_); SetStrideA(output_shape, strideAout, 4, data_format);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
@ -180,7 +180,7 @@ class PoolingGradGpuKernel : public GpuKernel {
int window_width = window[3]; int window_width = window[3];
int stride_h = stride_[2]; int stride_h = stride_[2];
int stride_w = stride_[3]; int stride_w = stride_[3];
if (data_format_ == kOpFormat_NHWC) { if (format_attr_ == kOpFormat_NHWC) {
window_height = window[1]; window_height = window[1];
window_width = window[2]; window_width = window[2];
stride_h = stride_[1]; stride_h = stride_[1];
@ -247,7 +247,7 @@ class PoolingGradGpuKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
std::string mode_; std::string mode_;
std::string pad_mode_; std::string pad_mode_;
std::string data_format_ = kOpFormat_NCHW; std::string format_attr_ = kOpFormat_NCHW;
cudnnDataType_t cudnn_data_type_; cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_; cudnnTensorFormat_t compute_format_;
int old_height_; int old_height_;

Loading…
Cancel
Save