[1.1] [project] train imagenet using large batch size (#13766)

* fix nccl2 lars dist support

* put lars in momentum op

* add tests lars

* fix ci

* fix cpu kernel

* soft warning

* remove lars in test_recognize_digits.py

* move to another op

* add file

* update api.spec test=develop

* update test=develop

* fix api.spec test=develop

* wip

* wip, finish grad merge ops

* wip, finish graph build

* wip test running

* work on 1 gpu

* workable version

* update

* fix tests

* fuse broadcast op

* fix compile failed

* refine

* add batch merge test mnist

* fix CI test=develop

* fix build

* use independent bn params for batch merge test=develop

* update api.spec

* follow comments and for test

* wip

* refine tests test=develop

* follow comments test=develop

* remove startup bn modify test=develop

* follow comments test=develop

* fix merge test=develop
fix_recordio_link
Wu Yi 7 years ago committed by GitHub
parent 0a80f06ec4
commit 26200f2e42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -142,5 +142,10 @@ def parse_args():
choices=['reduce', 'all_reduce'], choices=['reduce', 'all_reduce'],
default='all_reduce', default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce') help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--fuse_broadcast_op',
action='store_true',
help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.'
)
args = parser.parse_args() args = parser.parse_args()
return args return args

@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
else: else:
build_strategy.reduce_strategy = fluid.BuildStrategy( build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce ).ReduceStrategy.AllReduce
build_strategy.fuse_broadcast_op = args.fuse_broadcast_op
avg_loss = train_args[0] avg_loss = train_args[0]
@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
if args.use_fake_data or args.use_reader_op: if args.use_fake_data or args.use_reader_op:
try: try:
fetch_ret = exe.run(fetch_list) fetch_ret = exe.run(fetch_list)
except fluid.core.EOFException as eof: except fluid.core.EOFException as eof:
break break

@ -355,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind
paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None) paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None)
paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))

@ -16,12 +16,14 @@ if(WITH_GPU)
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else() else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor) variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif() endif()
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
@ -34,7 +36,7 @@ if(WITH_GPU)
endif() endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU) if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
cc_library(build_strategy SRCS build_strategy.cc DEPS cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass) fuse_elewise_add_act_pass multi_batch_merge_pass)

@ -48,16 +48,23 @@ void BroadcastOpHandle::RunImpl() {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>()); var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
} }
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
}
void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) {
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
InitOutputValue(*in_var_handle, out_var_handles); InitOutputValue(in_var_handle, out_var_handles);
if (platform::is_cpu_place(in_tensor.place())) { if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) { for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) { if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue; continue;
} }
auto &out_p = out_var_handle->place_; auto &out_p = out_var_handle->place_;
@ -114,12 +121,12 @@ void BroadcastOpHandle::RunImpl() {
} }
} }
if (!out_handle->IsTheSameVar(*in_var_handle)) { if (!out_handle->IsTheSameVar(in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_) auto out_var = var_scopes.at(in_var_handle.scope_idx_)
->FindVar(out_var_handles[0]->name_); ->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_, in_tensor, in_var_handle.place_,
*(dev_ctxes_.at(in_var_handle->place_)), *(dev_ctxes_.at(in_var_handle.place_)),
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
} }
}); });

@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private: void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes);
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA

@ -121,6 +121,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass); USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);

@ -69,6 +69,8 @@ struct BuildStrategy {
bool enable_data_balance_{false}; bool enable_data_balance_{false};
bool fuse_broadcast_op_{false};
// User normally doesn't need to call this API. // User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes // The PassBuilder allows for more customized insert, remove of passes
// from python side. // from python side.

@ -0,0 +1,55 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
void FusedBroadcastOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
if (places_.size() == 1UL) return;
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
for (size_t i = 0; i < in_var_handles.size(); ++i) {
BroadcastOneVar(
*in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num),
var_scopes);
}
}
std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; }
} // namespace details
} // namespace framework
} // namespace paddle

