|
|
|
@ -85,10 +85,10 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|
|
|
|
auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
|
|
|
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
|
|
|
|
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
|
|
|
|
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
|
|
|
|
|
if (format_attr == kOpFormat_NHWC) {
|
|
|
|
|
data_format_ = kOpFormat_NHWC;
|
|
|
|
|
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
|
|
|
|
format_attr_ = GetAttr<std::string>(kernel_node, "data_format");
|
|
|
|
|
if (format_attr_ == kOpFormat_NHWC) {
|
|
|
|
|
data_format = kOpFormat_NHWC;
|
|
|
|
|
}
|
|
|
|
|
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
|
|
|
|
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
|
|
|
|
@ -97,7 +97,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
|
|
|
|
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format);
|
|
|
|
|
const int nbDims = 4;
|
|
|
|
|
int dimA[4];
|
|
|
|
|
int strideAin[4];
|
|
|
|
@ -107,14 +107,14 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|
|
|
|
int strideAdy[4];
|
|
|
|
|
int dimAout[4];
|
|
|
|
|
int strideAout[4];
|
|
|
|
|
SetDimA(input_shape, dimA, 4, data_format_);
|
|
|
|
|
SetStrideA(input_shape, strideAin, 4, data_format_);
|
|
|
|
|
SetDimA(input_mask, dimAy, 4, data_format_);
|
|
|
|
|
SetStrideA(input_mask, strideAiny, 4, data_format_);
|
|
|
|
|
SetDimA(dout_shape, dimAdy, 4, data_format_);
|
|
|
|
|
SetStrideA(dout_shape, strideAdy, 4, data_format_);
|
|
|
|
|
SetDimA(output_shape, dimAout, 4, data_format_);
|
|
|
|
|
SetStrideA(output_shape, strideAout, 4, data_format_);
|
|
|
|
|
SetDimA(input_shape, dimA, 4, data_format);
|
|
|
|
|
SetStrideA(input_shape, strideAin, 4, data_format);
|
|
|
|
|
SetDimA(input_mask, dimAy, 4, data_format);
|
|
|
|
|
SetStrideA(input_mask, strideAiny, 4, data_format);
|
|
|
|
|
SetDimA(dout_shape, dimAdy, 4, data_format);
|
|
|
|
|
SetStrideA(dout_shape, strideAdy, 4, data_format);
|
|
|
|
|
SetDimA(output_shape, dimAout, 4, data_format);
|
|
|
|
|
SetStrideA(output_shape, strideAout, 4, data_format);
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny),
|
|
|
|
|
"cudnnSetTensor4dDescriptor failed");
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
|
|
|
|
@ -180,7 +180,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|
|
|
|
int window_width = window[3];
|
|
|
|
|
int stride_h = stride_[2];
|
|
|
|
|
int stride_w = stride_[3];
|
|
|
|
|
if (data_format_ == kOpFormat_NHWC) {
|
|
|
|
|
if (format_attr_ == kOpFormat_NHWC) {
|
|
|
|
|
window_height = window[1];
|
|
|
|
|
window_width = window[2];
|
|
|
|
|
stride_h = stride_[1];
|
|
|
|
@ -247,7 +247,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|
|
|
|
std::vector<size_t> workspace_size_list_;
|
|
|
|
|
std::string mode_;
|
|
|
|
|
std::string pad_mode_;
|
|
|
|
|
std::string data_format_ = kOpFormat_NCHW;
|
|
|
|
|
std::string format_attr_ = kOpFormat_NCHW;
|
|
|
|
|
cudnnDataType_t cudnn_data_type_;
|
|
|
|
|
cudnnTensorFormat_t compute_format_;
|
|
|
|
|
int old_height_;
|
|
|
|
|