|
|
|
@ -30,6 +30,8 @@ from ..nn.metrics import Loss
|
|
|
|
|
from .. import nn
|
|
|
|
|
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|
|
|
|
from .parallel_utils import ParallelMode
|
|
|
|
|
from ._utils import _to_full_tensor
|
|
|
|
|
from ..parallel._utils import _need_to_full
|
|
|
|
|
from ..common import dtype as mstype
|
|
|
|
|
from .dataset_helper import DatasetHelper
|
|
|
|
|
from . import amp
|
|
|
|
@ -418,6 +420,8 @@ class Model:
|
|
|
|
|
|
|
|
|
|
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
|
|
|
|
for inputs in dataset_helper:
|
|
|
|
|
if _need_to_full():
|
|
|
|
|
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
|
|
|
|
list_callback.step_begin(run_context)
|
|
|
|
|
outputs = self._train_network(*inputs)
|
|
|
|
|
cb_params.cur_step_num += dataset_helper.sink_size()
|
|
|
|
|