IR operators of GPU and CPU are unified as batchnorm

pull/12115/head
dingpeifei 4 years ago
parent 172a28fe14
commit 87e41aaeee

@ -16,6 +16,10 @@ Previously MakeRefKey is an external interface that is not used, now make it an
Previously the number of outputs of these operator is different on different backends. To unify their definition we change their output on Ascend backend from multiple to a single. Previously the number of outputs of these operator is different on different backends. To unify their definition we change their output on Ascend backend from multiple to a single.
##### `P.FusedBatchNorm`, `P.FusedBatchNormEx` deleted ([!12115](https://gitee.com/mindspore/mindspore/pulls/12115))
The FusedBatchNorm and FusedBatchNormEx interface has been deleted. Please use the batchnorm operator to replace it.
# MindSpore 1.1.1 Release Notes # MindSpore 1.1.1 Release Notes
## MindSpore ## MindSpore

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,14 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */
#include <string> #include <string>
#include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void BatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node); CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
size_t type_size = sizeof(float); size_t type_size = sizeof(float);
@ -30,16 +30,13 @@ void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
workspace_size_list_.emplace_back(tensor_size); workspace_size_list_.emplace_back(tensor_size);
} }
void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
auto node_name = AnfAlgo::GetCNodeName(kernel_node); is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
if (node_name == "FusedBatchNorm") {
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum"); momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
is_train = true;
}
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) { if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!"; MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
} }
batch_size = x_shape[0]; batch_size = x_shape[0];
channel = x_shape[1]; channel = x_shape[1];
@ -66,7 +63,7 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
AddArgument(DNNL_ARG_DST, x_desc); AddArgument(DNNL_ARG_DST, x_desc);
} }
bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool BatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 5 || outputs.empty()) { if (inputs.size() < 5 || outputs.empty()) {

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,18 +13,18 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_CPU_KERNEL_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class FusedBatchNormCPUKernel : public MKLCPUKernel { class BatchNormCPUKernel : public MKLCPUKernel {
public: public:
FusedBatchNormCPUKernel() = default; BatchNormCPUKernel() = default;
~FusedBatchNormCPUKernel() override = default; ~BatchNormCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
@ -43,20 +43,6 @@ class FusedBatchNormCPUKernel : public MKLCPUKernel {
size_t nhw_size{0}; size_t nhw_size{0};
}; };
MS_REG_CPU_KERNEL(FusedBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormCPUKernel)
MS_REG_CPU_KERNEL(BatchNorm, MS_REG_CPU_KERNEL(BatchNorm,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -69,7 +55,7 @@ MS_REG_CPU_KERNEL(BatchNorm,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormCPUKernel) BatchNormCPUKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_gard_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/batch_norm_gard_cpu_kernel.h"
#include <string> #include <string>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
@ -22,19 +22,20 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void FusedBatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node); CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
size_t type_size = sizeof(float); size_t type_size = sizeof(float);
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t tensor_size = shape[1] * 2 * type_size; size_t tensor_size = shape[1] * 2 * type_size;
input_size_list_.pop_back();
// [2, c] to store scale and bias // [2, c] to store scale and bias
workspace_size_list_.emplace_back(tensor_size); workspace_size_list_.emplace_back(tensor_size);
// [2, c] to store diff_scale and diff_bias // [2, c] to store diff_scale and diff_bias
workspace_size_list_.emplace_back(tensor_size); workspace_size_list_.emplace_back(tensor_size);
} }
void FusedBatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) { if (x_shape.size() != 4) {
@ -72,7 +73,7 @@ void FusedBatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
AddArgument(DNNL_ARG_DIFF_SCALE_SHIFT, scale_bias_desc); AddArgument(DNNL_ARG_DIFF_SCALE_SHIFT, scale_bias_desc);
} }
bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 5 || outputs.empty()) { if (inputs.size() < 5 || outputs.empty()) {
@ -81,16 +82,16 @@ bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr); auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr);
auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size); auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size);
auto max_size = workspace[0]->size - inputs[2]->size; auto max_size = workspace[0]->size - inputs[2]->size;
auto bias_ret = memcpy_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, inputs[3]->addr, inputs[3]->size); auto bias_ret = memset_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, 0., max_size);
if (scale_ret != 0 || bias_ret != 0) { if (scale_ret != 0 && bias_ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy_s error."; MS_LOG(EXCEPTION) << "Memcpy_s error.";
return false; return false;
} }
SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_MEAN, inputs[4]->addr); SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr);
SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[5]->addr); SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr);
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SCALE_SHIFT, workspace[1]->addr); SetArgumentHandle(DNNL_ARG_DIFF_SCALE_SHIFT, workspace[1]->addr);
@ -99,7 +100,7 @@ bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr); auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr);
auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size); auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size);
auto diff_bias_ret = auto diff_bias_ret =
memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), inputs[3]->size); memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), outputs[2]->size);
if (diff_scale_ret != 0 || diff_bias_ret != 0) { if (diff_scale_ret != 0 || diff_bias_ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy_s error."; MS_LOG(EXCEPTION) << "Memcpy_s error.";
return false; return false;

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,18 +13,18 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_GRAD_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_GRAD_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_GRAD_CPU_KERNEL_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class FusedBatchNormGradCPUKernel : public MKLCPUKernel { class BatchNormGradCPUKernel : public MKLCPUKernel {
public: public:
FusedBatchNormGradCPUKernel() = default; BatchNormGradCPUKernel() = default;
~FusedBatchNormGradCPUKernel() override = default; ~BatchNormGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
@ -42,7 +42,7 @@ class FusedBatchNormGradCPUKernel : public MKLCPUKernel {
size_t nhw_size{0}; size_t nhw_size{0};
}; };
MS_REG_CPU_KERNEL(FusedBatchNormGradCPU, MS_REG_CPU_KERNEL(BatchNormGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -53,7 +53,7 @@ MS_REG_CPU_KERNEL(FusedBatchNormGradCPU,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGradCPUKernel) BatchNormGradCPUKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h" #include "backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx, MS_REG_GPU_KERNEL_ONE(BatchNorm,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -29,10 +29,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormExGpuKernel, float) BatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx, MS_REG_GPU_KERNEL_ONE(BatchNorm,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -43,11 +42,10 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormExGpuKernel, half) BatchNormGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation, MS_REG_GPU_KERNEL_ONE(BatchNormWithActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -58,10 +56,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormExGpuKernel, float) BatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation, MS_REG_GPU_KERNEL_ONE(BatchNormWithActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -72,11 +69,10 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormExGpuKernel, half) BatchNormGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation, MS_REG_GPU_KERNEL_ONE(BatchNormWithAddAndActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -88,10 +84,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormExGpuKernel, float) BatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation, MS_REG_GPU_KERNEL_ONE(BatchNormWithAddAndActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
@ -103,8 +98,7 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormExGpuKernel, half) BatchNormGpuKernel, half)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GPU_KERNEL_H_
#include <string> #include <string>
#include <vector> #include <vector>
@ -27,10 +27,10 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class FusedBatchNormExGpuKernel : public GpuKernel { class BatchNormGpuKernel : public GpuKernel {
public: public:
FusedBatchNormExGpuKernel() { ResetResource(); } BatchNormGpuKernel() { ResetResource(); }
~FusedBatchNormExGpuKernel() override { DestroyResource(); } ~BatchNormGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -46,30 +46,38 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
auto x = GetDeviceAddress<T>(inputs, 0); auto x = GetDeviceAddress<T>(inputs, 0);
auto scale = GetDeviceAddress<float>(inputs, 1); auto scale = GetDeviceAddress<float>(inputs, 1);
auto bias = GetDeviceAddress<float>(inputs, 2); auto bias = GetDeviceAddress<float>(inputs, 2);
auto runing_mean = GetDeviceAddress<float>(inputs, 3); auto running_mean = GetDeviceAddress<float>(inputs, 3);
auto runnig_variance = GetDeviceAddress<float>(inputs, 4); auto running_variance = GetDeviceAddress<float>(inputs, 4);
T *z = nullptr; T *z = nullptr;
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) { if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
z = GetDeviceAddress<T>(inputs, 5); z = GetDeviceAddress<T>(inputs, 5);
} }
auto y = GetDeviceAddress<T>(outputs, 0); auto y = GetDeviceAddress<T>(outputs, 0);
auto save_mean = GetDeviceAddress<float>(outputs, 3); auto reserve_addr = GetDeviceAddress<float>(outputs, 2);
auto save_variance = GetDeviceAddress<float>(outputs, 4);
auto reserve_addr = GetDeviceAddress<float>(outputs, 5);
T *workspace_addr = nullptr; T *workspace_addr = nullptr;
if (workspace_size_ != 0) { if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0); workspace_addr = GetDeviceAddress<T>(workspace, 0);
} }
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
if (is_train_) {
auto save_mean = GetDeviceAddress<float>(outputs, 3);
auto save_variance = GetDeviceAddress<float>(outputs, 4);
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, kernel_node_,
cudnnBatchNormalizationForwardTrainingEx(handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x, z_desc_, z, y_desc_, cudnnBatchNormalizationForwardTrainingEx(
y, scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x, z_desc_, z, y_desc_, y, scale_bias_mean_var_desc_, scale,
runnig_variance, epsilon_, save_mean, save_variance, activation_desc_, bias, exp_avg_factor_, running_mean, running_variance, epsilon_, save_mean, save_variance, activation_desc_,
workspace_addr, workspace_size_, reserve_addr, reserve_size_), workspace_addr, workspace_size_, reserve_addr, reserve_size_),
"Kernel launch failed"); "Kernel launch failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnBatchNormalizationForwardInference(
handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, scale_bias_mean_var_desc_,
scale, bias, running_mean, running_variance, epsilon_),
"Kernel launch failed");
}
return true; return true;
} }
@ -77,18 +85,22 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == kFusedBatchNormEx) { if (kernel_name == kBatchNorm) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN; bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
} else if (kernel_name == kFusedBatchNormExWithActivation) { } else if (kernel_name == kBatchNormWithActivation) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
} else if (kernel_name == kFusedBatchNormExWithAddAndActivation) { } else if (kernel_name == kBatchNormWithAddAndActivation) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else { } else {
MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name; MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name;
} }
InitResource(); InitResource();
if (is_train_) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
epsilon_ = GetAttr<float>(kernel_node, "epsilon"); epsilon_ = GetAttr<float>(kernel_node, "epsilon");
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum"); exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
@ -106,11 +118,11 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (shape.size() != 4) { if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormExGpuKernel should be 4"; MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 4";
} }
is_null_input_ = CHECK_NULL_INPUT(shape); is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) { if (is_null_input_) {
MS_LOG(WARNING) << "FusedBatchNormExGpuKernel input is null"; MS_LOG(WARNING) << "BatchNormGpuKernel input is null";
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@ -121,6 +133,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
} }
SetTensorDescriptor(format, shape); SetTensorDescriptor(format, shape);
InitSizeLists(); InitSizeLists();
is_train_ = GetAttr<bool>(kernel_node, "is_training");
return true; return true;
} }
@ -135,6 +148,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN; bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
epsilon_ = 10e-5; epsilon_ = 10e-5;
exp_avg_factor_ = 0.1; exp_avg_factor_ = 0.1;
is_train_ = false;
is_null_input_ = false; is_null_input_ = false;
x_desc_ = nullptr; x_desc_ = nullptr;
y_desc_ = nullptr; y_desc_ = nullptr;
@ -215,11 +229,10 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
} }
output_size_list_.push_back(output_size_); // output output_size_list_.push_back(output_size_); // output
output_size_list_.push_back(reserve_size_); // reserve space
output_size_list_.push_back(para_size_); // save scale output_size_list_.push_back(para_size_); // save scale
output_size_list_.push_back(para_size_); // save bias
output_size_list_.push_back(para_size_); // save mean output_size_list_.push_back(para_size_); // save mean
output_size_list_.push_back(para_size_); // save variance output_size_list_.push_back(para_size_); // save variance
output_size_list_.push_back(reserve_size_); // reserve space
workspace_size_list_.push_back(workspace_size_); workspace_size_list_.push_back(workspace_size_);
} }
@ -280,6 +293,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
cudnnBatchNormOps_t bn_ops_; cudnnBatchNormOps_t bn_ops_;
double epsilon_; double epsilon_;
double exp_avg_factor_; double exp_avg_factor_;
bool is_train_;
bool is_null_input_; bool is_null_input_;
cudnnTensorDescriptor_t x_desc_; cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_; cudnnTensorDescriptor_t y_desc_;

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h" #include "backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx, MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy .AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x .AddInputAttr(kNumberTypeFloat32) // x
@ -29,8 +29,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
.AddOutputAttr(kNumberTypeFloat32) // dx .AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale .AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias .AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, float) BatchNormGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx, MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy .AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x .AddInputAttr(kNumberTypeFloat16) // x
@ -41,9 +41,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
.AddOutputAttr(kNumberTypeFloat16) // dx .AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale .AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias .AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, half) BatchNormGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation, MS_REG_GPU_KERNEL_ONE(BatchNormGradWithActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy .AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x .AddInputAttr(kNumberTypeFloat32) // x
@ -56,8 +56,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
.AddOutputAttr(kNumberTypeFloat32) // dx .AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale .AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias .AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, float) BatchNormGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation, MS_REG_GPU_KERNEL_ONE(BatchNormGradWithActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy .AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x .AddInputAttr(kNumberTypeFloat16) // x
@ -70,9 +70,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
.AddOutputAttr(kNumberTypeFloat16) // dx .AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale .AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias .AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, half) BatchNormGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation, MS_REG_GPU_KERNEL_ONE(BatchNormGradWithAddAndActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy .AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x .AddInputAttr(kNumberTypeFloat32) // x
@ -86,8 +86,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
.AddOutputAttr(kNumberTypeFloat32) // dscale .AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32) // dbias .AddOutputAttr(kNumberTypeFloat32) // dbias
.AddOutputAttr(kNumberTypeFloat32), // dz .AddOutputAttr(kNumberTypeFloat32), // dz
FusedBatchNormGradExGpuKernel, float) BatchNormGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation, MS_REG_GPU_KERNEL_ONE(BatchNormGradWithAddAndActivation,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy .AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x .AddInputAttr(kNumberTypeFloat16) // x
@ -101,6 +101,6 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
.AddOutputAttr(kNumberTypeFloat32) // dscale .AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32) // dbias .AddOutputAttr(kNumberTypeFloat32) // dbias
.AddOutputAttr(kNumberTypeFloat16), // dz .AddOutputAttr(kNumberTypeFloat16), // dz
FusedBatchNormGradExGpuKernel, half) BatchNormGradGpuKernel, half)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GRAD_GPU_KERNEL_H_
#include <string> #include <string>
#include <vector> #include <vector>
@ -24,13 +24,14 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h" #include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class FusedBatchNormGradExGpuKernel : public GpuKernel { class BatchNormGradGpuKernel : public GpuKernel {
public: public:
FusedBatchNormGradExGpuKernel() BatchNormGradGpuKernel()
: x_size_(0), : x_size_(0),
para_size_(0), para_size_(0),
workspace_size_(0), workspace_size_(0),
@ -38,6 +39,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
mode_(CUDNN_BATCHNORM_SPATIAL), mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN), bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5), epsilon_(10e-5),
is_train_(false),
is_null_input_(false), is_null_input_(false),
x_desc_(nullptr), x_desc_(nullptr),
y_desc_(nullptr), y_desc_(nullptr),
@ -49,7 +51,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
handle_(nullptr), handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT), cudnn_data_type_(CUDNN_DATA_FLOAT),
beta_data_diff_(0) {} beta_data_diff_(0) {}
~FusedBatchNormGradExGpuKernel() override { DestroyResource(); } ~BatchNormGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -88,17 +90,22 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
if (workspace_size_ != 0) { if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0); workspace_addr = GetDeviceAddress<T>(workspace, 0);
} }
if (is_train_) {
const float alpha_data_diff = 1; const float alpha_data_diff = 1;
const float alpha_param_diff = 1; const float alpha_param_diff = 1;
const float beta_param_diff = 0; const float beta_param_diff = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnBatchNormalizationBackwardEx( kernel_node_,
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, cudnnBatchNormalizationBackwardEx(handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_,
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, &alpha_param_diff, &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy,
scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, scale, bias, dscale, dbias,
activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_), epsilon_, save_mean, save_variance, activation_desc_, workspace_addr,
workspace_size_, reserve_addr, reserve_size_),
"Kernel launch failed"); "Kernel launch failed");
} else {
CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, dscale, dbias, epsilon_, batch_, channel_, height_,
width_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true; return true;
} }
@ -106,11 +113,11 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == kFusedBatchNormGradEx) { if (kernel_name == kBatchNormGradOpName) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN; bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
} else if (kernel_name == kFusedBatchNormGradExWithActivation) { } else if (kernel_name == kBatchNormGradWithActivation) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
} else if (kernel_name == kFusedBatchNormGradExWithAddAndActivation) { } else if (kernel_name == kBatchNormGradWithAddAndActivation) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else { } else {
MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name; MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name;
@ -134,11 +141,11 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (shape.size() != 4) { if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradExGpuKernel should be 4"; MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4";
} }
is_null_input_ = CHECK_NULL_INPUT(shape); is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) { if (is_null_input_) {
MS_LOG(WARNING) << "FusedBatchNormGradExGpuKernel input is null"; MS_LOG(WARNING) << "BatchNormGradGpuKernel input is null";
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@ -150,6 +157,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
SetTensorDescriptor(format, shape); SetTensorDescriptor(format, shape);
InitSizeLists(); InitSizeLists();
is_train_ = GetAttr<bool>(kernel_node, "is_training");
return true; return true;
} }
@ -225,50 +233,52 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
private: private:
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) { void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
cudnnTensorFormat_t cudnn_format; cudnnTensorFormat_t cudnn_format;
int batch, channel, height, width;
if (format == kOpFormat_NHWC) { if (format == kOpFormat_NHWC) {
batch = SizeToInt(shape[0]); batch_ = SizeToInt(shape[0]);
height = SizeToInt(shape[1]); height_ = SizeToInt(shape[1]);
width = SizeToInt(shape[2]); width_ = SizeToInt(shape[2]);
channel = SizeToInt(shape[3]); channel_ = SizeToInt(shape[3]);
cudnn_format = CUDNN_TENSOR_NHWC; cudnn_format = CUDNN_TENSOR_NHWC;
} else { } else {
batch = SizeToInt(shape[0]); batch_ = SizeToInt(shape[0]);
channel = SizeToInt(shape[1]); channel_ = SizeToInt(shape[1]);
height = SizeToInt(shape[2]); height_ = SizeToInt(shape[2]);
width = SizeToInt(shape[3]); width_ = SizeToInt(shape[3]);
cudnn_format = CUDNN_TENSOR_NCHW; cudnn_format = CUDNN_TENSOR_NCHW;
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), kernel_node_,
cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed"); "Set x desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) { if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, kernel_node_,
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set z desc failed"); "Set z desc failed");
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), kernel_node_,
cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dy desc failed"); "Set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), kernel_node_,
cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dx desc failed"); "Set dx desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) { if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, kernel_node_,
cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set z desc failed"); "Set z desc failed");
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1), cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed"); "Set para desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) { if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
@ -278,7 +288,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
"cudnnSetActivationDescriptor failed"); "cudnnSetActivationDescriptor failed");
} }
} }
int batch_;
int channel_;
int height_;
int width_;
size_t x_size_; size_t x_size_;
size_t para_size_; size_t para_size_;
size_t workspace_size_; size_t workspace_size_;
@ -286,6 +299,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
cudnnBatchNormMode_t mode_; cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_; cudnnBatchNormOps_t bn_ops_;
double epsilon_; double epsilon_;
bool is_train_;
bool is_null_input_; bool is_null_input_;
cudnnTensorDescriptor_t x_desc_; cudnnTensorDescriptor_t x_desc_;

