|
|
|
@ -147,13 +147,13 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|
|
|
|
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
|
|
|
|
|
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
|
|
|
|
|
IntToSize(old_width_ + pad_width_)};
|
|
|
|
|
SetDimA(padded_shape, dimA, data_format_);
|
|
|
|
|
SetStrideA(padded_shape, strideApadded, data_format_);
|
|
|
|
|
SetDimA(padded_shape, dimA, 4, data_format_);
|
|
|
|
|
SetStrideA(padded_shape, strideApadded, 4, data_format_);
|
|
|
|
|
} else if (data_format_ == "NHWC") {
|
|
|
|
|
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
|
|
|
|
|
IntToSize(c_)};
|
|
|
|
|
SetDimA(padded_shape, dimA, data_format_);
|
|
|
|
|
SetStrideA(padded_shape, strideApadded, data_format_);
|
|
|
|
|
SetDimA(padded_shape, dimA, 4, data_format_);
|
|
|
|
|
SetStrideA(padded_shape, strideApadded, 4, data_format_);
|
|
|
|
|
}
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded),
|
|
|
|
|
"cudnnSetTensor4dDescriptor failed");
|
|
|
|
@ -259,18 +259,18 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|
|
|
|
|
|
|
|
|
void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<size_t> &filter_shape,
|
|
|
|
|
const std::vector<size_t> &output_shape) {
|
|
|
|
|
int nbDims = 4;
|
|
|
|
|
const int nbDims = 4;
|
|
|
|
|
int dimA[4];
|
|
|
|
|
int strideAin[4];
|
|
|
|
|
int dimAout[4];
|
|
|
|
|
int strideAout[4];
|
|
|
|
|
SetDimA(in_shape, dimA, data_format_);
|
|
|
|
|
SetStrideA(in_shape, strideAin, data_format_);
|
|
|
|
|
SetDimA(output_shape, dimAout, data_format_);
|
|
|
|
|
SetStrideA(output_shape, strideAout, data_format_);
|
|
|
|
|
SetDimA(in_shape, dimA, 4, data_format_);
|
|
|
|
|
SetStrideA(in_shape, strideAin, 4, data_format_);
|
|
|
|
|
SetDimA(output_shape, dimAout, 4, data_format_);
|
|
|
|
|
SetStrideA(output_shape, strideAout, 4, data_format_);
|
|
|
|
|
int filterDimA[4];
|
|
|
|
|
// OHWI for NHWC; OIHW for NCHW
|
|
|
|
|
SetDimA(filter_shape, filterDimA, data_format_);
|
|
|
|
|
SetDimA(filter_shape, filterDimA, 4, data_format_);
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin),
|
|
|
|
|
"cudnnSetTensor4dDescriptor failed");
|
|
|
|
|
|
|
|
|
|