!12700 add grad accumulation combined with optimizer parallel

From: @yangzhenzhang
Reviewed-by: 
Signed-off-by:
pull/12700/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7fcce73c51

@ -86,6 +86,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate", mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
prim::kPrimMirrorMiniStep); prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
check_bprop_eliminate_ = check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = reset_defer_inline_ =

@ -52,6 +52,7 @@ class OptimizeIRPassLib {
SubstitutionPtr depend_value_elim_; SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_; SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_; SubstitutionPtr mirror_mini_step_elim_;
SubstitutionPtr mini_step_allgather_replace_;
// Env Item Eliminate // Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr env_get_item_eliminate_;

@ -33,6 +33,7 @@
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/resolve.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -155,7 +156,7 @@ class CheckBpropEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}; AnfNodePtr x_{nullptr};
}; };
// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X // {prim::kPrimMirrorMiniStep, X, Z} -> X
class MirrorMiniStepEliminater : public AnfVisitor { class MirrorMiniStepEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -163,11 +164,7 @@ class MirrorMiniStepEliminater : public AnfVisitor {
return nullptr; return nullptr;
} }
auto cnode = node->cast<CNodePtr>(); auto &inputs = node->cast<CNodePtr>()->inputs();
if (cnode == nullptr) {
return nullptr;
}
auto inputs = cnode->inputs();
if (inputs.size() < 2) { if (inputs.size() < 2) {
return nullptr; return nullptr;
} }
@ -178,6 +175,32 @@ class MirrorMiniStepEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {} void Visit(const AnfNodePtr &) override {}
}; };
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
class MiniStepAllGatherPass : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
std::string group = attrs[parallel::GROUP]->ToString();
parallel::Operator op = parallel::CreateAllGatherOp(group);
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
auto func_graph = inputs[1]->func_graph();
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;
}
void Visit(const AnfNodePtr &) override {}
};
// Reset defer_inline flag // Reset defer_inline flag
class ResetDeferInline : public AnfVisitor { class ResetDeferInline : public AnfVisitor {
public: public:
@ -328,6 +351,80 @@ class PynativeEliminater : public OptimizerCaller {
return out; return out;
} }
private:
AnfNodePtr OperatorHandle1(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) {
auto rep = (arg).GetNode(node);
if (rep != nullptr) {
if (rep->isa<ValueNode>()) {
auto value_node = rep->cast<ValueNodePtr>();
auto new_value_node = NewValueNode(FillZero(value_node->value()));
new_value_node->set_has_new_value(value_node->has_new_value());
MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4);
return new_value_node;
}
}
return nullptr;
}
AnfNodePtr OperatorHandle2(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) {
auto rep = (arg).GetNode(node);
if (rep != nullptr) {
if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) {
auto value_node = rep->cast<ValueNodePtr>();
auto new_value_node = NewValueNode(FillZero(value_node->value()));
new_value_node->set_has_new_value(value_node->has_new_value());
MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4);
return new_value_node;
}
}
return nullptr;
}
void OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> &args, const AnfNodePtr &node) {
for (size_t i = 0; i < 2; i++) {
auto rep = (args[i]).GetNode(node);
if (rep != nullptr && rep->isa<ValueNode>()) {
auto value_node = rep->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto &value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
// when the use count of value node equals to one, it only used in binop_grad_common function
if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
value_node->set_value(new_tensor);
}
}
}
}
AnfNodePtr OperatorHandle4(const PatternNode<AnfNodePtr> &arg, const PatternNode<AnfNodePtr> &arg1,
const AnfNodePtr &node) {
auto rep = (arg).GetNode(node);
if (rep != nullptr) {
if (rep->isa<ValueNode>()) {
MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4);
ValueNodePtr new_node;
auto value_node = rep->cast<ValueNodePtr>();
auto rep1 = (arg1).GetNode(node);
if (rep1 != nullptr) {
if (rep1->isa<ValueNode>()) {
auto idx = rep1->cast<ValueNodePtr>();
if (!value_node->value()->isa<ValueTuple>()) {
return nullptr;
}
new_node = NewValueNode(FillGetItem(value_node->value(), idx->value()));
new_node->set_has_new_value(value_node->has_new_value());
}
}
MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4);
return new_node;
}
}
return nullptr;
}
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4);
@ -342,15 +439,9 @@ class PynativeEliminater : public OptimizerCaller {
if ((pattern).TryCapture(node) && if ((pattern).TryCapture(node) &&
(CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
auto rep = (arg).GetNode(node); auto new_value_node = OperatorHandle1(arg, node);
if (rep != nullptr) { if (new_value_node != nullptr) {
if (rep->isa<ValueNode>()) { return new_value_node;
auto value_node = rep->cast<ValueNodePtr>();
auto new_value_node = NewValueNode(FillZero(value_node->value()));
new_value_node->set_has_new_value(value_node->has_new_value());
MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4);
return new_value_node;
}
} }
} }
MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4);
@ -360,15 +451,9 @@ class PynativeEliminater : public OptimizerCaller {
if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
auto rep = (arg).GetNode(node); auto new_value_node = OperatorHandle2(arg, node);
if (rep != nullptr) { if (new_value_node != nullptr) {
if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) { return new_value_node;
auto value_node = rep->cast<ValueNodePtr>();
auto new_value_node = NewValueNode(FillZero(value_node->value()));
new_value_node->set_has_new_value(value_node->has_new_value());
MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4);
return new_value_node;
}
} }
} }
// {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout}
@ -379,22 +464,7 @@ class PynativeEliminater : public OptimizerCaller {
auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]);
if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) {
for (size_t i = 0; i < 2; i++) { OperatorHandle3(args, node);
auto rep = (args[i]).GetNode(node);
if (rep != nullptr && rep->isa<ValueNode>()) {
auto value_node = rep->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto &value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
// when the use count of value node equals to one, it only used in binop_grad_common function
if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
value_node->set_value(new_tensor);
}
}
}
return nullptr; return nullptr;
} }
// resolve(CommonOPS, getitem)((tensors), 3) // resolve(CommonOPS, getitem)((tensors), 3)
@ -403,26 +473,9 @@ class PynativeEliminater : public OptimizerCaller {
auto pattern2 = PCNode(resolve2, arg, arg1); auto pattern2 = PCNode(resolve2, arg, arg1);
if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") &&
CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) {
auto rep = (arg).GetNode(node); auto new_value_node = OperatorHandle4(arg, arg1, node);
if (rep != nullptr) { if (new_value_node != nullptr) {
if (rep->isa<ValueNode>()) { return new_value_node;
MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4);
ValueNodePtr new_node;
auto value_node = rep->cast<ValueNodePtr>();
auto rep1 = (arg1).GetNode(node);
if (rep1 != nullptr) {
if (rep1->isa<ValueNode>()) {
auto idx = rep1->cast<ValueNodePtr>();
if (!value_node->value()->isa<ValueTuple>()) {
return nullptr;
}
new_node = NewValueNode(FillGetItem(value_node->value(), idx->value()));
new_node->set_has_new_value(value_node->has_new_value());
}
}
MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4);
return new_node;
}
} }
} }