@ -1,48 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BatchNormGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BatchNormGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -1,202 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class BatchNormGradGpuKernel : public GpuKernel {
public:
BatchNormGradGpuKernel()
: batch_(0),
channel_(0),
height_(0),
width_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
epsilon_(10e-5),
is_null_input_(false),
x_desc_(nullptr),
dy_desc_(nullptr),
dx_desc_(nullptr),
scale_bias_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
~BatchNormGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto dy = GetDeviceAddress<T>(inputs, 0);
auto x = GetDeviceAddress<T>(inputs, 1);
auto scale = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
auto dx = GetDeviceAddress<T>(outputs, 0);
auto bn_scale = GetDeviceAddress<float>(outputs, 1);
auto bn_bias = GetDeviceAddress<float>(outputs, 2);
auto reserve_1 = GetDeviceAddress<T>(outputs, 3);
auto reserve_2 = GetDeviceAddress<T>(outputs, 4);
// For CI only, reserved vars can not be unused.
MS_LOG(DEBUG) << reinterpret_cast<size_t>(reserve_1) << reinterpret_cast<size_t>(reserve_2); // NOLINT
if (is_training_) {
const float alpha_data_diff = 1;
const float beta_data_diff = 0;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_,
scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance),
"Kernel Launch Failed.");
} else {
CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_,
height_, width_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", BatchNormGradGpuKernel should be 5";
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4";
return false;
}
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
MS_LOG(WARNING) << "BatchNormGradGpuKernel input is null";
InitSizeLists();
return true;
}
batch_ = SizeToInt(shape[0]);
channel_ = SizeToInt(shape[1]);
height_ = SizeToInt(shape[2]);
width_ = SizeToInt(shape[3]);
mode_ = CUDNN_BATCHNORM_SPATIAL;
is_training_ = GetAttr<bool>(kernel_node, "is_training");
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed");
InitSizeLists();
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_desc_),
"Destroy para desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
size_t input_size = 0;
size_t para_size = 0;
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size),
"Get input size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_desc_, &para_size),
"Get input size failed");
}
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(para_size);
input_size_list_.push_back(para_size);
input_size_list_.push_back(para_size);
output_size_list_.push_back(input_size);
output_size_list_.push_back(para_size);
output_size_list_.push_back(para_size);
output_size_list_.push_back(input_size);
output_size_list_.push_back(input_size);
}
private:
int batch_;
int channel_;
int height_;
int width_;
cudnnBatchNormMode_t mode_;
bool is_training_;
double epsilon_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t scale_bias_desc_;
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_

