!14372 GPU update addn

From: @VectorSL
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
pull/14372/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b4fbf5b134

@ -115,20 +115,20 @@ class AddNGpuFwdKernel : public GpuKernel {
for (size_t i = input_shape.size(); i < 4; i++) {
(void)input_shape.insert(input_shape.begin(), 1);
}
int dimA[4];
std::vector<int> dimA;
for (size_t i = 0; i < input_shape.size(); i++) {
dimA[i] = SizeToInt(input_shape[i]);
dimA.push_back(SizeToInt(input_shape[i]));
}
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
if (input_format == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
SizeToInt(input_shape.size()), dimA),
SizeToInt(input_shape.size()), dimA.data()),
"cudnnSetTensorNdDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
SizeToInt(input_shape.size()), dimA),
SizeToInt(input_shape.size()), dimA.data()),
"cudnnSetTensorNdDescriptor failed");
}
InitSizeLists();

Loading…
Cancel
Save