@ -0,0 +1,57 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
struct FusedBroadcastOpHandle : public BroadcastOpHandle {
public:
#ifdef PADDLE_WITH_CUDA
FusedBroadcastOpHandle(ir::Node *node,
const std::vector<Scope *> local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctx)
: BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {}
#else
FusedBroadcastOpHandle(ir::Node* node, const std::vector<Scope*> local_scopes,
const std::vector<platform::Place>& places)
: BroadcastOpHandle(node, local_scopes, places) {}
#endif
std::string Name() const override;
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
@ -347,7 +348,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle? // TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
CreateScaleLossGradOp(&result, loss_grad_name); CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]);
} }
// This assumes the backward generating code will ensure IsScaleLossOp // This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss. // is true only for the op that scale the final scalar loss.
@ -436,10 +437,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
if ((use_gpu && if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) { is_dist_train) {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { if (strategy_.fuse_broadcast_op_) {
auto &to_bcast_set = bcast_var_name_set[dev_id]; CreateFusedBroadcastOp(&result, bcast_var_name_set);
for (auto &bcast_name : to_bcast_set) { } else {
CreateBroadcastOp(&result, bcast_name, dev_id); for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
} }
} }
} }
@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
} }
} }
void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
local_scopes_, places_);
#endif
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
}
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) {
auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id];
auto &vars =
result->Get<GraphVars>(kGraphVars).at(out_dev_id).at(p_name);
auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable),
vars.size(), out_dev_id, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
@ -602,7 +645,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name) const { ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
@ -617,10 +661,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
// loss->pending_ops_.emplace_back(op_handle); // loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss); // op_handle->inputs_.emplace_back(loss);
CreateOpOutput( CreateOpOutput(result, op_handle,
result, op_handle, result->CreateVarNode(out_var_node->Var()), places_[i], i);
result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
places_[i], i);
} }
} }

