momentum weightdecay fusion

pull/10159/head
wilfChen 4 years ago
parent 27b337a4d2
commit 09e10e18bb

@ -75,9 +75,9 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con
}
template <typename T, typename S>
__global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable,
T *accumulation, const T *learning_rate, const S *gradient,
const T *momentum) {
__global__ void FusedMomentumWeightDecayScaleKernel(const size_t element_num, T *weight_decay, T *scale, T *variable,
T *accumulation, const T *learning_rate, const S *gradient,
const T *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
T grad = (variable[i] * weight_decay[0] + static_cast<T>(gradient[i])) * scale[0];
accumulation[i] = momentum[0] * accumulation[i] + grad;
@ -91,13 +91,13 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T
cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
FusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
FusedMomentumWeightDecayScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum);
}
template <typename T, typename S>
__global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum) {
__global__ void FusedMomentumScaleKernel(const size_t element_num, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]) * scale[0];
variable[i] -= learning_rate[0] * accumulation[i];
@ -109,15 +109,33 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu
const S *gradient, const T *momentum, cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
FusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
FusedMomentumScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
element_num, scale, variable, accumulation, learning_rate, gradient, momentum);
}
template <typename T, typename S>
__global__ void FusedWeightDecayMomentumKernel(const size_t element_num, T *weight_decay, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
T grad = variable[i] * weight_decay[0] + static_cast<T>(gradient[i]);
accumulation[i] = momentum[0] * accumulation[i] + grad;
variable[i] -= learning_rate[0] * accumulation[i];
}
}
template <typename T, typename S>
void FusedWeightDecayMomentum(const size_t element_num, T *weight_decay, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
FusedWeightDecayMomentumKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
element_num, weight_decay, variable, accumulation, learning_rate, gradient, momentum);
}
// CombineFusedScaleMomentum
template <typename T, typename S>
__global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t *element_num,
T **scale, T **variable, T **accumulation,
T **learning_rate, S **gradient, T **momentum) {
__global__ void CombineFusedMomentumScaleKernel(const size_t num, const size_t *element_num, T **scale, T **variable,
T **accumulation, T **learning_rate, S **gradient, T **momentum) {
for (size_t idx = 0; idx < num; idx++) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) {
accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast<T>(gradient[idx][i]) * scale[idx][0];
@ -127,22 +145,21 @@ __global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t
}
template <typename T, typename S>
void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale,
T **variable, T **accumulation, T **learning_rate, S **gradient,
T **momentum, cudaStream_t cuda_stream) {
void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, T **variable,
T **accumulation, T **learning_rate, S **gradient, T **momentum,
cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block;
CombineFusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
CombineFusedMomentumScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
num, elements, scale, variable, accumulation, learning_rate, gradient, momentum);
}
// end CombineFusedScaleMomentum
// CombineFusedWeightDecayScaleMomentum
template <typename T, typename S>
__global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, const size_t *element_num,
T **weight_decay, T **scale, T **variable,
T **accumulation, T **learning_rate, S **gradient,
T **momentum) {
__global__ void CombineFusedMomentumWeightDecayScaleKernel(const size_t num, const size_t *element_num,
T **weight_decay, T **scale, T **variable, T **accumulation,
T **learning_rate, S **gradient, T **momentum) {
for (size_t idx = 0; idx < num; idx++) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) {
T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast<T>(gradient[idx][i])) * scale[idx][0];
@ -155,11 +172,10 @@ __global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, c
template <typename T, typename S>
void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num,
T **weight_decay, T **scale, T **variable, T **accumulation,
T **learning_rate, S **gradient, T **momentum,
cudaStream_t cuda_stream) {
T **learning_rate, S **gradient, T **momentum, cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block;
CombineFusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
CombineFusedMomentumWeightDecayScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum);
}
// end CombineFusedWeightDecayScaleMomentum
@ -186,6 +202,12 @@ template void FusedWeightDecayScaleMomentum(const size_t element_num, float *wei
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale,
float *variable, float *accumulation, const float *learning_rate,
const half *gradient, const float *momentum, cudaStream_t cuda_stream);
template void FusedWeightDecayMomentum(const size_t element_num, float *weight_decay, float *variable,
float *accumulation, const float *learning_rate, const float *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void FusedWeightDecayMomentum(const size_t element_num, float *weight_decay, float *variable,
float *accumulation, const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation,
const float *learning_rate, const float *gradient, const float *momentum,
cudaStream_t cuda_stream);
@ -193,16 +215,16 @@ template void FusedScaleMomentum(const size_t element_num, float *scale, float *
const float *learning_rate, const half *gradient, const float *momentum,
cudaStream_t cuda_stream);
template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements,
float **weight_decay, float **scale, float **variable,
float **accumulation, float **learning_rate, float **gradient,
float **momentum, cudaStream_t cuda_stream);
float **weight_decay, float **scale, float **variable,
float **accumulation, float **learning_rate, float **gradient,
float **momentum, cudaStream_t cuda_stream);
template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements,
float **weight_decay, float **scale, float **variable,
float **accumulation, float **learning_rate, half **gradient,
float **momentum, cudaStream_t cuda_stream);
float **weight_decay, float **scale, float **variable,
float **accumulation, float **learning_rate, half **gradient,
float **momentum, cudaStream_t cuda_stream);
template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale,
float **variable, float **accumulation, float **learning_rate,
float **gradient, float **momentum, cudaStream_t cuda_stream);
float **variable, float **accumulation, float **learning_rate, float **gradient,
float **momentum, cudaStream_t cuda_stream);
template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale,
float **variable, float **accumulation, float **learning_rate,
half **gradient, float **momentum, cudaStream_t cuda_stream);
float **variable, float **accumulation, float **learning_rate, half **gradient,
float **momentum, cudaStream_t cuda_stream);

