mini step grad accumulation

pull/10221/head
yangzhenzhang 5 years ago
parent 03e655f14a
commit 9da3f9bec9

@ -80,6 +80,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
prim::kPrimMirrorMiniStep);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

@ -51,6 +51,7 @@ class OptimizeIRPassLib {
SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;

@ -155,6 +155,29 @@ class CheckBpropEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X
class MirrorMiniStepEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto inputs = cnode->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// Reset defer_inline flag
class ResetDeferInline : public AnfVisitor {
public:

@ -64,6 +64,7 @@ void ParallelContext::Reset() {
all_reduce_fusion_split_sizes_.clear();
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
pipeline_stage_split_num_ = 1;
grad_accumulation_step_ = 1;
}
void ParallelContext::set_device_num(int64_t device_num) {
@ -80,6 +81,10 @@ void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
grad_accumulation_step_ = grad_accumulation_step;
}
void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }

@ -73,6 +73,9 @@ class ParallelContext {
void set_global_rank(int64_t global_rank);
int64_t global_rank() const { return global_rank_; }
void set_grad_accumulation_step(int64_t grad_accumulation_step);
int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
bool set_parallel_mode(const std::string &parallel_mode);
std::string parallel_mode() const { return parallel_mode_; }
@ -116,6 +119,7 @@ class ParallelContext {
bool loss_repeated_mean_;
int64_t device_num_;
int64_t global_rank_;
int64_t grad_accumulation_step_;
std::string parallel_mode_;
std::string strategy_search_mode_;
int64_t pipeline_stage_split_num_;

@ -285,8 +285,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
}
OperatorVector op_for_weight;
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
OperatorName operator_name = MIRROR_OPERATOR;
ValuePtr attr0_value = MakeValue(group_name);
ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
ValuePtr attr2_value = MakeValue(mean_flag);
@ -300,6 +300,17 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
operator_attrs.push_back(attr1);
operator_attrs.push_back(attr2);
OperatorName operator_name;
if (grad_accumulation_step > 1) {
operator_name = MIRROR_MINI_STEP_OPERATOR;
ValuePtr attr3_value = MakeValue(grad_accumulation_step);
Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
operator_attrs.push_back(attr3);
MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
} else {
operator_name = MIRROR_OPERATOR;
}
OperatorParams operator_param;
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);

@ -146,8 +146,10 @@ constexpr char IS_IN_FORWARD[] = "is_in_forward";
constexpr char DTYPE[] = "DType";
constexpr char DEV_NUM[] = "dev_num";
constexpr char MEAN_FLAG[] = "mean_flag";
constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step";
constexpr char TYPES[] = "types";
constexpr char SHAPES[] = "shapes";
constexpr char ACCU_GRADS[] = "accu_grads";
constexpr char GETNEXT_NUM[] = "output_num";
constexpr char SHARED_NAME[] = "shared_name";
constexpr char MIRROR_OP[] = "mirror_op";
@ -171,6 +173,8 @@ constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";
constexpr char ALL_REDUCE[] = "AllReduce";
constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator";
constexpr char LOCAL_STEP[] = "local_step";
constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter";

@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
MS_LOG(INFO) << "Insert " << instance_name << " success";
}
bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(parameter_node);
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
// find the clone parameter
if (!cloned_parameter->has_default()) {
return false;
}
auto param_value = cloned_parameter->param_info();
if (param_value == nullptr) {
return false;
}
bool cloned = param_value->cloned();
if (!cloned) {
return false;
}
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
return true;
}
std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
const std::string &instance_name, const std::string &weight_name) {
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(root->manager());
AnfNodePtr local_step_param = nullptr;
AnfNodePtr grad_accu = nullptr;
std::string op_name = op.first;
OperatorArgs arg_forward = op.second;
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
if (grad_accumulation_step > 1) {
bool find_locat_step_node = false;
auto parameters = root->parameters();
for (auto &param : parameters) {
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == LOCAL_STEP) {
auto param_users = root->manager()->node_users()[param];
for (auto &user : param_users) {
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
find_locat_step_node = true;
local_step_param = user.first;
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
break;
}
}
break;
}
}
bool find_grad_accu_node = false;
for (auto &param : parameters) {
if (!ParameterIsCloned(param)) {
continue;
}
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name().find(weight_name) != std::string::npos &&
param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
find_grad_accu_node = true;
grad_accu = param;
MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
break;
}
}
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
if (!find_locat_step_node || !find_grad_accu_node) {
op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back();
}
}
}
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
MS_EXCEPTION_IF_NULL(pyop_instance);
OperatorParams params = arg_forward.second;
std::vector<AnfNodePtr> new_node_input;
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu};
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input";
} else {
new_node_input = {NewValueNode(pyop_instance), node};
}
if (!params.empty()) {
for (auto &param : params) {
AnfNodePtr val = NewValueNode(param.first.second);
MS_EXCEPTION_IF_NULL(val);
int64_t position = param.second;
(void)new_node_input.insert(new_node_input.begin() + position, val);
}
}
// if the op have 'group' attr, set the rank list name for the op
SetCommunicationOpGroupLabel(new_node_input);
return new_node_input;
}
void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index,
const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name,
const std::string &param_name) {
// insert new node before the node
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
CNodePtr new_node = func_graph->NewCNode(node_input);
MS_EXCEPTION_IF_NULL(new_node);
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
new_node->set_in_forward_flag(true); // mark forward flag
}
auto new_node_value = node_input[0]->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(new_node_value);
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->SetEdge(node, SizeToLong(index), new_node);
MS_LOG(INFO) << "Insert " << instance_name << " success";
}
// Replace pre_node with pre_node->op
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
const std::string &instance_name) {
@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size();
FuncGraphPtr func_graph = node->func_graph();
@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
if (!param_node_pair.first) {
continue;
}
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
std::string param_name;
if (param_ptr != nullptr) {
param_name = param_ptr->name();
}
// not a RefKey
if (!param_node_pair.second) {
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
} else {
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertNode(op, node, index, pre_node, func_graph, instance_name);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
}
}
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(node);
@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
// insert mirror op
if (!mirror_ops.empty()) {
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
InsertMirrorOps(mirror_ops, node);
InsertMirrorOps(root, mirror_ops, node);
}
// insert virtual div op
if (!virtual_div_op.empty() && is_loss_cnode) {
@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
g_RefMap.clear();
}
bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(parameter_node);
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
// find the clone parameter
if (!cloned_parameter->has_default()) {
return false;
}
auto param_value = cloned_parameter->param_info();
if (param_value == nullptr) {
return false;
}
bool cloned = param_value->cloned();
if (!cloned) {
return false;
}
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
return true;
}
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
for (auto &cloned_parameter_node : root->parameters()) {
@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
// insert backward ops
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
}
HandleSpecialNode(distribute_operator, cnode);