@ -1,74 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(BatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -1,204 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
class FusedBatchNormGpuKernel : public GpuKernel {
public:
FusedBatchNormGpuKernel()
: batch_(0),
channel_(0),
height_(0),
width_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
epsilon_(10e-5),
exp_avg_factor_(0.1),
is_train_(false),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
scale_bias_mean_var_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
~FusedBatchNormGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto x = GetDeviceAddress<T>(inputs, 0);
auto scale = GetDeviceAddress<float>(inputs, 1);
auto bias = GetDeviceAddress<float>(inputs, 2);
auto runing_mean = GetDeviceAddress<float>(inputs, 3);
auto runnig_variance = GetDeviceAddress<float>(inputs, 4);
auto y = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
if (is_train_) {
auto save_mean = GetDeviceAddress<float>(outputs, 3);
auto save_variance = GetDeviceAddress<float>(outputs, 4);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y,
scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean,
runnig_variance, epsilon_, save_mean, save_variance),
"Kernel launch failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x,
y_desc_, y, scale_bias_mean_var_desc_, scale,
bias, runing_mean, runnig_variance, epsilon_),
"Kernel launch failed");
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5";
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4";
}
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null";
InitSizeLists();
return true;
}
cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW;
auto format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "format");
if (format_attr == kOpFormat_NHWC) {
format = kOpFormat_NHWC;
cudnn_format = CUDNN_TENSOR_NHWC;
}
SetNCHW(shape, &batch_, &channel_, &height_, &width_, format);
mode_ = CUDNN_BATCHNORM_SPATIAL;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
// P.FusedBatchNorm is used for training; P.BatchNorm is used for inference
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if (node_name == "FusedBatchNorm") {
is_train_ = true;
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, cudnn_format, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed");
InitSizeLists();
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
"Destroy para desc failed");
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
size_t input_size = 0;
size_t para_size = 0;
size_t output_size = 0;
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size),
"Get input size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, &para_size),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size),
"Get para size failed");
}
input_size_list_.push_back(input_size);
input_size_list_.push_back(para_size); // scale
input_size_list_.push_back(para_size); // bias
input_size_list_.push_back(para_size); // mean
input_size_list_.push_back(para_size); // variance
output_size_list_.push_back(output_size);
output_size_list_.push_back(para_size); // running mean
output_size_list_.push_back(para_size); // running variance
output_size_list_.push_back(para_size); // save mean
output_size_list_.push_back(para_size); // save variance
return;
}
private:
int batch_;
int channel_;
int height_;
int width_;
cudnnBatchNormMode_t mode_;
double epsilon_;
double exp_avg_factor_;
bool is_train_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_;
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_