@ -61,7 +61,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(ir::Graph *result, void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name) const; const std::string &loss_grad_name,
ir::Node *out_var_node) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
@ -78,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
void CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID( size_t GetAppropriateDeviceID(

@ -36,6 +36,7 @@ pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN) if(WITH_MKLDNN)

@ -27,14 +27,20 @@ namespace ir {
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
// Make the nodes id start from 0. // Make the nodes id start from 0.
Node::ResetId(); Node::ResetId();
auto var_nodes = InitFromProgram(program_);
ResolveHazard(var_nodes);
}
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
// var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before. // For input args, reuse the same var name if it was created before.
@ -72,7 +78,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var->inputs.push_back(node); var->inputs.push_back(node);
} }
} }
return std::move(var_nodes);
}
void Graph::ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes) {
/** /**
* We should handle write after read(WAR) and write after write(WAW) here. * We should handle write after read(WAR) and write after write(WAW) here.
* Because some of the operators of the program can be executed parallelly. * Because some of the operators of the program can be executed parallelly.
@ -91,6 +101,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
auto it_old = versions.rbegin(); auto it_old = versions.rbegin();
++it_old; ++it_old;
for (; it_old != versions.rend(); it_new = it_old, ++it_old) { for (; it_old != versions.rend(); it_new = it_old, ++it_old) {
VLOG(3) << "deal with var: " << (*it_new)->Name();
ir::Node *write_op = ir::Node *write_op =
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs; const auto &read_ops = (*it_old)->outputs;

@ -160,6 +160,12 @@ class Graph {
return nullptr; return nullptr;
} }
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {

File diff suppressed because it is too large Load Diff

@ -0,0 +1,44 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
// BatchMergePass is used to copy forward and backward ops for several
// times to run several batches to simulate large batch size training
// as if we have more than 1 GPUs.
// User can define how many batches to run, gradients will be merged
// through those repeats, and then do optimization using merged gradients.
// This pass is extremely useful when doing large batch-size distributed
// sync training, we can simulate even large batch size as if we have more
// GPUs.
class BatchMergePass : public Pass {
public:
virtual ~BatchMergePass() {}
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -109,18 +109,9 @@ ParallelExecutor::ParallelExecutor(
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevices(bcast_vars); BCastParamsToDevices(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Create vars in each scope; // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
std::vector<details::VariableInfo> var_infos;
for (auto *var : main_program.Block(0).AllVars()) {
var_infos.emplace_back();
var_infos.back().name_ = var->Name();
var_infos.back().type_ = var->GetType();
var_infos.back().persistable_ = var->Persistable();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
@ -156,6 +147,17 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,

@ -0,0 +1,86 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lars_momentum_op.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace paddle {
namespace operators {
class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001);
AddAttr<float>("lars_weight_decay",
"(float, default 0.0005) LARS weight decay")
.SetDefault(0.0005);
AddComment(R"DOC(
Lars Momentum Optimizer.
This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
weight using a local learning rate:
$$
local\_lr = \eta *
\frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
velocity = mu * velocity +
local\_lr * (grad + \beta * param) \\
param = param - velocity. \\
$$
Note that we use lars_weight_decay here to decay weights, you may need not to
use L2 regularizers in case of using LARS.
)DOC");
}
};
class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::LarsMomentumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(lars_momentum, ops::LarsMomentumOpKernel<float>,
ops::LarsMomentumOpKernel<double>);

@ -0,0 +1,94 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lars_momentum_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, const T lars_coeff,
const T lars_weight_decay, const T* p_norm,
const T* g_norm, T* p_out, T* v_out) {
T lr = learning_rate[0];
T local_lr = learning_rate[0];
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
if (p_norm[0] > 0 && g_norm[0] > 0) {
local_lr = lr * lars_coeff * p_norm[0] /
(g_norm[0] + lars_weight_decay * p_norm[0]);
}
T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
v_out[i] = v_new;
p_out[i] = p[i] - v_new;
}
}
template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
auto eigen_p = framework::EigenVector<T>::Flatten(*param);
auto eigen_g = framework::EigenVector<T>::Flatten(*grad);
// calculate norms using eigein and launch the kernel.
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
ep_norm.device(*place) = eigen_p.square().sum().sqrt();
eg_norm.device(*place) = eigen_g.square().sum().sqrt();
MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
p_norm_data, g_norm_data, p_out, v_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lars_momentum,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);

@ -0,0 +1,72 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class LarsMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad");
// only support dense for now.
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>());
auto grad = ctx.Input<framework::LoDTensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<T>();
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
p_norm_t.mutable_data<T>(ctx.GetPlace());
g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt();
T local_lr = lr[0];
if (ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) /
(eg_norm(0) + lars_weight_decay * ep_norm(0));
}
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out;
}
};
} // namespace operators
} // namespace paddle

@ -19,54 +19,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Velocity"),
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
"Output(VelocityOut) of Momentum should not be null.");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
}
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
"Learning_rate should be a scalar");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MomentumOpInferVarType : public framework::VarTypeInference { class MomentumOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,

@ -28,6 +28,54 @@ using framework::SelectedRows;
struct NoNesterov; struct NoNesterov;
struct UseNesterov; struct UseNesterov;
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Velocity"),
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
"Output(VelocityOut) of Momentum should not be null.");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
}
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
"Learning_rate should be a scalar");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T> template <typename T>
class CPUDenseMomentumFunctor { class CPUDenseMomentumFunctor {
private: private:

@ -645,9 +645,13 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("set_str", [](ir::Pass &self, const std::string &name, .def(
const std::string &attr) { "set_str",
self.Set<std::string>(name, new std::string(attr)); [](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set_int", [](ir::Pass &self, const std::string &name, int val) {
self.Set<const int>(name, new int(val));
}); });
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb( py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(

@ -27,7 +27,7 @@ from . import nn
from . import ops from . import ops
from . import tensor from . import tensor
from ..initializer import init_on_cpu from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter, unique_name from ..framework import default_main_program, Parameter, unique_name, name_scope
__all__ = [ __all__ = [
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
@ -332,14 +332,16 @@ def append_LARS(params_grads, learning_rate, weight_decay):
return grad_norm + weight_decay * param_norm return grad_norm + weight_decay * param_norm
for param, grad in params_grads: for param, grad in params_grads:
param_lr = param.optimize_attr['learning_rate'] with param.block.program.optimized_guard(
param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param))) [param, grad]), name_scope("optimizer"):
grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad))) param_lr = param.optimize_attr['learning_rate']
if type(param_lr) == float and param_lr == 1.0: param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param)))
decayed_lr = learning_rate * param_norm \ grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad)))
/ _balanced_weight(param_norm, grad_norm) if type(param_lr) == float and param_lr == 1.0:
else: decayed_lr = learning_rate * param_norm \
decayed_lr = learning_rate * param_lr * param_norm \ / _balanced_weight(param_norm, grad_norm)
/ _balanced_weight(param_norm, grad_norm) else:
# set back param local learning rate decayed_lr = learning_rate * param_lr * param_norm \
param.optimize_attr['learning_rate'] = decayed_lr / _balanced_weight(param_norm, grad_norm)
# set back param local learning rate
param.optimize_attr['learning_rate'] = decayed_lr

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save