!1393 Gpu Support AdamWeightDecay optimizer fusion
Merge pull request !1393 from chenweifeng/adam_weight_decaypull/1393/MERGE
commit
4c6bff75af
@ -0,0 +1,50 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "adam_weight_decay_impl.cuh"
|
||||||
|
#include "device/gpu/cuda_common.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1,
|
||||||
|
const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2,
|
||||||
|
const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v,
|
||||||
|
T *param, T *gradient) {
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) {
|
||||||
|
float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i];
|
||||||
|
float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i];
|
||||||
|
float update = next_m / (sqrt(next_v) + epsilon[0]);
|
||||||
|
if (need_decay && weight_decay != nullptr) {
|
||||||
|
update += weight_decay[0] * param[i];
|
||||||
|
}
|
||||||
|
param[i] -= lr[0] * update;
|
||||||
|
m[i] = next_m;
|
||||||
|
v[i] = next_v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1,
|
||||||
|
const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr,
|
||||||
|
const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) {
|
||||||
|
AdamWeightDecayKernel<<<GET_BLOCKS(element_num_), GET_THREADS, 0, stream>>>(
|
||||||
|
element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param,
|
||||||
|
gradient);
|
||||||
|
}
|
||||||
|
|
||||||
|
template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1,
|
||||||
|
const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2,
|
||||||
|
const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v,
|
||||||
|
float *param, float *gradient, cudaStream_t stream);
|
@ -0,0 +1,24 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
|
||||||
|
template <typename T>
|
||||||
|
void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1,
|
||||||
|
const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr,
|
||||||
|
const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream);
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
|
@ -0,0 +1,52 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "kernel/gpu/nn/fused_adam_weight_decay.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
FusedAdamWeightDecayGpuKernel, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(FusedAdam,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
FusedAdamWeightDecayGpuKernel, float)
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,103 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "kernel/gpu/gpu_kernel.h"
|
||||||
|
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "kernel/gpu/kernel_constants.h"
|
||||||
|
#include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class FusedAdamWeightDecayGpuKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {}
|
||||||
|
~FusedAdamWeightDecayGpuKernel() override = default;
|
||||||
|
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
if (node_name == "AdamWeighDecay") {
|
||||||
|
weight_decay_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7);
|
||||||
|
element_nums_ = 1;
|
||||||
|
for (auto i : shape) {
|
||||||
|
element_nums_ *= i;
|
||||||
|
}
|
||||||
|
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
float *beta1 = GetDeviceAddress<float>(inputs, 0);
|
||||||
|
float *one_sub_beta1 = GetDeviceAddress<float>(inputs, 1);
|
||||||
|
float *beta2 = GetDeviceAddress<float>(inputs, 2);
|
||||||
|
float *one_sub_beta2 = GetDeviceAddress<float>(inputs, 3);
|
||||||
|
float *epsilon = GetDeviceAddress<float>(inputs, 4);
|
||||||
|
float *lr = GetDeviceAddress<float>(inputs, 5);
|
||||||
|
T *param = GetDeviceAddress<T>(inputs, 6);
|
||||||
|
T *m = GetDeviceAddress<T>(inputs, 7);
|
||||||
|
T *v = GetDeviceAddress<T>(inputs, 8);
|
||||||
|
T *gradient = GetDeviceAddress<T>(inputs, 9);
|
||||||
|
float *weight_decay = nullptr;
|
||||||
|
if (weight_decay_) {
|
||||||
|
weight_decay = GetDeviceAddress<float>(inputs, 10);
|
||||||
|
}
|
||||||
|
AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v,
|
||||||
|
param, gradient, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitResource() override{};
|
||||||
|
void InitSizeLists() override {
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
input_size_list_.push_back(element_nums_ * sizeof(T));
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
input_size_list_.push_back(element_nums_ * sizeof(T));
|
||||||
|
if (weight_decay_) {
|
||||||
|
input_size_list_.push_back(sizeof(float));
|
||||||
|
}
|
||||||
|
output_size_list_.push_back(element_nums_ * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
|
||||||
|
int element_nums_;
|
||||||
|
bool weight_decay_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
|
@ -0,0 +1,112 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pre_activate/gpu/adam_fusion.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "ir/primitive.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "pre_activate/common/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||||
|
std::vector<std::string> inputs_format;
|
||||||
|
std::vector<std::string> outputs_format;
|
||||||
|
std::vector<TypeId> inputs_type;
|
||||||
|
std::vector<TypeId> outputs_type;
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||||
|
|
||||||
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||||
|
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||||
|
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||||
|
}
|
||||||
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||||
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||||
|
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||||
|
}
|
||||||
|
builder.SetInputsDeviceType(inputs_type);
|
||||||
|
builder.SetInputsFormat(inputs_format);
|
||||||
|
builder.SetOutputsDeviceType(outputs_type);
|
||||||
|
builder.SetOutputsFormat(outputs_format);
|
||||||
|
return builder.Build();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef AdamFusion::DefinePattern() const {
|
||||||
|
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
|
||||||
|
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||||
|
VectorRef next_v =
|
||||||
|
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||||
|
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||||
|
VectorRef update = VectorRef(
|
||||||
|
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||||
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
|
||||||
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
|
||||||
|
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
|
||||||
|
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
|
||||||
|
return depend3;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(equiv);
|
||||||
|
auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
|
||||||
|
auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
|
||||||
|
auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
|
||||||
|
auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
|
||||||
|
auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
|
||||||
|
auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
|
||||||
|
auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
|
||||||
|
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
|
||||||
|
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
|
||||||
|
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||||
|
MS_EXCEPTION_IF_NULL(beta1_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(beta2_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(eps_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(lr_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(param_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(m_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(v_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(gradient_input);
|
||||||
|
|
||||||
|
auto prim = std::make_shared<Primitive>(kFusedAdamName);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
std::vector<AnfNodePtr> inputs = {
|
||||||
|
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
|
||||||
|
eps_input, lr_input, param_input, m_input, v_input,
|
||||||
|
gradient_input};
|
||||||
|
auto adam = graph->NewCNode(inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(adam);
|
||||||
|
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
|
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
|
||||||
|
adam->set_scope(node->scope());
|
||||||
|
|
||||||
|
auto build_info = GenerateKernelBuildInfo(adam);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
|
||||||
|
return adam;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "pre_activate/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class AdamFusion : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) {
|
||||||
|
beta1_ = std::make_shared<Var>();
|
||||||
|
one_sub_beta1_ = std::make_shared<Var>();
|
||||||
|
beta2_ = std::make_shared<Var>();
|
||||||
|
one_sub_beta2_ = std::make_shared<Var>();
|
||||||
|
eps_ = std::make_shared<Var>();
|
||||||
|
lr_ = std::make_shared<Var>();
|
||||||
|
param_ = std::make_shared<Var>();
|
||||||
|
m_ = std::make_shared<Var>();
|
||||||
|
v_ = std::make_shared<Var>();
|
||||||
|
gradient_ = std::make_shared<Var>();
|
||||||
|
}
|
||||||
|
~AdamFusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarPtr beta1_;
|
||||||
|
VarPtr one_sub_beta1_;
|
||||||
|
VarPtr beta2_;
|
||||||
|
VarPtr one_sub_beta2_;
|
||||||
|
VarPtr eps_;
|
||||||
|
VarPtr lr_;
|
||||||
|
VarPtr param_;
|
||||||
|
VarPtr m_;
|
||||||
|
VarPtr v_;
|
||||||
|
VarPtr gradient_;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
|
@ -0,0 +1,117 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pre_activate/gpu/adam_weight_decay_fusion.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "ir/primitive.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "pre_activate/common/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||||
|
std::vector<std::string> inputs_format;
|
||||||
|
std::vector<std::string> outputs_format;
|
||||||
|
std::vector<TypeId> inputs_type;
|
||||||
|
std::vector<TypeId> outputs_type;
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||||
|
|
||||||
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||||
|
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||||
|
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||||
|
}
|
||||||
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||||
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||||
|
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||||
|
}
|
||||||
|
builder.SetInputsDeviceType(inputs_type);
|
||||||
|
builder.SetInputsFormat(inputs_format);
|
||||||
|
builder.SetOutputsDeviceType(outputs_type);
|
||||||
|
builder.SetOutputsFormat(outputs_format);
|
||||||
|
return builder.Build();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
||||||
|
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
|
||||||
|
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||||
|
VectorRef next_v =
|
||||||
|
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||||
|
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||||
|
VectorRef update = VectorRef(
|
||||||
|
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||||
|
VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
|
||||||
|
|
||||||
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
|
||||||
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||||
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
|
||||||
|
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
|
||||||
|
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
|
||||||
|
return depend3;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &equiv) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(equiv);
|
||||||
|
auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
|
||||||
|
auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
|
||||||
|
auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
|
||||||
|
auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
|
||||||
|
auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
|
||||||
|
auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
|
||||||
|
auto weight_decay_input = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
|
||||||
|
auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
|
||||||
|
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
|
||||||
|
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
|
||||||
|
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||||
|
MS_EXCEPTION_IF_NULL(beta1_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(beta2_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(eps_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(lr_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(weight_decay_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(param_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(m_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(v_input);
|
||||||
|
MS_EXCEPTION_IF_NULL(gradient_input);
|
||||||
|
|
||||||
|
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
std::vector<AnfNodePtr> inputs = {
|
||||||
|
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
|
||||||
|
eps_input, lr_input, param_input, m_input, v_input,
|
||||||
|
gradient_input, weight_decay_input};
|
||||||
|
auto adam_weight_decay = graph->NewCNode(inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(adam_weight_decay);
|
||||||
|
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||||
|
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get());
|
||||||
|
adam_weight_decay->set_scope(node->scope());
|
||||||
|
|
||||||
|
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
|
||||||
|
return adam_weight_decay;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,58 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "pre_activate/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class AdamWeightDecayFusion : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) {
|
||||||
|
beta1_ = std::make_shared<Var>();
|
||||||
|
one_sub_beta1_ = std::make_shared<Var>();
|
||||||
|
beta2_ = std::make_shared<Var>();
|
||||||
|
one_sub_beta2_ = std::make_shared<Var>();
|
||||||
|
eps_ = std::make_shared<Var>();
|
||||||
|
lr_ = std::make_shared<Var>();
|
||||||
|
weight_decay_ = std::make_shared<Var>();
|
||||||
|
param_ = std::make_shared<Var>();
|
||||||
|
m_ = std::make_shared<Var>();
|
||||||
|
v_ = std::make_shared<Var>();
|
||||||
|
gradient_ = std::make_shared<Var>();
|
||||||
|
}
|
||||||
|
~AdamWeightDecayFusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarPtr beta1_;
|
||||||
|
VarPtr one_sub_beta1_;
|
||||||
|
VarPtr beta2_;
|
||||||
|
VarPtr one_sub_beta2_;
|
||||||
|
VarPtr eps_;
|
||||||
|
VarPtr lr_;
|
||||||
|
VarPtr weight_decay_;
|
||||||
|
VarPtr param_;
|
||||||
|
VarPtr m_;
|
||||||
|
VarPtr v_;
|
||||||
|
VarPtr gradient_;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
|
@ -0,0 +1,87 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, decay_flag=True):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.decay_flag = decay_flag
|
||||||
|
self.op_mul = P.Mul()
|
||||||
|
self.op_square = P.Square()
|
||||||
|
self.op_sqrt = P.Sqrt()
|
||||||
|
self.op_cast = P.Cast()
|
||||||
|
self.op_reshape = P.Reshape()
|
||||||
|
self.op_shape = P.Shape()
|
||||||
|
self.param = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='param')
|
||||||
|
self.m = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='m')
|
||||||
|
self.v = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='v')
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, beta1, beta2, gradient, eps, weight_decay_tensor, lr):
|
||||||
|
param_fp32 = self.op_cast(self.param, mstype.float32)
|
||||||
|
m_fp32 = self.op_cast(self.m, mstype.float32)
|
||||||
|
v_fp32 = self.op_cast(self.v, mstype.float32)
|
||||||
|
gradient_fp32 = self.op_cast(gradient, mstype.float32)
|
||||||
|
|
||||||
|
next_m = self.op_mul(beta1, m_fp32) + \
|
||||||
|
self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||||
|
next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - \
|
||||||
|
beta2, self.op_square(gradient_fp32))
|
||||||
|
update = next_m / (eps + self.op_sqrt(next_v))
|
||||||
|
if self.decay_flag:
|
||||||
|
update = self.op_mul(weight_decay_tensor, param_fp32) + update
|
||||||
|
update_with_lr = self.op_mul(lr, update)
|
||||||
|
next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32))
|
||||||
|
|
||||||
|
next_v = F.depend(next_v, F.assign(self.param, next_param))
|
||||||
|
next_v = F.depend(next_v, F.assign(self.m, next_m))
|
||||||
|
next_v = F.depend(next_v, F.assign(self.v, next_v))
|
||||||
|
return next_v
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test():
|
||||||
|
beta1 = Tensor(np.array([0.9]).astype(np.float32))
|
||||||
|
beta2 = Tensor(np.array([0.999]).astype(np.float32))
|
||||||
|
lr = Tensor(np.array([0.001]).astype(np.float32))
|
||||||
|
eps = Tensor(np.array([1e-6]).astype(np.float32))
|
||||||
|
weight_decay_tensor = Tensor(np.array([0.001]).astype(np.float32))
|
||||||
|
|
||||||
|
gradient = Tensor(np.array([0.01, 0.03, 0.05]).astype(np.float32))
|
||||||
|
opt = Net(True)
|
||||||
|
_ = opt(beta1, beta2, gradient, eps, weight_decay_tensor, lr)
|
||||||
|
|
||||||
|
param_expect = np.array([0.09971199, 0.29950103, 0.4993557]).astype(np.float32)
|
||||||
|
m_expect = np.array([0.091, 0.273, 0.45499998]).astype(np.float32)
|
||||||
|
v_expect = np.array([0.0999001, 0.29970092, 0.4995025]).astype(np.float32)
|
||||||
|
assert np.allclose(opt.param.data.asnumpy(), param_expect)
|
||||||
|
assert np.allclose(opt.m.data.asnumpy(), m_expect)
|
||||||
|
assert np.allclose(opt.v.data.asnumpy(), v_expect)
|
Loading…
Reference in new issue