@ -153,25 +153,27 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
} }
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { if (func_graph->has_flag(AUTO_PARALLEL) &&
return; (!func_graph->has_flag(TRAINING) ||
(ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) {
init_param_shape_ = false;
} else {
param_shapes.clear();
init_param_shape_ = true;
} }
param_shapes.clear();
} }
// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node, void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
AbstractBasePtr ptr) { const ParameterPtr &param_node, AbstractBasePtr ptr) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || if (init_param_shape_) {
func_graph->has_flag(TRAINING)) {
return; return;
} }
auto iter = param_shapes.find(param_node->name()); auto iter = param_shapes.find(param_node->name());
if (iter == param_shapes.end()) { if (iter == param_shapes.end()) {
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
@ -183,16 +185,16 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph,
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
} }
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node, void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr) { const AbstractBasePtr &ptr) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { if (!init_param_shape_) {
return; return;
} }
std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape(); std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
auto ret = param_shapes.try_emplace(param_node->name(), shape); auto ret = param_shapes.try_emplace(param_node->name(), shape);
if (!ret.second) { if (!ret.second) {

@ -30,6 +30,7 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#include "utils/info.h" #include "utils/info.h"
#include "pipeline/jit/pipeline.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
@ -43,6 +44,7 @@ constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
constexpr char TRAINING[] = "training"; constexpr char TRAINING[] = "training";
constexpr char ACCUMULATION[] = "accumulation";
class ParallelContext { class ParallelContext {
public: public:
@ -111,6 +113,11 @@ class ParallelContext {
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void Reset(); void Reset();
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr);
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr);
private: private:
ParallelContext(); ParallelContext();
@ -136,13 +143,9 @@ class ParallelContext {
std::string strategy_ckpt_save_file_; std::string strategy_ckpt_save_file_;
std::string group_ckpt_save_file_; std::string group_ckpt_save_file_;
bool enable_parallel_optimizer_; bool enable_parallel_optimizer_;
bool init_param_shape_;
}; };
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr);
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -284,6 +284,39 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
return op; return op;
} }
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(comm_node);
MS_EXCEPTION_IF_NULL(param_node);
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
return;
}
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto param_info = param->param_info();
if (!param_info) {
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
return;
}
int32_t fusion_type = param_info->comm_fusion();
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
void AddCommOpMeanFlag(const CNodePtr &comm_node) {
MS_EXCEPTION_IF_NULL(comm_node);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
}
Operator CreateAllGatherOp(const std::string &group) { Operator CreateAllGatherOp(const std::string &group) {
OperatorName operator_name = ALL_GATHER; OperatorName operator_name = ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group ValuePtr attr0_value = MakeValue(group); // group
@ -299,6 +332,30 @@ Operator CreateAllGatherOp(const std::string &group) {
return op; return op;
} }
Operator CreateMiniStepAllGatherOp(const std::string &group) {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
OperatorName operator_name = MINI_STEP_ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group
Attr attr0 = std::make_pair(GROUP, attr0_value);
ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step
Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value);
ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag
Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
OperatorAttrs operator_attrs;
operator_attrs.push_back(attr0);
operator_attrs.push_back(attr1);
operator_attrs.push_back(attr2);
OperatorParams operator_param;
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_arg);
MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group;
return op;
}
// use for get tensor slice // use for get tensor slice
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
Shape tensor_map = tensor_layout.tensor_map().array(); Shape tensor_map = tensor_layout.tensor_map().array();
@ -771,7 +828,7 @@ void OperatorInfo::ComputeBatchSplitFlagList() {
ReComputeBatchSplitFlagList(); ReComputeBatchSplitFlagList();
} }
// This is a common method for checking whether the generated stragegy has the correct number of devuces. // This is a common method for checking whether the generated strategy has the correct number of devuces.
Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) {
if (sp == nullptr) { if (sp == nullptr) {
MS_LOG(ERROR) << "The strategy is null."; MS_LOG(ERROR) << "The strategy is null.";
@ -886,7 +943,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs
(void)input0_strategy.erase(input0_strategy.begin(), (void)input0_strategy.erase(input0_strategy.begin(),
input0_strategy.begin() + static_cast<different_type>(size_diff)); input0_strategy.begin() + static_cast<different_type>(size_diff));
// handel the case likes ([1, c, d], [a, b, c, d]) // handle the case likes ([1, c, d], [a, b, c, d])
for (size_t i = 0; i < inputs_shape[0].size(); ++i) { for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
if (inputs_shape[0][i] == 1) { if (inputs_shape[0][i] == 1) {
input0_strategy[i] = 1; input0_strategy[i] = 1;
@ -937,7 +994,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input
(void)input1_strategy.erase(input1_strategy.begin(), (void)input1_strategy.erase(input1_strategy.begin(),
input1_strategy.begin() + static_cast<different_type>(size_diff)); input1_strategy.begin() + static_cast<different_type>(size_diff));
// handel the case likes ([a, b, c, d], [1, c, d]) // handle the case likes ([a, b, c, d], [1, c, d])
for (size_t i = 0; i < inputs_shape[1].size(); ++i) { for (size_t i = 0; i < inputs_shape[1].size(); ++i) {
if (inputs_shape[1][i] == 1) { if (inputs_shape[1][i] == 1) {
input1_strategy[i] = 1; input1_strategy[i] = 1;

@ -36,6 +36,7 @@
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_info.h" #include "frontend/parallel/tensor_layout/tensor_info.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
@ -160,7 +161,7 @@ class OperatorInfo {
void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } const std::string &refkey_parameter_name() const { return refkey_parameter_name_; }
// When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
// multiple times. This method is to correct this, and makes the cost is calulated only once. // multiple times. This method is to correct this, and makes the cost is calculated only once.
Status CorrectMemoryCost(size_t input_index); Status CorrectMemoryCost(size_t input_index);
int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; } int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; }
int64_t is_output_critical() const { return is_output_critical_; } int64_t is_output_critical() const { return is_output_critical_; }
@ -242,7 +243,7 @@ class OperatorInfo {
bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel
// 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected.
std::vector<size_t> corrected_input_indices_; std::vector<size_t> corrected_input_indices_;
// Given a parallization strategy, there is a cost. // Given a parallelization strategy, there is a cost.
std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_; std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_;
// For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
std::vector<bool> is_parameter_; std::vector<bool> is_parameter_;
@ -288,6 +289,9 @@ Operator CreateVirtualDivOp(int64_t div_num);
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
Operator CreateAllGatherOp(const std::string &group); Operator CreateAllGatherOp(const std::string &group);
Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
void AddCommOpMeanFlag(const CNodePtr &comm_node);
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);

@ -109,6 +109,7 @@ constexpr char END[] = "end";
constexpr char STRIDES[] = "strides"; constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group"; constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion"; constexpr char FUSION[] = "fusion";
constexpr char DO_MIRROR[] = "do_mirror";
constexpr char NUM_SAMPLED[] = "num_sampled"; constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true"; constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed"; constexpr char SEED[] = "seed";
@ -180,6 +181,7 @@ constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator";
constexpr char LOCAL_STEP[] = "local_step"; constexpr char LOCAL_STEP[] = "local_step";
constexpr char STRIDED_SLICE[] = "StridedSlice"; constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather"; constexpr char ALL_GATHER[] = "AllGather";
constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter"; constexpr char REDUCE_SCATTER[] = "ReduceScatter";
constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter";
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";

@ -66,8 +66,8 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
return; return;
} }
ValueNodePtr prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs(); auto attrs = prim->attrs();
@ -84,6 +84,19 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
} }
} }
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
if (new_node_input.empty()) {
return;
}
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
prim->SetAttrs(attrs);
}
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
OperatorArgs arg_forward = op.second; OperatorArgs arg_forward = op.second;
@ -158,7 +171,6 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(root->manager()); MS_EXCEPTION_IF_NULL(root->manager());
AnfNodePtr local_step_param = nullptr;
AnfNodePtr grad_accu = nullptr; AnfNodePtr grad_accu = nullptr;
std::string op_name = op.first; std::string op_name = op.first;
OperatorArgs arg_forward = op.second; OperatorArgs arg_forward = op.second;
@ -166,25 +178,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
if (grad_accumulation_step > 1) { if (grad_accumulation_step > 1) {
bool find_locat_step_node = false;
auto parameters = root->parameters(); auto parameters = root->parameters();
for (auto &param : parameters) {
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == LOCAL_STEP) {
auto param_users = root->manager()->node_users()[param];
for (auto &user : param_users) {
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
find_locat_step_node = true;
local_step_param = user.first;
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
break;
}
}
break;
}
}
bool find_grad_accu_node = false; bool find_grad_accu_node = false;
for (auto &param : parameters) { for (auto &param : parameters) {
if (!ParameterIsCloned(param)) { if (!ParameterIsCloned(param)) {
@ -202,10 +196,12 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
} }
} }
if (op_name == MIRROR_MINI_STEP_OPERATOR) { if (!find_grad_accu_node) {
if (!find_locat_step_node || !find_grad_accu_node) { if (op_name == MIRROR_MINI_STEP_OPERATOR) {
op_name = MIRROR_OPERATOR; op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back(); arg_forward.first.pop_back();
} else if (op_name == MINI_STEP_ALL_GATHER) {
MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation.";
} }
} }
} }
@ -215,9 +211,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
OperatorParams params = arg_forward.second; OperatorParams params = arg_forward.second;
std::vector<AnfNodePtr> new_node_input; std::vector<AnfNodePtr> new_node_input;
if (op_name == MIRROR_MINI_STEP_OPERATOR) { if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) {
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
} else { } else {
new_node_input = {NewValueNode(pyop_instance), node}; new_node_input = {NewValueNode(pyop_instance), node};
} }
@ -233,6 +229,10 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
// if the op have 'group' attr, set the rank list name for the op // if the op have 'group' attr, set the rank list name for the op
SetCommunicationOpGroupLabel(new_node_input); SetCommunicationOpGroupLabel(new_node_input);
// gradient accumulation
if (grad_accumulation_step > 1) {
SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
}
return new_node_input; return new_node_input;
} }
@ -285,6 +285,31 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
return new_node; return new_node;
} }
// Replace pre_node with pre_node->op
static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node,
const FuncGraphPtr &func_graph, const std::string &instance_name,
const std::string &param_name) {
// insert new node before the node
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ScopePtr scope = pre_node->scope();
MS_EXCEPTION_IF_NULL(scope);
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
CNodePtr new_node = func_graph->NewCNode(node_input);
MS_EXCEPTION_IF_NULL(new_node);
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
new_node->set_in_forward_flag(true); // mark forward flag
}
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->Replace(pre_node, new_node);
MS_LOG(INFO) << "Insert " << instance_name << " success";
return new_node;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) { std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) { if (!IsValueNode<Primitive>(node->input(0))) {
@ -1086,29 +1111,6 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
return (type_id != kNumberTypeFloat32); return (type_id != kNumberTypeFloat32);
} }
static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(comm_node);
MS_EXCEPTION_IF_NULL(param_node);
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
return;
}
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto param_info = param->param_info();
if (!param_info) {
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
return;
}
int32_t fusion_type = param_info->comm_fusion();
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) { if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it."; MS_LOG(INFO) << "Input is ValueList, skip it.";
@ -1195,7 +1197,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag // add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first); AddCommOpFusionType(comm_op, param_node_pair.first);
} }
continue; continue;
@ -1540,33 +1541,40 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
return std::make_pair(nullptr, 0); return std::make_pair(nullptr, 0);
} }
static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
const AnfNodePtr &parameter) { const AnfNodePtr &node) {
Operator op = CreateAllGatherOp(group);
MS_EXCEPTION_IF_NULL(res.first); MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(node);
auto cnode = res.first->cast<CNodePtr>(); auto cnode = res.first->cast<CNodePtr>();
auto graph = cnode->func_graph(); auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(cnode_prim); MS_EXCEPTION_IF_NULL(cnode_prim);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
Operator op;
CNodePtr allgather; CNodePtr allgather;
if (cnode_prim->name() == CAST) { if (grad_accumulation_step > 1) {
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); op = CreateMiniStepAllGatherOp(group);
auto param_name = node->cast<ParameterPtr>()->name();
if (cnode_prim->name() == CAST) {
allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
} else {
InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
} else { } else {
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); op = CreateAllGatherOp(group);
allgather = cnode->input(res.second)->cast<CNodePtr>(); if (cnode_prim->name() == CAST) {
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER);
} else {
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
} }
MS_EXCEPTION_IF_NULL(allgather);
// add fusion flag // add fusion flag
AddCommOpFusionType(allgather, parameter); AddCommOpFusionType(allgather, node);
// add gradients mean // add gradients mean
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); AddCommOpMeanFlag(allgather);
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
} }
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter, static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
@ -1589,7 +1597,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
<< distribute_operator->inputs_tensor_info().size(); << distribute_operator->inputs_tensor_info().size();
} }
// insert allgather operator between shard parameter and cnode // insert allgather operator between shard parameter and cnode
InsertAllGatherOp(opt_shard_group, param_pair, parameter); InsertAllGatherOp(root, opt_shard_group, param_pair, parameter);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
} }
} }
@ -1734,12 +1742,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) { if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter // set the shape and tensor layout for cloned parameter
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract); MS_EXCEPTION_IF_NULL(cloned_abstract);
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); if (param_name.find(ACCU_GRADS) != std::string::npos) {
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
cloned_abstract->set_shape(parallel_shape);
} else {
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
}
cloned_parameter_node->set_abstract(cloned_abstract); cloned_parameter_node->set_abstract(cloned_abstract);
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()

