|
|
@ -438,8 +438,9 @@ class PredictWithSigmoid(nn.Cell):
|
|
|
|
self.network = network
|
|
|
|
self.network = network
|
|
|
|
self.sigmoid = P.Sigmoid()
|
|
|
|
self.sigmoid = P.Sigmoid()
|
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
|
|
|
full_batch = context.get_auto_parallel_context("full_batch")
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
if is_auto_parallel:
|
|
|
|
if is_auto_parallel and full_batch:
|
|
|
|
self.sigmoid.shard(((1, 1),))
|
|
|
|
self.sigmoid.shard(((1, 1),))
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, batch_ids, batch_wts, labels):
|
|
|
|
def construct(self, batch_ids, batch_wts, labels):
|
|
|
|