@ -82,11 +82,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph);
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs);
// Generate and init parallel operator
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
const std::vector<Shapes> &shape_list);

@ -131,6 +131,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
.def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.")
.def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.")
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
.def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
.def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,

@ -143,6 +143,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.check_bprop_eliminate_,
irpass.switch_layer_defer_inline_,
irpass.replace_applicator_,
irpass.mirror_mini_step_elim_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);

@ -206,6 +206,7 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorM
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");

@ -21,7 +21,7 @@ from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@ -282,6 +282,82 @@ def get_bprop_mirror_operator(self):
return bprop
@bprop_getters.register(_MirrorMiniStepOperator)
def get_bprop_mirror_mini_step_operator(self):
"""
Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group,
allgather for sparse feature.
"""
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
grad_accumulation_step = self.grad_accumulation_step
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
equal = P.Equal()
reshape = P.Reshape()
fusion = 1
if hasattr(self, 'fusion'):
fusion = self.fusion
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
all_reduce.add_prim_attr("parameter", parameter)
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
def bprop(x, y, z, out, dout):
do_mirror = equal(y, grad_accumulation_step)
do_mirror = reshape(do_mirror, (()))
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
dx = real_grad - z
else:
dx = dout
float_one = F.scalar_cast(1.0, F.dtype(dx))
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = RowTensor(indices, grad, dout.dense_shape)
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
dx = real_grad - z
else:
dx = dout
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
dx = RowTensor(indices, grad, dout.dense_shape)
return (dx, zeros_like(y), zeros_like(z))
return bprop
@bprop_getters.register(_VirtualDiv)
def get_bprop_virtual_div_operator(self):
"""Backpropagator for _VirtualDiv, do Div for the divisor."""

@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,

@ -567,6 +567,35 @@ class _MirrorOperator(PrimitiveWithInfer):
mirror = _MirrorOperator()
class _MirrorMiniStepOperator(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
internal use of parallel modules and cannot be called by users.
Args:
group (str): The communication group to work on. Default: None.
dev_num (int): The device number of the group. Default: None.
mean_flag (bool): Whether use mean in backward. Default: None.
grad_accumulation_step (int): The grad accumulation step. Default: None.
"""
@prim_attr_register
def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
self.grad_accumulation_step = grad_accumulation_step
def infer_shape(self, x_shape, y_shape, z_shape):
return x_shape
def infer_dtype(self, x_dtype, y_shape, z_shape):
return x_dtype
mirror_mini_step = _MirrorMiniStepOperator()
class _VirtualDiv(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do Div in backward.

@ -249,6 +249,21 @@ class _AutoParallelContext:
return False
return self._context_handle.get_full_batch()
def set_grad_accumulation_step(self, grad_accumulation_step):
"""
Set grad accumulation step.
Args:
grad_accumulation_step (int): The grad accumulation step.
"""
self.check_context_handle()
self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
def get_grad_accumulation_step(self):
"""Get grad accumulation step."""
self.check_context_handle()
return self._context_handle.get_grad_accumulation_step()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
"""
Set strategy checkpoint save path.
@ -492,6 +507,7 @@ _set_auto_parallel_context_func_map = {
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
@ -509,6 +525,7 @@ _get_auto_parallel_context_func_map = {
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
@ -516,7 +533,7 @@ _get_auto_parallel_context_func_map = {
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list)
grad_accumulation_step=int, all_reduce_fusion_config=list)
def _set_auto_parallel_context(**kwargs):
"""

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save