@ -26,6 +26,9 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T
const T *learning_rate, const S *gradient, const T *momentum,
cudaStream_t cuda_stream);
template <typename T, typename S>
void FusedWeightDecayMomentum(const size_t element_num, T *weight_decay, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream);
template <typename T, typename S>
void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate,
const S *gradient, const T *momentum, cudaStream_t cuda_stream);
template <typename T, typename S>

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(FusedWeightApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // weight decay
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat32) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
FusedWeightDecayMomentumGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(FusedWeightApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat16) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
FusedWeightDecayMomentumGpuKernel, float, half)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class FusedWeightDecayMomentumGpuKernel : public GpuKernel {
public:
FusedWeightDecayMomentumGpuKernel() : element_num_(1) {}
~FusedWeightDecayMomentumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *weight_decay = GetDeviceAddress<T>(inputs, 0);
T *variable = GetDeviceAddress<T>(inputs, 1);
T *accumulation = GetDeviceAddress<T>(inputs, 2);
T *learning_rate = GetDeviceAddress<T>(inputs, 3);
S *gradient = GetDeviceAddress<S>(inputs, 4);
T *momentum = GetDeviceAddress<T>(inputs, 5);
FusedWeightDecayMomentum(element_num_, weight_decay, variable, accumulation, learning_rate, gradient, momentum,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs.";
return false;
}
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < variable_shape.size(); i++) {
element_num_ *= variable_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(S));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(element_num_ * sizeof(T));
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_

@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}
const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const {
VectorRef weight_decay =
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_});
return apply_momentum;
}
const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
MS_EXCEPTION_IF_NULL(weight_decay);
MS_EXCEPTION_IF_NULL(variable);
MS_EXCEPTION_IF_NULL(accumulation);
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(momentum);
auto prim = std::make_shared<Primitive>(kFusedWeightApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, variable, accumulation,
learning_rate, gradient, momentum};
auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
replace_node->set_scope(node->scope());
return replace_node;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
public:
explicit ApplyMomentumWeightDecayFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_fusion", multigraph) {
weight_decay_ = std::make_shared<Var>();
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
momentum_ = std::make_shared<Var>();
}
~ApplyMomentumWeightDecayFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
static bool IsScalar(const BaseRef &n);
VarPtr weight_decay_;
VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr momentum_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_

@ -22,6 +22,7 @@
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
@ -125,6 +126,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}

@ -234,6 +234,7 @@ constexpr auto kReduceMeanOpName = "ReduceMean";
constexpr auto kReduceAnyOpName = "ReduceAny";
constexpr auto kReduceAllOpName = "ReduceAll";
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
constexpr auto kFusedWeightApplyMomentum = "FusedWeightApplyMomentum";
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";
constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore.common.parameter import Parameter
from mindspore import Tensor
from mindspore.ops import operations as P
class MomentumFusionNet(nn.Cell):
def __init__(self, var, accum):
super(MomentumFusionNet, self).__init__()
self.op = P.ApplyMomentum()
self.add = P.AddN()
self.mul = P.Mul()
self.var = Parameter(var, name="variable")
self.accum = Parameter(accum, name="accumulate")
self.lr = 0.1
self.weight_decay = 0.002
self.moment = 0.98
def construct(self, grad):
wd = self.mul(self.var, self.weight_decay)
g = self.add((wd, grad))
return self.op(self.var, self.accum, self.lr, g, self.moment)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_momentum_fusion():
np.random.seed(42)
var = Tensor(np.random.randn(10, 20).astype(np.float32))
accum = Tensor(np.random.randn(10, 20).astype(np.float32))
grad = Tensor(np.random.randn(10, 20).astype(np.float32))
context.set_context(device_target='GPU', mode=context.GRAPH_MODE)
net1 = MomentumFusionNet(var, accum)
_ = net1(grad)
context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE)
net2 = MomentumFusionNet(var, accum)
_ = net2(grad)
assert np.allclose(net1.var.data.asnumpy(), net2.var.data.asnumpy(), atol=1e-5)
assert np.allclose(net1.accum.data.asnumpy(), net2.accum.data.asnumpy(), atol=1e-5)
Loading…
Cancel
Save