!12700 add grad accumulation combined with optimizer parallel

From: @yangzhenzhang
Reviewed-by: 
Signed-off-by:
pull/12700/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7fcce73c51

@ -86,6 +86,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

@ -52,6 +52,7 @@ class OptimizeIRPassLib {
SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_;
SubstitutionPtr mini_step_allgather_replace_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;

@ -33,6 +33,7 @@
#include "utils/comm_manager.h"
#include "frontend/parallel/context.h"
#include "pipeline/jit/parse/resolve.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace opt {
@ -155,7 +156,7 @@ class CheckBpropEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X
// {prim::kPrimMirrorMiniStep, X, Z} -> X
class MirrorMiniStepEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -163,11 +164,7 @@ class MirrorMiniStepEliminater : public AnfVisitor {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto inputs = cnode->inputs();
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
@ -178,6 +175,32 @@ class MirrorMiniStepEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
class MiniStepAllGatherPass : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
std::string group = attrs[parallel::GROUP]->ToString();
parallel::Operator op = parallel::CreateAllGatherOp(group);
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
auto func_graph = inputs[1]->func_graph();
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;
}
void Visit(const AnfNodePtr &) override {}
};
// Reset defer_inline flag
class ResetDeferInline : public AnfVisitor {
public:
@ -328,20 +351,8 @@ class PynativeEliminater : public OptimizerCaller {
return out;
}
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4);
PatternNode<AnfNodePtr> symbol_str_vnode;
PatternNode<AnfNodePtr> c_vnode;
PatternNode<AnfNodePtr> zeros_like_vnode;
PatternNode<AnfNodePtr> arg;
auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode);
auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode);
auto pattern = PCNode(getattr, arg);
// {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy))
if ((pattern).TryCapture(node) &&
(CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
private:
AnfNodePtr OperatorHandle1(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) {
auto rep = (arg).GetNode(node);
if (rep != nullptr) {
if (rep->isa<ValueNode>()) {
@ -352,14 +363,10 @@ class PynativeEliminater : public OptimizerCaller {
return new_value_node;
}
}
return nullptr;
}
MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4);
// {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy))
auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode);
auto pattern1 = PCNode(resolve1, arg);
if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
AnfNodePtr OperatorHandle2(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) {
auto rep = (arg).GetNode(node);
if (rep != nullptr) {
if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) {
@ -370,15 +377,10 @@ class PynativeEliminater : public OptimizerCaller {
return new_value_node;
}
}
return nullptr;
}
// {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout}
PatternNode<AnfNodePtr> binop_grad_common;
PatternNode<AnfNodePtr> getitem_vnode;
std::vector<PatternNode<AnfNodePtr>> args(4);
auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common);
auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]);
if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) {
void OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> &args, const AnfNodePtr &node) {
for (size_t i = 0; i < 2; i++) {
auto rep = (args[i]).GetNode(node);
if (rep != nullptr && rep->isa<ValueNode>()) {
@ -395,14 +397,10 @@ class PynativeEliminater : public OptimizerCaller {
}
}
}
return nullptr;
}
// resolve(CommonOPS, getitem)((tensors), 3)
PatternNode<AnfNodePtr> arg1;
auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode);
auto pattern2 = PCNode(resolve2, arg, arg1);
if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") &&
CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) {
AnfNodePtr OperatorHandle4(const PatternNode<AnfNodePtr> &arg, const PatternNode<AnfNodePtr> &arg1,
const AnfNodePtr &node) {
auto rep = (arg).GetNode(node);
if (rep != nullptr) {
if (rep->isa<ValueNode>()) {
@ -424,6 +422,61 @@ class PynativeEliminater : public OptimizerCaller {
return new_node;
}
}
return nullptr;
}
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4);
PatternNode<AnfNodePtr> symbol_str_vnode;
PatternNode<AnfNodePtr> c_vnode;
PatternNode<AnfNodePtr> zeros_like_vnode;
PatternNode<AnfNodePtr> arg;
auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode);
auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode);
auto pattern = PCNode(getattr, arg);
// {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy))
if ((pattern).TryCapture(node) &&
(CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
auto new_value_node = OperatorHandle1(arg, node);
if (new_value_node != nullptr) {
return new_value_node;
}
}
MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4);
// {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy))
auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode);
auto pattern1 = PCNode(resolve1, arg);
if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
auto new_value_node = OperatorHandle2(arg, node);
if (new_value_node != nullptr) {
return new_value_node;
}
}
// {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout}
PatternNode<AnfNodePtr> binop_grad_common;
PatternNode<AnfNodePtr> getitem_vnode;
std::vector<PatternNode<AnfNodePtr>> args(4);
auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common);
auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]);
if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) {
OperatorHandle3(args, node);
return nullptr;
}
// resolve(CommonOPS, getitem)((tensors), 3)
PatternNode<AnfNodePtr> arg1;
auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode);
auto pattern2 = PCNode(resolve2, arg, arg1);
if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") &&
CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) {
auto new_value_node = OperatorHandle4(arg, arg1, node);
if (new_value_node != nullptr) {
return new_value_node;
}
}
MS_LOG(DEBUG) << "End Replace " << node->DebugString(4);