@ -30,6 +30,8 @@
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/pipeline.h" #include "pipeline/jit/pipeline.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;

@ -258,9 +258,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec = res->args_spec(); abstract::AbstractBasePtrList args_spec = res->args_spec();
auto context = parallel::ParallelContext::GetInstance();
parallel::ParallelParameterContextInit(func_graph); MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
context->ParallelParameterContextInitShape(func_graph);
// suppose that there is not KeywordArgument for the top graph // suppose that there is not KeywordArgument for the top graph
// get the hyper parameter // get the hyper parameter
for (const auto &param : func_graph->parameters()) { for (const auto &param : func_graph->parameters()) {
@ -271,9 +271,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
auto ref_key = std::make_shared<RefKey>(param_node->name()); auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref_key = ref_key->ToAbstract(); auto abs_ref_key = ref_key->ToAbstract();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value); auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
args_spec.push_back(abs_ref); args_spec.push_back(abs_ref);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
} }
} }
// Analyze // Analyze

@ -160,6 +160,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.replace_applicator_, irpass.replace_applicator_,
irpass.mirror_mini_step_elim_, irpass.mirror_mini_step_elim_,
irpass.row_tensor_add_zeros_like_, irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
}); });
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);

@ -374,7 +374,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape")) .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape"))
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) .def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
.def(py::pickle( .def(py::pickle(
[](const MetaTensor &t) { // __getstate__ [](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */

@ -134,7 +134,7 @@ class Parameter(Tensor_):
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
self._param_info = ParamInfo() self.param_info = ParamInfo()
self.init_in_server = False self.init_in_server = False
self.cache_enable = False self.cache_enable = False
self.name = name self.name = name
@ -230,7 +230,7 @@ class Parameter(Tensor_):
"sparse operator support initialization in server.".format(self.name)) "sparse operator support initialization in server.".format(self.name))
self.is_param_ps = True self.is_param_ps = True
self.init_in_server = init_in_server self.init_in_server = init_in_server
self._param_info.init_in_server = init_in_server self.param_info.init_in_server = init_in_server
@property @property
def inited_param(self): def inited_param(self):
@ -245,7 +245,7 @@ class Parameter(Tensor_):
@property @property
def name(self): def name(self):
"""Get the name of the parameter.""" """Get the name of the parameter."""
return self._param_info.name return self.param_info.name
@name.setter @name.setter
def name(self, name_): def name(self, name_):
@ -272,9 +272,9 @@ class Parameter(Tensor_):
if len(self.shape) != 2: if len(self.shape) != 2:
raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
.format(self.name, len(self.shape))) .format(self.name, len(self.shape)))
_reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1])
self._param_info.name = name_ self.param_info.name = name_
@property @property
def sliced(self): def sliced(self):
@ -288,12 +288,12 @@ class Parameter(Tensor_):
@property @property
def comm_fusion(self): def comm_fusion(self):
"""Get the fusion type for communication operators corresponding to this parameter.""" """Get the fusion type for communication operators corresponding to this parameter."""
return self._param_info.comm_fusion return self.param_info.comm_fusion
@comm_fusion.setter @comm_fusion.setter
def comm_fusion(self, comm_fusion_): def comm_fusion(self, comm_fusion_):
"""Set the fusion type for communication operators corresponding to this parameter.""" """Set the fusion type for communication operators corresponding to this parameter."""
self._param_info.comm_fusion = comm_fusion_ self.param_info.comm_fusion = comm_fusion_
@property @property
def unique(self): def unique(self):
@ -339,7 +339,7 @@ class Parameter(Tensor_):
""" """
x = copy(self) x = copy(self)
# pylint: disable=protected-access # pylint: disable=protected-access
x._param_info = self._param_info.clone() x.param_info = self.param_info.clone()
x.is_init = False x.is_init = False
x.init = self.init x.init = self.init
x.is_param_ps = self.is_param_ps x.is_param_ps = self.is_param_ps
@ -355,57 +355,57 @@ class Parameter(Tensor_):
@property @property
def layerwise_parallel(self): def layerwise_parallel(self):
return self._param_info.layerwise_parallel return self.param_info.layerwise_parallel
@layerwise_parallel.setter @layerwise_parallel.setter
def layerwise_parallel(self, value=True): def layerwise_parallel(self, value=True):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("`layerwise_parallel` parameter must be bool type") raise TypeError("`layerwise_parallel` parameter must be bool type")
self._param_info.layerwise_parallel = value self.param_info.layerwise_parallel = value
@property @property
def parallel_optimizer(self): def parallel_optimizer(self):
"""Return whether the parameter requires weight shard for parallel optimizer.""" """Return whether the parameter requires weight shard for parallel optimizer."""
return self._param_info.parallel_optimizer return self.param_info.parallel_optimizer
@parallel_optimizer.setter @parallel_optimizer.setter
def parallel_optimizer(self, value=True): def parallel_optimizer(self, value=True):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("`parallel_optimizer` parameter must be bool type") raise TypeError("`parallel_optimizer` parameter must be bool type")
self._param_info.parallel_optimizer = value self.param_info.parallel_optimizer = value
@property @property
def cache_enable(self): def cache_enable(self):
"""Return whether the parameter is cache enable.""" """Return whether the parameter is cache enable."""
return self._param_info.cache_enable return self.param_info.cache_enable
@cache_enable.setter @cache_enable.setter
def cache_enable(self, value=True): def cache_enable(self, value=True):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("`cache_enable` parameter must be bool type") raise TypeError("`cache_enable` parameter must be bool type")
self._param_info.cache_enable = value self.param_info.cache_enable = value
@property @property
def cache_shape(self): def cache_shape(self):
"""Return the cache shape corresponding to the parameter if use cache.""" """Return the cache shape corresponding to the parameter if use cache."""
return self._param_info.cache_shape return self.param_info.cache_shape
@cache_shape.setter @cache_shape.setter
def cache_shape(self, value): def cache_shape(self, value):
if not isinstance(value, (tuple, list)): if not isinstance(value, (tuple, list)):
raise TypeError("`cache_shape` parameter must be tuple or list type") raise TypeError("`cache_shape` parameter must be tuple or list type")
self._param_info.cache_shape = value self.param_info.cache_shape = value
@property @property
def requires_grad(self): def requires_grad(self):
"""Return whether the parameter requires gradient.""" """Return whether the parameter requires gradient."""
return self._param_info.requires_grad return self.param_info.requires_grad
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value=True): def requires_grad(self, value=True):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("`requires_grad` parameter must be bool type") raise TypeError("`requires_grad` parameter must be bool type")
self._param_info.requires_grad = value self.param_info.requires_grad = value
@property @property
def data(self): def data(self):
@ -419,7 +419,9 @@ class Parameter(Tensor_):
self.init = None self.init = None
return self.assign_value(data) return self.assign_value(data)
# create a new tensor # create a new tensor
return Parameter(data, self.name, self.requires_grad) new_param = Parameter(data, self.name, self.requires_grad)
new_param.param_info = self.param_info
return new_param
def set_data(self, data, slice_shape=False): def set_data(self, data, slice_shape=False):
""" """

