|
|
|
@ -621,6 +621,7 @@ class DynamicGraphAdapter(object):
|
|
|
|
|
|
|
|
|
|
self._input_info = None
|
|
|
|
|
if self._nranks > 1:
|
|
|
|
|
dist.init_parallel_env()
|
|
|
|
|
stradegy = fluid.dygraph.parallel.ParallelStrategy()
|
|
|
|
|
stradegy.nranks = ParallelEnv().nranks
|
|
|
|
|
stradegy.local_rank = ParallelEnv().local_rank
|
|
|
|
@ -888,7 +889,6 @@ class Model(object):
|
|
|
|
|
|
|
|
|
|
# init backend
|
|
|
|
|
if fluid.in_dygraph_mode():
|
|
|
|
|
dist.init_parallel_env()
|
|
|
|
|
self._adapter = DynamicGraphAdapter(self)
|
|
|
|
|
else:
|
|
|
|
|
self._adapter = StaticGraphAdapter(self)
|
|
|
|
@ -943,6 +943,7 @@ class Model(object):
|
|
|
|
|
self._update_inputs()
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def eval_batch(self, inputs, labels=None):
|
|
|
|
|
"""
|
|
|
|
|
Run one evaluating step on a batch of data.
|
|
|
|
@ -994,6 +995,7 @@ class Model(object):
|
|
|
|
|
self._update_inputs()
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def predict_batch(self, inputs):
|
|
|
|
|
"""
|
|
|
|
|
Run one predicting step on a batch of data.
|
|
|
|
|