@ -1,44 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -1,188 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
class FusedBatchNormGradGpuKernel : public GpuKernel {
public:
FusedBatchNormGradGpuKernel()
: batch_(0),
channel_(0),
height_(0),
width_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
epsilon_(10e-5),
is_null_input_(false),
x_desc_(nullptr),
dy_desc_(nullptr),
dx_desc_(nullptr),
scale_bias_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
~FusedBatchNormGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto dy = GetDeviceAddress<T>(inputs, 0);
auto x = GetDeviceAddress<T>(inputs, 1);
auto scale = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
auto dx = GetDeviceAddress<T>(outputs, 0);
auto bn_scale = GetDeviceAddress<float>(outputs, 1);
auto bn_bias = GetDeviceAddress<float>(outputs, 2);
const float alpha_data_diff = 1;
const float beta_data_diff = 0;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale,
bn_scale, bn_bias, epsilon_, save_mean, save_variance),
"Kernel Launch Failed.");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5";
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4";
return false;
}
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null";
InitSizeLists();
return true;
}
batch_ = SizeToInt(shape[0]);
channel_ = SizeToInt(shape[1]);
height_ = SizeToInt(shape[2]);
width_ = SizeToInt(shape[3]);
mode_ = CUDNN_BATCHNORM_SPATIAL;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed");
InitSizeLists();
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_desc_),
"Destroy para desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
size_t input_size = 0;
size_t para_size = 0;
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size),
"Get input size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_desc_, &para_size),
"Get input size failed");
}
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(para_size);
input_size_list_.push_back(para_size);
input_size_list_.push_back(para_size);
output_size_list_.push_back(input_size);
output_size_list_.push_back(para_size);
output_size_list_.push_back(para_size);
}
private:
int batch_;
int channel_;
int height_;
int width_;
cudnnBatchNormMode_t mode_;
double epsilon_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t scale_bias_desc_;
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_