@ -306,6 +306,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitiv
// Comm ops // Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");

@ -20,7 +20,7 @@ from mindspore.communication import get_rank, get_group_size
from .. import operations as P from .. import operations as P
from ...common.tensor import RowTensor from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters from .grad_base import bprop_getters
@ -150,6 +150,39 @@ def get_bprop_all_gather(self):
return bprop return bprop
@bprop_getters.register(_MiniStepAllGather)
def get_bprop_mini_step_all_gather(self):
"""Generate bprop for _MiniStepAllGather"""
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
do_mirror = self.get_attr_dict()["do_mirror"]
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
def bprop(x, z, out, dout):
if do_mirror:
if mean_flag:
tmp = z + dout
grad = all_reduce(tmp)
dx = split(grad)[rank]
dx = F.tensor_mul(dx, scale)
else:
tmp = z + dout
grad = all_reduce(tmp)
dx = split(grad)[rank]
else:
dx = dout
return (dx, zeros_like(z))
return bprop
@bprop_getters.register(_HostAllGather) @bprop_getters.register(_HostAllGather)
def get_bprop_host_all_gather(self): def get_bprop_host_all_gather(self):
"""Generate bprop for _HostAllGather""" """Generate bprop for _HostAllGather"""
@ -291,18 +324,13 @@ def get_bprop_mirror_mini_step_operator(self):
group = self.group group = self.group
dev_num = self.dev_num dev_num = self.dev_num
mean_flag = self.mean_flag mean_flag = self.mean_flag
grad_accumulation_step = self.grad_accumulation_step
all_reduce = AllReduce(group=group) all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group) all_gather = AllGather(group=group)
mul = P.Mul() mul = P.Mul()
cast = P.Cast() cast = P.Cast()
equal = P.Equal()
reshape = P.Reshape()
fusion = 1 fusion = self.get_attr_dict()["fusion"]
if hasattr(self, 'fusion'):
fusion = self.fusion
all_reduce.add_prim_attr("fusion", fusion) all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'): if hasattr(self, 'parameter'):
parameter = self.parameter parameter = self.parameter
@ -311,16 +339,15 @@ def get_bprop_mirror_mini_step_operator(self):
if self.instance_name: if self.instance_name:
instance_name = "grad_mirror" + self.instance_name instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name) all_reduce.set_prim_instance_name(instance_name)
do_mirror = self.get_attr_dict()["do_mirror"]
def bprop(x, y, z, out, dout): def bprop(x, z, out, dout):
do_mirror = equal(y, grad_accumulation_step)
do_mirror = reshape(do_mirror, (()))
if mean_flag: if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror: if do_mirror:
tmp = z + dout tmp = z + dout
real_grad = all_reduce(tmp) real_grad = all_reduce(tmp)
dx = real_grad - z dx = real_grad
else: else:
dx = dout dx = dout
float_one = F.scalar_cast(1.0, F.dtype(dx)) float_one = F.scalar_cast(1.0, F.dtype(dx))
@ -342,7 +369,7 @@ def get_bprop_mirror_mini_step_operator(self):
if do_mirror: if do_mirror:
tmp = z + dout tmp = z + dout
real_grad = all_reduce(tmp) real_grad = all_reduce(tmp)
dx = real_grad - z dx = real_grad
else: else:
dx = dout dx = dout
else: else:
@ -354,7 +381,7 @@ def get_bprop_mirror_mini_step_operator(self):
grad = dout.values grad = dout.values
dx = RowTensor(indices, grad, dout.dense_shape) dx = RowTensor(indices, grad, dout.dense_shape)
return (dx, zeros_like(y), zeros_like(z)) return (dx, zeros_like(z))
return bprop return bprop