@ -153,25 +153,27 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextInit(const FuncGraphPtr &func_graph) {
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
return;
}
if (func_graph->has_flag(AUTO_PARALLEL) &&
(!func_graph->has_flag(TRAINING) ||
(ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) {
init_param_shape_ = false;
} else {
param_shapes.clear();
init_param_shape_ = true;
}
}
// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr) {
void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
const ParameterPtr &param_node, AbstractBasePtr ptr) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) ||
func_graph->has_flag(TRAINING)) {
if (init_param_shape_) {
return;
}
auto iter = param_shapes.find(param_node->name());
if (iter == param_shapes.end()) {
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
@ -183,16 +185,16 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph,
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
if (!init_param_shape_) {
return;
}
std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
auto ret = param_shapes.try_emplace(param_node->name(), shape);
if (!ret.second) {

@ -30,6 +30,7 @@
#include "ir/func_graph.h"
#include "utils/convert_utils.h"
#include "utils/info.h"
#include "pipeline/jit/pipeline.h"
namespace mindspore {
namespace parallel {
@ -43,6 +44,7 @@ constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
constexpr char TRAINING[] = "training";
constexpr char ACCUMULATION[] = "accumulation";
class ParallelContext {
public:
@ -111,6 +113,11 @@ class ParallelContext {
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void Reset();
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr);
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr);
private:
ParallelContext();
@ -136,13 +143,9 @@ class ParallelContext {
std::string strategy_ckpt_save_file_;
std::string group_ckpt_save_file_;
bool enable_parallel_optimizer_;
bool init_param_shape_;
};
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr);
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr);
} // namespace parallel
} // namespace mindspore

@ -284,6 +284,39 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
return op;
}
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(comm_node);
MS_EXCEPTION_IF_NULL(param_node);
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
return;
}
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto param_info = param->param_info();
if (!param_info) {
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
return;
}
int32_t fusion_type = param_info->comm_fusion();
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
void AddCommOpMeanFlag(const CNodePtr &comm_node) {
MS_EXCEPTION_IF_NULL(comm_node);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
}
Operator CreateAllGatherOp(const std::string &group) {
OperatorName operator_name = ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group
@ -299,6 +332,30 @@ Operator CreateAllGatherOp(const std::string &group) {
return op;
}
Operator CreateMiniStepAllGatherOp(const std::string &group) {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
OperatorName operator_name = MINI_STEP_ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group
Attr attr0 = std::make_pair(GROUP, attr0_value);
ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step
Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value);
ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag
Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
OperatorAttrs operator_attrs;
operator_attrs.push_back(attr0);
operator_attrs.push_back(attr1);
operator_attrs.push_back(attr2);
OperatorParams operator_param;
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_arg);
MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group;
return op;
}
// use for get tensor slice
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
Shape tensor_map = tensor_layout.tensor_map().array();
@ -771,7 +828,7 @@ void OperatorInfo::ComputeBatchSplitFlagList() {
ReComputeBatchSplitFlagList();
}
// This is a common method for checking whether the generated stragegy has the correct number of devuces.
// This is a common method for checking whether the generated strategy has the correct number of devuces.
Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) {
if (sp == nullptr) {
MS_LOG(ERROR) << "The strategy is null.";
@ -886,7 +943,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs
(void)input0_strategy.erase(input0_strategy.begin(),
input0_strategy.begin() + static_cast<different_type>(size_diff));
// handel the case likes ([1, c, d], [a, b, c, d])
// handle the case likes ([1, c, d], [a, b, c, d])
for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
if (inputs_shape[0][i] == 1) {
input0_strategy[i] = 1;
@ -937,7 +994,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input
(void)input1_strategy.erase(input1_strategy.begin(),
input1_strategy.begin() + static_cast<different_type>(size_diff));
// handel the case likes ([a, b, c, d], [1, c, d])
// handle the case likes ([a, b, c, d], [1, c, d])
for (size_t i = 0; i < inputs_shape[1].size(); ++i) {
if (inputs_shape[1][i] == 1) {
input1_strategy[i] = 1;

@ -36,6 +36,7 @@
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
#include "utils/log_adapter.h"
#include "base/core_ops.h"
namespace mindspore {
namespace parallel {
@ -160,7 +161,7 @@ class OperatorInfo {
void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
const std::string &refkey_parameter_name() const { return refkey_parameter_name_; }
// When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
// multiple times. This method is to correct this, and makes the cost is calulated only once.
// multiple times. This method is to correct this, and makes the cost is calculated only once.
Status CorrectMemoryCost(size_t input_index);
int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; }
int64_t is_output_critical() const { return is_output_critical_; }
@ -242,7 +243,7 @@ class OperatorInfo {
bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel
// 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected.
std::vector<size_t> corrected_input_indices_;
// Given a parallization strategy, there is a cost.
// Given a parallelization strategy, there is a cost.
std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_;
// For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
std::vector<bool> is_parameter_;
@ -288,6 +289,9 @@ Operator CreateVirtualDivOp(int64_t div_num);
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
Operator CreateAllGatherOp(const std::string &group);
Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
void AddCommOpMeanFlag(const CNodePtr &comm_node);
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);

@ -109,6 +109,7 @@ constexpr char END[] = "end";
constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion";
constexpr char DO_MIRROR[] = "do_mirror";
constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed";
@ -180,6 +181,7 @@ constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator";
constexpr char LOCAL_STEP[] = "local_step";
constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather";
constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter";
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";

@ -66,8 +66,8 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
return;
}
ValueNodePtr prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
@ -84,6 +84,19 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
}
}
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
if (new_node_input.empty()) {
return;
}
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
prim->SetAttrs(attrs);
}
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
MS_EXCEPTION_IF_NULL(node);
OperatorArgs arg_forward = op.second;
@ -158,7 +171,6 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(root->manager());
AnfNodePtr local_step_param = nullptr;
AnfNodePtr grad_accu = nullptr;
std::string op_name = op.first;
OperatorArgs arg_forward = op.second;
@ -166,25 +178,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
if (grad_accumulation_step > 1) {
bool find_locat_step_node = false;
auto parameters = root->parameters();
for (auto &param : parameters) {
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == LOCAL_STEP) {
auto param_users = root->manager()->node_users()[param];
for (auto &user : param_users) {
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
find_locat_step_node = true;
local_step_param = user.first;
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
break;
}
}
break;
}
}
bool find_grad_accu_node = false;
for (auto &param : parameters) {
if (!ParameterIsCloned(param)) {
@ -202,10 +196,12 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
}
}
if (!find_grad_accu_node) {
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
if (!find_locat_step_node || !find_grad_accu_node) {
op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back();
} else if (op_name == MINI_STEP_ALL_GATHER) {
MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation.";
}
}
}
@ -215,9 +211,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
OperatorParams params = arg_forward.second;
std::vector<AnfNodePtr> new_node_input;
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu};
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input";
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) {
new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
} else {
new_node_input = {NewValueNode(pyop_instance), node};
}
@ -233,6 +229,10 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
// if the op have 'group' attr, set the rank list name for the op
SetCommunicationOpGroupLabel(new_node_input);
// gradient accumulation
if (grad_accumulation_step > 1) {
SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
}
return new_node_input;
}
@ -285,6 +285,31 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
return new_node;
}
// Replace pre_node with pre_node->op
static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node,
const FuncGraphPtr &func_graph, const std::string &instance_name,
const std::string &param_name) {
// insert new node before the node
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ScopePtr scope = pre_node->scope();
MS_EXCEPTION_IF_NULL(scope);
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
CNodePtr new_node = func_graph->NewCNode(node_input);
MS_EXCEPTION_IF_NULL(new_node);
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
new_node->set_in_forward_flag(true); // mark forward flag
}
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->Replace(pre_node, new_node);
MS_LOG(INFO) << "Insert " << instance_name << " success";
return new_node;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
@ -1086,29 +1111,6 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
return (type_id != kNumberTypeFloat32);
}
static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(comm_node);
MS_EXCEPTION_IF_NULL(param_node);
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
return;
}
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto param_info = param->param_info();
if (!param_info) {
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
return;
}
int32_t fusion_type = param_info->comm_fusion();
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it.";
@ -1195,7 +1197,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
continue;
@ -1540,33 +1541,40 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
return std::make_pair(nullptr, 0);
}
static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res,
const AnfNodePtr &parameter) {
Operator op = CreateAllGatherOp(group);
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(node);
auto cnode = res.first->cast<CNodePtr>();
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(cnode_prim);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
Operator op;
CNodePtr allgather;
if (grad_accumulation_step > 1) {
op = CreateMiniStepAllGatherOp(group);
auto param_name = node->cast<ParameterPtr>()->name();
if (cnode_prim->name() == CAST) {
allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
} else {
InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
} else {
op = CreateAllGatherOp(group);
if (cnode_prim->name() == CAST) {
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER);
} else {
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
MS_EXCEPTION_IF_NULL(allgather);
}
// add fusion flag
AddCommOpFusionType(allgather, parameter);
AddCommOpFusionType(allgather, node);
// add gradients mean
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
AddCommOpMeanFlag(allgather);
}
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
@ -1589,7 +1597,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
<< distribute_operator->inputs_tensor_info().size();
}
// insert allgather operator between shard parameter and cnode
InsertAllGatherOp(opt_shard_group, param_pair, parameter);
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString();
}
}
@ -1734,12 +1742,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract);
if (param_name.find(ACCU_GRADS) != std::string::npos) {
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
cloned_abstract->set_shape(parallel_shape);
} else {
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
}
cloned_parameter_node->set_abstract(cloned_abstract);
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()

