!9777 fix auto parallet full batch

From: @limingqi107
Reviewed-by: @chujinjin,@cristoval
Signed-off-by: @cristoval
pull/9777/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f7c339fce1

@ -438,8 +438,9 @@ class PredictWithSigmoid(nn.Cell):
self.network = network
self.sigmoid = P.Sigmoid()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
full_batch = context.get_auto_parallel_context("full_batch")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
if is_auto_parallel and full_batch:
self.sigmoid.shard(((1, 1),))
def construct(self, batch_ids, batch_wts, labels):

Loading…
Cancel
Save