@ -34,7 +34,12 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node); MS_EXCEPTION_IF_NULL(bn_grad_node);
auto bn_grad_inputs = bn_grad_node->inputs(); auto bn_grad_inputs = bn_grad_node->inputs();
if (AnfAlgo::CheckPrimitiveType(bn_grad_node, prim::kPrimBatchNormGrad)) {
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
} else {
CheckCNodeInputSize(bn_grad_node, kSyncBNGradInputTensorNum);
}
std::vector<AnfNodePtr> bn_update_grad_inputs = { std::vector<AnfNodePtr> bn_update_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2],
bn_grad_inputs[4], bn_grad_inputs[5]}; bn_grad_inputs[4], bn_grad_inputs[5]};
@ -57,7 +62,12 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node); MS_EXCEPTION_IF_NULL(bn_grad_node);
auto bn_grad_inputs = bn_grad_node->inputs(); auto bn_grad_inputs = bn_grad_node->inputs();
if (AnfAlgo::CheckPrimitiveType(bn_grad_node, prim::kPrimBatchNormGrad)) {
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
} else {
CheckCNodeInputSize(bn_grad_node, kSyncBNGradInputTensorNum);
}
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size";
} }
@ -110,6 +120,7 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> bn_update_grad_outputs; std::vector<AnfNodePtr> bn_update_grad_outputs;
CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs);
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size" MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,7 +38,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_cnode); MS_EXCEPTION_IF_NULL(bn_cnode);
if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) { if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString(); MS_LOG(INFO) << "BatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
return false; return false;
} }
std::vector<AnfNodePtr> bn_training_reduce_inputs = { std::vector<AnfNodePtr> bn_training_reduce_inputs = {
@ -51,7 +51,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
bn_training_reduce->set_kernel_info(kernel_info); bn_training_reduce->set_kernel_info(kernel_info);
std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0);
if (bn_shape_i0.size() < kShape2dDims) { if (bn_shape_i0.size() < kShape2dDims) {
MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims;
return false; return false;
} }
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]}; std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]};