@ -30,6 +30,8 @@
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/pipeline.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;

@ -258,9 +258,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec = res->args_spec();
parallel::ParallelParameterContextInit(func_graph);
auto context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
context->ParallelParameterContextInitShape(func_graph);
// suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
for (const auto &param : func_graph->parameters()) {
@ -271,9 +271,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref);
context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
args_spec.push_back(abs_ref);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref);
context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
}
}
// Analyze

@ -160,6 +160,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.replace_applicator_,
irpass.mirror_mini_step_elim_,
irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);

@ -374,7 +374,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape"))
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
.def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */

@ -134,7 +134,7 @@ class Parameter(Tensor_):
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
self._param_info = ParamInfo()
self.param_info = ParamInfo()
self.init_in_server = False
self.cache_enable = False
self.name = name
@ -230,7 +230,7 @@ class Parameter(Tensor_):
"sparse operator support initialization in server.".format(self.name))
self.is_param_ps = True
self.init_in_server = init_in_server
self._param_info.init_in_server = init_in_server
self.param_info.init_in_server = init_in_server
@property
def inited_param(self):
@ -245,7 +245,7 @@ class Parameter(Tensor_):
@property
def name(self):
"""Get the name of the parameter."""
return self._param_info.name
return self.param_info.name
@name.setter
def name(self, name_):
@ -272,9 +272,9 @@ class Parameter(Tensor_):
if len(self.shape) != 2:
raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
.format(self.name, len(self.shape)))
_reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1])
_reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1])
self._param_info.name = name_
self.param_info.name = name_
@property
def sliced(self):
@ -288,12 +288,12 @@ class Parameter(Tensor_):
@property
def comm_fusion(self):
"""Get the fusion type for communication operators corresponding to this parameter."""
return self._param_info.comm_fusion
return self.param_info.comm_fusion
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
"""Set the fusion type for communication operators corresponding to this parameter."""
self._param_info.comm_fusion = comm_fusion_
self.param_info.comm_fusion = comm_fusion_
@property
def unique(self):
@ -339,7 +339,7 @@ class Parameter(Tensor_):
"""
x = copy(self)
# pylint: disable=protected-access
x._param_info = self._param_info.clone()
x.param_info = self.param_info.clone()
x.is_init = False
x.init = self.init
x.is_param_ps = self.is_param_ps
@ -355,57 +355,57 @@ class Parameter(Tensor_):
@property
def layerwise_parallel(self):
return self._param_info.layerwise_parallel
return self.param_info.layerwise_parallel
@layerwise_parallel.setter
def layerwise_parallel(self, value=True):
if not isinstance(value, bool):
raise TypeError("`layerwise_parallel` parameter must be bool type")
self._param_info.layerwise_parallel = value
self.param_info.layerwise_parallel = value
@property
def parallel_optimizer(self):
"""Return whether the parameter requires weight shard for parallel optimizer."""
return self._param_info.parallel_optimizer
return self.param_info.parallel_optimizer
@parallel_optimizer.setter
def parallel_optimizer(self, value=True):
if not isinstance(value, bool):
raise TypeError("`parallel_optimizer` parameter must be bool type")
self._param_info.parallel_optimizer = value
self.param_info.parallel_optimizer = value
@property
def cache_enable(self):
"""Return whether the parameter is cache enable."""
return self._param_info.cache_enable
return self.param_info.cache_enable
@cache_enable.setter
def cache_enable(self, value=True):
if not isinstance(value, bool):
raise TypeError("`cache_enable` parameter must be bool type")
self._param_info.cache_enable = value
self.param_info.cache_enable = value
@property
def cache_shape(self):
"""Return the cache shape corresponding to the parameter if use cache."""
return self._param_info.cache_shape
return self.param_info.cache_shape
@cache_shape.setter
def cache_shape(self, value):
if not isinstance(value, (tuple, list)):
raise TypeError("`cache_shape` parameter must be tuple or list type")
self._param_info.cache_shape = value
self.param_info.cache_shape = value
@property
def requires_grad(self):
"""Return whether the parameter requires gradient."""
return self._param_info.requires_grad
return self.param_info.requires_grad
@requires_grad.setter
def requires_grad(self, value=True):
if not isinstance(value, bool):
raise TypeError("`requires_grad` parameter must be bool type")
self._param_info.requires_grad = value
self.param_info.requires_grad = value
@property
def data(self):
@ -419,7 +419,9 @@ class Parameter(Tensor_):
self.init = None
return self.assign_value(data)
# create a new tensor
return Parameter(data, self.name, self.requires_grad)
new_param = Parameter(data, self.name, self.requires_grad)
new_param.param_info = self.param_info
return new_param
def set_data(self, data, slice_shape=False):
"""

