From 5316061fa33315f972d3c628294308011c5e79e2 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Tue, 1 Sep 2020 17:28:51 +0800 Subject: [PATCH] gpu resnet50 fusion --- .../gpu/cuda_impl/momentum_impl.cu | 77 ++++++-- .../gpu/cuda_impl/momentum_impl.cuh | 9 +- .../gpu/nn/fused_scale_momentum_gpu_kernel.cc | 44 +++++ .../gpu/nn/fused_scale_momentum_gpu_kernel.h | 85 +++++++++ ...d_weightdecay_scale_momentum_gpu_kernel.cc | 44 +++++ ...ed_weightdecay_scale_momentum_gpu_kernel.h | 87 +++++++++ .../gpu/apply_momentum_scale_fusion.cc | 67 +++++++ .../gpu/apply_momentum_scale_fusion.h | 48 +++++ .../gpu/apply_momentum_weight_scale_fusion.cc | 71 +++++++ .../gpu/apply_momentum_weight_scale_fusion.h | 52 ++++++ .../gpu/batch_norm_add_relu_grad_fusion.cc | 175 ++++++++++++++++++ .../gpu/batch_norm_add_relu_grad_fusion.h | 57 ++++++ .../ccsrc/backend/session/gpu_session.cc | 6 + mindspore/ccsrc/utils/utils.h | 2 + 14 files changed, 810 insertions(+), 14 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu mode change 100755 => 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu old mode 100755 new mode 100644 index 03a4ccb617..a91a9138b6 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,8 +26,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T * } template <> __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation, - const float *learning_rate, const half *gradient, - const float *momentum) { + const float *learning_rate, const half *gradient, const float *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; variable[i] -= __float2half(learning_rate[0]) * accumulation[i]; @@ -36,8 +35,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, } template <> __global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, - const float *learning_rate, const half *gradient, - const float *momentum) { + const float *learning_rate, const half *gradient, const float *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); variable[i] -= learning_rate[0] * accumulation[i]; @@ -51,15 +49,68 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con learning_rate, gradient, momentum); return; } + +template +__global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, + T *accumulation, const T *learning_rate, const S *gradient, + const T *momentum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { + T grad = (variable[i] * weight_decay[0] + static_cast(gradient[i])) * scale[0]; + accumulation[i] = momentum[0] * accumulation[i] + grad; + variable[i] -= learning_rate[0] * accumulation[i]; + } +} + +template +void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum, + cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; + FusedMomentumWeightDecayScaleMomentum<<>>( + element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); +} + +template +__global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { + accumulation[i] = momentum[0] * accumulation[i] + static_cast(gradient[i]); + variable[i] -= learning_rate[0] * accumulation[i]; + } +} + +template +void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, + const S *gradient, const T *momentum, cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; + FusedMomentumScaleMomentum<<>>( + element_num, scale, variable, accumulation, learning_rate, gradient, momentum); +} + template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, - const float *learning_rate, const float *gradient, - const float *momentum, cudaStream_t cuda_stream); + const float *learning_rate, const float *gradient, + const float *momentum, cudaStream_t cuda_stream); template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, - const half *learning_rate, const half *gradient, - const half *momentum, cudaStream_t cuda_stream); + const half *learning_rate, const half *gradient, + const half *momentum, cudaStream_t cuda_stream); template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, - const float *learning_rate, const half *gradient, - const float *momentum, cudaStream_t cuda_stream); + const float *learning_rate, const half *gradient, + const float *momentum, cudaStream_t cuda_stream); template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, - const float *learning_rate, const half *gradient, - const float *momentum, cudaStream_t cuda_stream); + const float *learning_rate, const half *gradient, + const float *momentum, cudaStream_t cuda_stream); + +template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, + float *variable, float *accumulation, const float *learning_rate, + const float *gradient, const float *momentum, cudaStream_t cuda_stream); +template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, + float *variable, float *accumulation, const float *learning_rate, + const half *gradient, const float *momentum, cudaStream_t cuda_stream); +template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, + const float *learning_rate, const float *gradient, const float *momentum, + cudaStream_t cuda_stream); +template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, + const float *learning_rate, const half *gradient, const float *momentum, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh old mode 100755 new mode 100644 index e5a22e4791..00fa7afb2a --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,5 +21,12 @@ template void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, const S *momentum, cudaStream_t cuda_stream); +template +void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum, + cudaStream_t cuda_stream); +template +void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, + const S *gradient, const T *momentum, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.cc new file mode 100644 index 0000000000..b6c1d55533 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat32) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + FusedScaleMomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat16) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + FusedScaleMomentumGpuKernel, float, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h new file mode 100644 index 0000000000..4cfd7d6548 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" +namespace mindspore { +namespace kernel { +template +class FusedScaleMomentumGpuKernel : public GpuKernel { + public: + FusedScaleMomentumGpuKernel() : element_num_(1) {} + ~FusedScaleMomentumGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *scale = GetDeviceAddress(inputs, 0); + T *variable = GetDeviceAddress(inputs, 1); + T *accumulation = GetDeviceAddress(inputs, 2); + T *learning_rate = GetDeviceAddress(inputs, 3); + S *gradient = GetDeviceAddress(inputs, 4); + T *momentum = GetDeviceAddress(inputs, 5); + + FusedScaleMomentum(element_num_, scale, variable, accumulation, learning_rate, gradient, momentum, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 6) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs."; + return false; + } + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + element_num_ *= variable_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(S)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(element_num_ * sizeof(T)); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.cc new file mode 100644 index 0000000000..721c2e7486 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // weight decay + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat32) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + FusedWeightDecayScaleMomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat16) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + FusedWeightDecayScaleMomentumGpuKernel, float, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h new file mode 100644 index 0000000000..22f5a711a1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" +namespace mindspore { +namespace kernel { +template +class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel { + public: + FusedWeightDecayScaleMomentumGpuKernel() : element_num_(1) {} + ~FusedWeightDecayScaleMomentumGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *weight_decay = GetDeviceAddress(inputs, 0); + T *scale = GetDeviceAddress(inputs, 1); + T *variable = GetDeviceAddress(inputs, 2); + T *accumulation = GetDeviceAddress(inputs, 3); + T *learning_rate = GetDeviceAddress(inputs, 4); + S *gradient = GetDeviceAddress(inputs, 5); + T *momentum = GetDeviceAddress(inputs, 6); + + FusedWeightDecayScaleMomentum(element_num_, weight_decay, scale, variable, accumulation, learning_rate, gradient, + momentum, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 7) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 7 inputs."; + return false; + } + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + element_num_ *= variable_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(S)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(element_num_ * sizeof(T)); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc new file mode 100644 index 0000000000..6ce2d3a72a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { + VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); + VectorRef apply_momentum = + VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); + return apply_momentum; +} + +const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto scale = utils::cast((*equiv)[scale_]); + auto variable = utils::cast((*equiv)[variable_]); + auto accumulation = utils::cast((*equiv)[accumulation_]); + auto learning_rate = utils::cast((*equiv)[learning_rate_]); + auto gradient = utils::cast((*equiv)[gradient_]); + auto momentum = utils::cast((*equiv)[momentum_]); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(variable); + MS_EXCEPTION_IF_NULL(accumulation); + MS_EXCEPTION_IF_NULL(learning_rate); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(momentum); + + auto prim = std::make_shared(kFusedScaleApplyMomentum); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), scale, variable, accumulation, + learning_rate, gradient, momentum}; + auto replace_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(replace_node); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get()); + replace_node->set_scope(node->scope()); + return replace_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h new file mode 100644 index 0000000000..c9112ab6e9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ApplyMomentumScaleFusion : public PatternProcessPass { + public: + explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { + scale_ = std::make_shared(); + variable_ = std::make_shared(); + accumulation_ = std::make_shared(); + learning_rate_ = std::make_shared(); + gradient_ = std::make_shared(); + momentum_ = std::make_shared(); + } + ~ApplyMomentumScaleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr scale_; + VarPtr variable_; + VarPtr accumulation_; + VarPtr learning_rate_; + VarPtr gradient_; + VarPtr momentum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc new file mode 100644 index 0000000000..9e235a756f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const { + VectorRef weight = VectorRef( + {prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); + VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_}); + VectorRef apply_momentum = + VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); + return apply_momentum; +} + +const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto weight_decay = utils::cast((*equiv)[weight_decay_]); + auto scale = utils::cast((*equiv)[scale_]); + auto variable = utils::cast((*equiv)[variable_]); + auto accumulation = utils::cast((*equiv)[accumulation_]); + auto learning_rate = utils::cast((*equiv)[learning_rate_]); + auto gradient = utils::cast((*equiv)[gradient_]); + auto momentum = utils::cast((*equiv)[momentum_]); + MS_EXCEPTION_IF_NULL(weight_decay); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(variable); + MS_EXCEPTION_IF_NULL(accumulation); + MS_EXCEPTION_IF_NULL(learning_rate); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(momentum); + + auto prim = std::make_shared(kFusedWeightScaleApplyMomentum); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), weight_decay, scale, variable, + accumulation, learning_rate, gradient, momentum}; + auto replace_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(replace_node); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get()); + replace_node->set_scope(node->scope()); + return replace_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h new file mode 100644 index 0000000000..f047881d81 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { + public: + explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) + : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { + weight_decay_ = std::make_shared(); + scale_ = std::make_shared(); + variable_ = std::make_shared(); + accumulation_ = std::make_shared(); + learning_rate_ = std::make_shared(); + gradient_ = std::make_shared(); + momentum_ = std::make_shared(); + } + ~ApplyMomentumWeightDecayScaleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr weight_decay_; + VarPtr scale_; + + VarPtr variable_; + VarPtr accumulation_; + VarPtr learning_rate_; + VarPtr gradient_; + VarPtr momentum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc new file mode 100644 index 0000000000..cb5f62c668 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc @@ -0,0 +1,175 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const std::vector kOutputIndex{0, 1, 2}; +constexpr size_t kBNGradOutputNum = 3; +constexpr size_t kBNAddReluGradOutputNum = 4; + +bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(bn) == manager->node_users().end()) { + return false; + } + size_t output_num = 0; + for (const auto &node_index : manager->node_users()[bn]) { + const AnfNodePtr &output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getiterm_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); + auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { + return false; + } + bn_outputs->push_back(output); + output_num++; + } + return output_num == kBNGradOutputNum; +} + +void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad) { + // set output shape and dtype + std::vector outputs_type; + std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad); + for (size_t i = 0; i < output_num; ++i) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(bn_grad, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(bn_grad, i)); + } + + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, 0)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(relu_grad, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, bn_add_relu_grad.get()); +} + +void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad, + const CNodePtr &bn_add_relu_grad) { + // Create outputs + std::vector bn_add_relu_grad_output; + CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output); + if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be " + << kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size(); + } + + // Get bn outputs + std::vector bn_outputs; + if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) { + MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx + << " node should only have output 0, 1 and 2. The node should not be changed"; + return; + } + + // Replace orignal output + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); + size_t output_index = 0; + for (const auto &output : bn_outputs) { + (void)manager->Replace(output, bn_add_relu_grad_output[output_index]); + output_index++; + } + + manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]); + return; +} + +} // namespace + +const BaseRef BatchNormAddReluGradFusion::DefinePattern() const { + VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); + VectorRef batch_norm_grad = + VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); + return batch_norm_grad; +} + +const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) { + return nullptr; + } + + auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); + MS_EXCEPTION_IF_NULL(relu_grad); + auto relu_users = GetRealNodeUsedList(graph, relu_grad); + if (relu_users->size() != 2) { + return nullptr; + } + + // process pattern as Relu(TensorAdd(BN#0, BN#1)) + auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast(node), 5); + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto forward_node = AnfAlgo::GetInputNode(utils::cast(tuple_getitem), 0); + if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) { + return nullptr; + } + + auto dy = AnfAlgo::GetInputNode(utils::cast(relu_grad), 0); + MS_EXCEPTION_IF_NULL(dy); + auto y = AnfAlgo::GetInputNode(utils::cast(relu_grad), 1); + MS_EXCEPTION_IF_NULL(y); + auto x = AnfAlgo::GetInputNode(utils::cast(node), 1); + MS_EXCEPTION_IF_NULL(x); + auto scale = AnfAlgo::GetInputNode(utils::cast(node), 2); + MS_EXCEPTION_IF_NULL(scale); + auto save_mean = AnfAlgo::GetInputNode(utils::cast(node), 3); + MS_EXCEPTION_IF_NULL(save_mean); + auto save_var = AnfAlgo::GetInputNode(utils::cast(node), 4); + MS_EXCEPTION_IF_NULL(save_var); + auto reserve = AnfAlgo::GetInputNode(utils::cast(node), 5); + MS_EXCEPTION_IF_NULL(reserve); + auto batch_norm = AnfAlgo::GetInputNode(utils::cast(save_mean), 0); + MS_EXCEPTION_IF_NULL(batch_norm); + auto bias = AnfAlgo::GetInputNode(utils::cast(batch_norm), 2); + MS_EXCEPTION_IF_NULL(bias); + + auto prim = std::make_shared(kFusedBatchNormGradExWithAddAndActivation); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y}; + auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fused_batch_norm_add_relu_grad); + AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad); + SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad); + ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h new file mode 100644 index 0000000000..80a0fc1b67 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormAddReluGradFusion : public PatternProcessPass { + public: + explicit BatchNormAddReluGradFusion(bool multigraph = true) + : PatternProcessPass("batch_norm_add_relu_grad_fusion", multigraph) { + dy_ = std::make_shared(); + y_ = std::make_shared(); + x_ = std::make_shared(); + scale_ = std::make_shared(); + bias_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + save_mean_ = std::make_shared(); + save_var_ = std::make_shared(); + reserve_ = std::make_shared(); + } + ~BatchNormAddReluGradFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr dy_; + VarPtr y_; + VarPtr x_; + VarPtr scale_; + VarPtr bias_; + VarPtr mean_; + VarPtr var_; + VarPtr save_mean_; + VarPtr save_var_; + VarPtr reserve_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index d4475da1c3..08f896e4dd 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -26,11 +26,14 @@ #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" #include "backend/optimizer/gpu/adam_fusion.h" +#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" +#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" #include "backend/optimizer/gpu/replace_bn_cast_fusion.h" #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" +#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" #include "backend/optimizer/gpu/replace_addn_fusion.h" #include "backend/optimizer/gpu/insert_format_transform_op.h" @@ -71,6 +74,8 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto pm = std::make_shared(); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + // pm->AddPass(std::make_shared()); + // pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -79,6 +84,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + // pm->AddPass(std::make_shared()); } optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index da7fa32ae9..6a0b01dfd4 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -192,6 +192,8 @@ constexpr auto kPaddingOpName = "Padding"; constexpr auto kAvgPoolOpName = "AvgPool"; constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; constexpr auto kTensorAddOpName = "TensorAdd"; +constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; +constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; // attr key name constexpr auto kAttrInputNames = "input_names";