!12853 handle the fully split parameter for grad accumulation

From: @yangzhenzhang
Reviewed-by: 
Signed-off-by:
pull/12853/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c12abe7a46

@ -51,7 +51,7 @@
namespace mindspore {
namespace session {
static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
static std::shared_ptr<std::map<ParamInfoPtr, ParameterPtr>> python_paras;
void ClearPythonParasMap() { python_paras = nullptr; }
namespace {
const int kSummaryGetItem = 2;
@ -106,7 +106,7 @@ bool CheckIfNeedCreateOutputTensor(const AnfNodePtr &node) {
return false;
}
ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
@ -114,7 +114,7 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
if (parameter == nullptr || !parameter->has_default()) {
return nullptr;
}
return parameter->default_param();
return parameter->param_info();
}
tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
@ -747,7 +747,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>();
}
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {
@ -1217,7 +1217,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
auto param_value = GetParamDefaultValue(anf);
ParameterPtr new_parameter = nullptr;
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>();
}
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {

@ -88,6 +88,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

@ -52,6 +52,7 @@ class OptimizeIRPassLib {
SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_;
SubstitutionPtr virtual_add_elim_;
SubstitutionPtr mini_step_allgather_replace_;
// Env Item Eliminate

@ -175,6 +175,25 @@ class MirrorMiniStepEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimVirtualAdd, X, Z} -> X
class VirtualAddEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
class MiniStepAllGatherPass : public AnfVisitor {
public:
@ -191,8 +210,15 @@ class MiniStepAllGatherPass : public AnfVisitor {
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
std::string group = attrs[parallel::GROUP]->ToString();
auto fusion = attrs[parallel::FUSION];
parallel::Operator op = parallel::CreateAllGatherOp(group);
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
attrs = prim->attrs();
attrs[parallel::FUSION] = fusion;
prim->SetAttrs(attrs);
auto func_graph = inputs[1]->func_graph();
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;

@ -155,13 +155,23 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_flag(AUTO_PARALLEL) &&
(!func_graph->has_flag(TRAINING) ||
(ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) {
if (!func_graph->has_flag(AUTO_PARALLEL)) {
return;
}
if (!func_graph->has_flag(TRAINING)) {
init_param_shape_ = false;
MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
return;
}
if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) {
init_param_shape_ = false;
MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
} else {
param_shapes.clear();
init_param_shape_ = true;
MS_LOG(INFO) << "Init the parameter shape dict";
}
}
@ -171,6 +181,10 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL)) {
return;
}
if (init_param_shape_) {
return;
}
@ -182,7 +196,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
Shape shape = iter->second;
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
ptr->set_shape(base_shape);
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
@ -192,6 +206,10 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL)) {
return;
}
if (!init_param_shape_) {
return;
}

@ -110,6 +110,8 @@ constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion";
constexpr char DO_MIRROR[] = "do_mirror";
constexpr char RECOMPUTE[] = "recompute";
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed";