@ -306,6 +306,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitiv
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");

@ -20,7 +20,7 @@ from mindspore.communication import get_rank, get_group_size
from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters
@ -150,6 +150,39 @@ def get_bprop_all_gather(self):
return bprop
@bprop_getters.register(_MiniStepAllGather)
def get_bprop_mini_step_all_gather(self):
"""Generate bprop for _MiniStepAllGather"""
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
do_mirror = self.get_attr_dict()["do_mirror"]
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
def bprop(x, z, out, dout):
if do_mirror:
if mean_flag:
tmp = z + dout
grad = all_reduce(tmp)
dx = split(grad)[rank]
dx = F.tensor_mul(dx, scale)
else:
tmp = z + dout
grad = all_reduce(tmp)
dx = split(grad)[rank]
else:
dx = dout
return (dx, zeros_like(z))
return bprop
@bprop_getters.register(_HostAllGather)
def get_bprop_host_all_gather(self):
"""Generate bprop for _HostAllGather"""
@ -291,18 +324,13 @@ def get_bprop_mirror_mini_step_operator(self):
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
grad_accumulation_step = self.grad_accumulation_step
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
equal = P.Equal()
reshape = P.Reshape()
fusion = 1
if hasattr(self, 'fusion'):
fusion = self.fusion
fusion = self.get_attr_dict()["fusion"]
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
@ -311,16 +339,15 @@ def get_bprop_mirror_mini_step_operator(self):
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
do_mirror = self.get_attr_dict()["do_mirror"]
def bprop(x, y, z, out, dout):
do_mirror = equal(y, grad_accumulation_step)
do_mirror = reshape(do_mirror, (()))
def bprop(x, z, out, dout):
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
dx = real_grad - z
dx = real_grad
else:
dx = dout
float_one = F.scalar_cast(1.0, F.dtype(dx))
@ -342,7 +369,7 @@ def get_bprop_mirror_mini_step_operator(self):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
dx = real_grad - z
dx = real_grad
else:
dx = dout
else:
@ -354,7 +381,7 @@ def get_bprop_mirror_mini_step_operator(self):
grad = dout.values
dx = RowTensor(indices, grad, dout.dense_shape)
return (dx, zeros_like(y), zeros_like(z))
return (dx, zeros_like(z))
return bprop

