!8397 fix api error and get next info

From: @gong_zi_yan
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
pull/8397/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b0fde1ede7

@ -203,11 +203,11 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
MS_LOG(ERROR) << name_ << " : The dev num is 0.";
return FAILED;
}
if (out_shapes[i][0] % dev_num_ != 0) {
MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num.";
return FAILED;
}
if (!full_batch) {
if (out_shapes[i][0] % dev_num_ != 0) {
MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num.";
return FAILED;
}
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
}
}

@ -238,7 +238,7 @@ class Lamb(Optimizer):
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Lamb(params=net.trainable_params())
>>> optim = nn.Lamb(params=net.trainable_params(learning_rate=0.1))
>>>
>>> #2) Use parameter groups and set different values
>>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR()

@ -254,6 +254,8 @@ class DistributedGradReducer(Cell):
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
>>> _get_parallel_mode)
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
@ -279,11 +281,8 @@ class DistributedGradReducer(Cell):
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> mean = _get_gradients_mean()
>>> degree = _get_device_num()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):

Loading…
Cancel
Save