fix poolinggrad in transpose mode

pull/8252/head
VectorSL 4 years ago
parent 1df488c1b1
commit a3fbbe22ee

@ -85,10 +85,10 @@ class PoolingGradGpuKernel : public GpuKernel {
auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
format_attr_ = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr_ == kOpFormat_NHWC) {
data_format = kOpFormat_NHWC;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
@ -97,7 +97,7 @@ class PoolingGradGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format);
const int nbDims = 4;
int dimA[4];
int strideAin[4];
@ -107,14 +107,14 @@ class PoolingGradGpuKernel : public GpuKernel {
int strideAdy[4];
int dimAout[4];
int strideAout[4];
SetDimA(input_shape, dimA, 4, data_format_);
SetStrideA(input_shape, strideAin, 4, data_format_);
SetDimA(input_mask, dimAy, 4, data_format_);
SetStrideA(input_mask, strideAiny, 4, data_format_);
SetDimA(dout_shape, dimAdy, 4, data_format_);
SetStrideA(dout_shape, strideAdy, 4, data_format_);
SetDimA(output_shape, dimAout, 4, data_format_);
SetStrideA(output_shape, strideAout, 4, data_format_);
SetDimA(input_shape, dimA, 4, data_format);
SetStrideA(input_shape, strideAin, 4, data_format);
SetDimA(input_mask, dimAy, 4, data_format);
SetStrideA(input_mask, strideAiny, 4, data_format);
SetDimA(dout_shape, dimAdy, 4, data_format);
SetStrideA(dout_shape, strideAdy, 4, data_format);
SetDimA(output_shape, dimAout, 4, data_format);
SetStrideA(output_shape, strideAout, 4, data_format);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
@ -180,7 +180,7 @@ class PoolingGradGpuKernel : public GpuKernel {
int window_width = window[3];
int stride_h = stride_[2];
int stride_w = stride_[3];
if (data_format_ == kOpFormat_NHWC) {
if (format_attr_ == kOpFormat_NHWC) {
window_height = window[1];
window_width = window[2];
stride_h = stride_[1];
@ -247,7 +247,7 @@ class PoolingGradGpuKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
std::string mode_;
std::string pad_mode_;
std::string data_format_ = kOpFormat_NCHW;
std::string format_attr_ = kOpFormat_NCHW;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_height_;

Loading…
Cancel
Save