@ -97,6 +97,27 @@ void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool acc
prim->SetAttrs(attrs);
}
void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
if (new_node_input.empty()) {
return;
}
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto anf_node = node->input(0)->cast<ValueNodePtr>();
auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
MS_EXCEPTION_IF_NULL(prim_node);
auto node_attrs = prim_node->attrs();
if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
attrs[RECOMPUTE] = MakeValue<bool>(false);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
}
}
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
MS_EXCEPTION_IF_NULL(node);
OperatorArgs arg_forward = op.second;
@ -353,6 +374,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
std::string instance_name_base = FORWARD_OP;
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
SetAllReduceRecomputeFlag(forward_input, node_to_insert);
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
MS_EXCEPTION_IF_NULL(forward_node);
ScopePtr scope = node->scope();
@ -1165,7 +1187,14 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
// not a RefKey
if (!param_node_pair.second) {
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
std::string mirror_op_name;
if (grad_accumulation_step > 1) {
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
} else {
mirror_op_name = MIRROR_OPERATOR;
}
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
if (next_cnode.first) {
MS_EXCEPTION_IF_NULL(next_cnode.second);
@ -1743,6 +1772,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
continue;
}
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
@ -3298,6 +3331,97 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
}
}
static bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
return false;
}
auto dev_mat_shape = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
int64_t rank = g_device_manager->global_rank();
RankList rank_list = g_device_manager->GetDeviceListInThisStage();
DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
return false;
}
if (group_devices.size() == 1) {
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
return true;
}
return false;
}
static AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name) {
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == name) {
continue;
}
if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) {
return parameter;
}
}
return nullptr;
}
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
auto cnode = node_user.first->cast<CNodePtr>();
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
return;
}
OperatorAttrs attrs;
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
auto value_node = NewValueNode(py_instance);
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
auto graph = cnode->func_graph();
auto virtual_node = graph->NewCNode(virtual_node_input);
manager->SetEdge(cnode, node_user.second, virtual_node);
}
static void HandleFullySplitParameters(const FuncGraphPtr &root) {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
return;
}
auto parameters = root->parameters();
auto node_users_map = root->manager()->node_users();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!IsFullySplitParameter(param_ptr)) {
continue;
}
auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
if (!accu_parameter) {
continue; // some parameters no need to handle, such as itself or lr
}
auto node_users = node_users_map[parameter];
for (auto &user : node_users) {
auto node = user.first;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!cnode->in_forward_flag()) {
continue;
}
InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
break; // only need to insert once, if the parameter has many users
}
}
}
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
@ -3390,6 +3514,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_LOG(EXCEPTION) << "Save group info failed";
}
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
HandleFullySplitParameters(root);
DumpGraph(root, std::string(STEP_PARALLEL_END));
// step parallel only run once

@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.switch_layer_defer_inline_,
irpass.replace_applicator_,
irpass.mirror_mini_step_elim_,
irpass.virtual_add_elim_,
irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
});

@ -307,6 +307,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOper
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");

@ -22,7 +22,7 @@ from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@ -108,6 +108,14 @@ def get_bprop_receive(self):
return bprop
@bprop_getters.register(_VirtualAdd)
def get_bprop_virtual_add(self):
"""Generate bprop for _VirtualAdd"""
def bprop(x, grad_accu, out, dout):
return (dout + grad_accu, zeros_like(grad_accu))
return bprop
@bprop_getters.register(Broadcast)
def get_bprop_broad_cast(self):
"""Generate bprop for Broadcast."""
@ -168,13 +176,13 @@ def get_bprop_mini_step_all_gather(self):
def bprop(x, z, out, dout):
if do_mirror:
if mean_flag:
tmp = z + dout
grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dx = split(grad)[rank]
dx = F.tensor_mul(dx, scale)
else:
tmp = z + dout
grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dx = split(grad)[rank]
else:
dx = dout
@ -326,7 +334,6 @@ def get_bprop_mirror_mini_step_operator(self):
mean_flag = self.mean_flag
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
@ -345,8 +352,8 @@ def get_bprop_mirror_mini_step_operator(self):
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
dx = real_grad
else:
dx = dout
@ -354,32 +361,17 @@ def get_bprop_mirror_mini_step_operator(self):
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = RowTensor(indices, grad, dout.dense_shape)
dx = zeros_like(x) # The grad accumulation do not support row tensor now
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
dx = real_grad
else:
dx = dout
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
dx = RowTensor(indices, grad, dout.dense_shape)
dx = zeros_like(x) # The grad accumulation do not support row tensor now
return (dx, zeros_like(z))
return bprop

@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Unique, GatherD, Identity, Range)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
_VirtualDiv, _GetTensorSlice, _VirtualAdd,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)

@ -653,6 +653,19 @@ class _VirtualDiv(PrimitiveWithInfer):
virtual_div = _VirtualDiv()
class _VirtualAdd(PrimitiveWithInfer):
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
@prim_attr_register
def __init__(self):
"""init"""
def infer_shape(self, x_shape, y_shape):
return x_shape
def infer_dtype(self, x_dtype, y_dtype):
return x_dtype
class _VirtualDataset(PrimitiveWithInfer):
"""
Auto parallel virtual dataset operator.

@ -25,6 +25,7 @@ from mindspore.common.initializer import TruncatedNormal, initializer, Normal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class LayerNorm(nn.Cell):
"""
Layer Normalization

@ -47,6 +47,7 @@ def test_get_parameter_layout():
net = Net(strategy1, strategy2, weight)
net.set_auto_parallel()
net.set_train()
exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = ([2, 4], [1, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1]

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save