@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity, Range)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,

@ -200,6 +200,38 @@ class AllGather(PrimitiveWithInfer):
raise NotImplementedError
class _MiniStepAllGather(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
internal use of parallel modules and cannot be called by users.
Args:
group (str): The communication group to work on. Default: None.
grad_accumulation_step (int): The grad accumulation step. Default: None.
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 1)
self.grad_accumulation_step = grad_accumulation_step
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype, z_shape):
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
class _HostAllGather(PrimitiveWithInfer):
"""
Gathers tensors from the specified communication group on host.
@ -590,10 +622,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
self.mean_flag = mean_flag
self.grad_accumulation_step = grad_accumulation_step
def infer_shape(self, x_shape, y_shape, z_shape):
def infer_shape(self, x_shape, z_shape):
return x_shape
def infer_dtype(self, x_dtype, y_shape, z_shape):
def infer_dtype(self, x_dtype, z_shape):
return x_dtype

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,15 +17,14 @@ import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum, Norm
from mindspore.train import Model
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.context import ParallelMode
from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm
from mindspore.parallel._utils import _get_device_num
from tests.dataset_mock import MindData
@ -142,29 +141,29 @@ class TrainAccumulateStepsWithLossScaleCell(Cell):
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4):
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.accu = False
self.is_accu_step = Tensor(np.array([self.accu]))
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step")
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
self.grad_reducer = F.identity
if self.reducer_flag:
self.degree = get_group_size()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
@ -197,34 +196,27 @@ class TrainAccumulateStepsWithLossScaleCell(Cell):
else:
scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
if self.is_accu_step and self.accumulation_steps > 1:
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ)
loss = F.depend(loss, accu_succ)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
is_accu_step = self.reshape(is_accu_step, (()))
self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero)
if is_accu_step:
if self.is_accu_step:
succ = False
else:
# apply grad reducer on grads
grads = self.grad_reducer(self.accu_grads)
grads = self.grad_reducer(grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
grads = ClipByGlobalNorm()(grads)
@ -241,7 +233,7 @@ class TrainAccumulateStepsWithLossScaleCell(Cell):
else:
succ = self.optimizer(grads)
ret = (mean_loss, overflow, scaling_sens)
ret = (loss, overflow, scaling_sens)
return F.depend(ret, succ)
@ -265,25 +257,51 @@ _b = Tensor(np.ones([16]), dtype=ms.float32)
_w1 = Tensor(np.ones([16]), dtype=ms.float32)
def compile_net(net, grad_accumulation_step):
context.set_context(save_graphs=True)
def compile_net(net):
context.set_context(enable_sparse=False)
learning_rate = 0.1
momentum = 0.9
epoch_size = 2
dataset = Dataset(_x, _b)
opt = Momentum(net.trainable_params(), learning_rate, momentum)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell,
accumulation_steps=grad_accumulation_step)
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell)
model = Model(net_wrap)
model.train(epoch_size, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context()
def test_grad_accumulation():
def test_grad_accumulation_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,))
net = Net(_w1, strategy)
compile_net(net, grad_accumulation_step)
net = Net(_w1, strategy).add_flags_recursive(accu=True)
compile_net(net)
def test_grad_accu_and_opt_shard_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=True)
compile_net(net)
def test_grad_accumulation_not_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=False)
compile_net(net)
def test_grad_accu_and_opt_shard_not_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=False)
compile_net(net)

Loading…
Cancel
Save