@ -33,7 +33,7 @@ CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchn
MS_EXCEPTION_IF_NULL(batchnorm_grad); MS_EXCEPTION_IF_NULL(batchnorm_grad);
auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName); auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
for (size_t i = 1; i < batchnorm_grad->size(); ++i) { for (size_t i = 1; i < batchnorm_grad->size() - 1; ++i) {
inputs.push_back(batchnorm_grad->input(i)); inputs.push_back(batchnorm_grad->input(i));
} }
auto new_node = graph->NewCNode(inputs); auto new_node = graph->NewCNode(inputs);

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -56,7 +56,8 @@ constexpr size_t kBN1OutputNum = 2;
constexpr size_t kBN2OutputNum = 3; constexpr size_t kBN2OutputNum = 3;
constexpr size_t kBN3OutputNum = 1; constexpr size_t kBN3OutputNum = 1;
constexpr size_t kBNGradInputTensorNum = 5; constexpr size_t kBNGradInputTensorNum = 6;
constexpr size_t kSyncBNGradInputTensorNum = 5;
constexpr size_t kBNGradOutputNum = 3; constexpr size_t kBNGradOutputNum = 3;
constexpr size_t kBNGrad1OutputNum = 3; constexpr size_t kBNGrad1OutputNum = 3;

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,8 +28,8 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
const BaseRef BatchNormAddReluFusion::DefinePattern() const { const BaseRef BatchNormAddReluFusion::DefinePattern() const {
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_}); VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
return relu; return relu;
@ -44,24 +44,24 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
MS_EXCEPTION_IF_NULL(tensor_add); MS_EXCEPTION_IF_NULL(tensor_add);
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0); auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0);
MS_EXCEPTION_IF_NULL(tuple_get_item); MS_EXCEPTION_IF_NULL(tuple_get_item);
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex); MS_EXCEPTION_IF_NULL(batch_norm);
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format");
MS_EXCEPTION_IF_NULL(format_attr); MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr); auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr; return nullptr;
} }
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) { if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr; return nullptr;
} }
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4);
auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1); auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1);
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
@ -71,7 +71,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
MS_EXCEPTION_IF_NULL(z); MS_EXCEPTION_IF_NULL(z);
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation); auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
@ -79,17 +79,17 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
std::vector<TypeId> outputs_type; std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape; std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i));
} }
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu); AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
return tuple_get_item; return tuple_get_item;
} }

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -85,14 +85,14 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
std::vector<AnfNodePtr> bn_add_relu_grad_output; std::vector<AnfNodePtr> bn_add_relu_grad_output;
CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output); CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output);
if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) { if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be " MS_LOG(EXCEPTION) << "The output size of node " << kBatchNormGradWithAddAndActivation << " must be "
<< kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size(); << kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size();
} }
// Get bn outputs // Get bn outputs
std::vector<AnfNodePtr> bn_outputs; std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) { if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) {
MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx MS_LOG(INFO) << "The " << prim::kPrimBatchNormGrad
<< " node should only have output 0, 1 and 2. The node should not be changed"; << " node should only have output 0, 1 and 2. The node should not be changed";
return; return;
} }
@ -139,7 +139,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
return false; return false;
} }
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0); auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) { if (AnfAlgo::GetCNodeName(forward_node) != kBatchNormWithAddAndActivation) {
return false; return false;
} }
@ -150,7 +150,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const { const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad = VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad; return batch_norm_grad;
} }
@ -184,7 +184,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(bias);
auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithAddAndActivation); auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs); auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs);

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,8 +28,8 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
const BaseRef BatchNormReluFusion::DefinePattern() const { const BaseRef BatchNormReluFusion::DefinePattern() const {
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get}); VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get});
return relu; return relu;
} }
@ -41,24 +41,24 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(tuple_get_item); MS_EXCEPTION_IF_NULL(tuple_get_item);
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex); MS_EXCEPTION_IF_NULL(batch_norm);
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format");
MS_EXCEPTION_IF_NULL(format_attr); MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr); auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr; return nullptr;
} }
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) { if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr; return nullptr;
} }
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4);
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(scale); MS_EXCEPTION_IF_NULL(scale);
@ -66,7 +66,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithActivation); auto prim = std::make_shared<Primitive>(kBatchNormWithActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var};
auto fused_batch_norm_with_relu = graph->NewCNode(inputs); auto fused_batch_norm_with_relu = graph->NewCNode(inputs);
@ -74,17 +74,17 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
std::vector<TypeId> outputs_type; std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape; std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i));
} }
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get());
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_relu); AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu); manager->Replace(batch_norm, fused_batch_norm_with_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_relu); device::gpu::SetKernelInfo(fused_batch_norm_with_relu);
return tuple_get_item; return tuple_get_item;
} }

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,7 +31,7 @@ namespace opt {
const BaseRef BatchNormReluGradFusion::DefinePattern() const { const BaseRef BatchNormReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad = VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad; return batch_norm_grad;
} }
@ -82,7 +82,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(bias);
auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithActivation); auto prim = std::make_shared<Primitive>(kBatchNormGradWithActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
auto fused_batch_norm_grad_with_relu = graph->NewCNode(inputs); auto fused_batch_norm_grad_with_relu = graph->NewCNode(inputs);

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -42,8 +42,7 @@ struct AnfNodeIndex {
}; };
// opname, output idx // opname, output idx
std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, {kBatchNormGradWithAddAndActivation, 3}};
{kFusedBatchNormGradExWithAddAndActivation, 3}};
std::set<string> kSkipOpNames = { std::set<string> kSkipOpNames = {
kTensorAddOpName, kTensorAddOpName,
@ -51,7 +50,7 @@ std::set<string> kSkipOpNames = {
// opname, input idx // opname, input idx
std::map<string, uint32_t> kAggregatesOpNames = { std::map<string, uint32_t> kAggregatesOpNames = {
{kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kBatchNormGradWithAddAndActivation, 0}};
constexpr size_t inplace_node_size = 2; constexpr size_t inplace_node_size = 2;

@ -28,8 +28,8 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const { const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item}); VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
return relu; return relu;
@ -44,24 +44,24 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(tensor_add); MS_EXCEPTION_IF_NULL(tensor_add);
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1); auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1);
MS_EXCEPTION_IF_NULL(tuple_get_item); MS_EXCEPTION_IF_NULL(tuple_get_item);
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex); MS_EXCEPTION_IF_NULL(batch_norm);
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format");
MS_EXCEPTION_IF_NULL(format_attr); MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr); auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr; return nullptr;
} }
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) { if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr; return nullptr;
} }
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4);
auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0); auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0);
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
@ -71,7 +71,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(var); MS_EXCEPTION_IF_NULL(var);
MS_EXCEPTION_IF_NULL(z); MS_EXCEPTION_IF_NULL(z);
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation); auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
@ -79,17 +79,17 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
std::vector<TypeId> outputs_type; std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape; std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i));
} }
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu); AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
return tuple_get_item; return tuple_get_item;
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save