@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity, Range) Unique, GatherD, Identity, Range)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, _VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter) _HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,

@ -200,6 +200,38 @@ class AllGather(PrimitiveWithInfer):
raise NotImplementedError raise NotImplementedError
class _MiniStepAllGather(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
internal use of parallel modules and cannot be called by users.
Args:
group (str): The communication group to work on. Default: None.
grad_accumulation_step (int): The grad accumulation step. Default: None.
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 1)
self.grad_accumulation_step = grad_accumulation_step
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype, z_shape):
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
class _HostAllGather(PrimitiveWithInfer): class _HostAllGather(PrimitiveWithInfer):
""" """
Gathers tensors from the specified communication group on host. Gathers tensors from the specified communication group on host.
@ -590,10 +622,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
self.mean_flag = mean_flag self.mean_flag = mean_flag
self.grad_accumulation_step = grad_accumulation_step self.grad_accumulation_step = grad_accumulation_step
def infer_shape(self, x_shape, y_shape, z_shape): def infer_shape(self, x_shape, z_shape):
return x_shape return x_shape
def infer_dtype(self, x_dtype, y_shape, z_shape): def infer_dtype(self, x_dtype, z_shape):
return x_dtype return x_dtype

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,15 +17,14 @@ import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum, Norm
from mindspore.train import Model from mindspore.train import Model
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm
from mindspore.parallel._utils import _get_device_num
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
@ -142,29 +141,29 @@ class TrainAccumulateStepsWithLossScaleCell(Cell):
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1. batch_size * accumulation_steps. Default: 1.
""" """
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.accu = False
self.is_accu_step = Tensor(np.array([self.accu]))
self.network = network self.network = network
self.network.set_grad() self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.accumulation_steps = accumulation_steps self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step")
self.one = Tensor(np.array([1]).astype(np.int32)) self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32)) self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1 self.degree = 1
self.grad_reducer = F.identity
if self.reducer_flag: if self.reducer_flag:
self.degree = get_group_size() self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity self.overflow_reducer = F.identity
@ -197,34 +196,27 @@ class TrainAccumulateStepsWithLossScaleCell(Cell):
else: else:
scaling_sens = sens scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation # alloc status and clear should be right before gradoperation
init = self.alloc_status() init = self.alloc_status()
self.clear_before_grad(init) self.clear_before_grad(init)
grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) if self.is_accu_step and self.accumulation_steps > 1:
mean_loss = F.depend(mean_loss, accu_succ) accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
loss = F.depend(loss, accu_succ)
self.get_status(init) self.get_status(init)
flag_sum = self.reduce_sum(init, (0,)) flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum) overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero) accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero)
is_accu_step = self.reshape(is_accu_step, (()))
if is_accu_step: if self.is_accu_step:
succ = False succ = False
else: else:
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(self.accu_grads) grads = self.grad_reducer(grads)
scaling = scaling_sens * self.degree * self.accumulation_steps scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads) grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
grads = ClipByGlobalNorm()(grads) grads = ClipByGlobalNorm()(grads)
@ -241,7 +233,7 @@ class TrainAccumulateStepsWithLossScaleCell(Cell):
else: else:
succ = self.optimizer(grads) succ = self.optimizer(grads)
ret = (mean_loss, overflow, scaling_sens) ret = (loss, overflow, scaling_sens)
return F.depend(ret, succ) return F.depend(ret, succ)
@ -265,25 +257,51 @@ _b = Tensor(np.ones([16]), dtype=ms.float32)
_w1 = Tensor(np.ones([16]), dtype=ms.float32) _w1 = Tensor(np.ones([16]), dtype=ms.float32)
def compile_net(net, grad_accumulation_step): def compile_net(net):
context.set_context(save_graphs=True) context.set_context(enable_sparse=False)
learning_rate = 0.1 learning_rate = 0.1
momentum = 0.9 momentum = 0.9
epoch_size = 2 epoch_size = 2
dataset = Dataset(_x, _b) dataset = Dataset(_x, _b)
opt = Momentum(net.trainable_params(), learning_rate, momentum) opt = Momentum(net.trainable_params(), learning_rate, momentum)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell)
accumulation_steps=grad_accumulation_step)
model = Model(net_wrap) model = Model(net_wrap)
model.train(epoch_size, dataset, dataset_sink_mode=False) model.train(epoch_size, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
def test_grad_accumulation(): def test_grad_accumulation_accu():
grad_accumulation_step = 4 grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step) grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,)) strategy = ((2,), (2,))
net = Net(_w1, strategy) net = Net(_w1, strategy).add_flags_recursive(accu=True)
compile_net(net, grad_accumulation_step) compile_net(net)
def test_grad_accu_and_opt_shard_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=True)
compile_net(net)
def test_grad_accumulation_not_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=False)
compile_net(net)
def test_grad_accu_and_opt_shard_not_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=False)
compile_net(net)

Loading…
Cancel
Save