|
|
|
@ -23,7 +23,7 @@ from mindspore import log as logger
|
|
|
|
|
from ..common.tensor import Tensor
|
|
|
|
|
from ..nn.metrics import get_metrics
|
|
|
|
|
from .._checkparam import check_input_data, check_output_data, Validator
|
|
|
|
|
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
|
|
|
|
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
|
|
|
|
|
from .. import context
|
|
|
|
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
|
|
|
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
|
|
|
@ -35,6 +35,7 @@ from ..context import ParallelMode
|
|
|
|
|
from ..parallel._cost_model_context import _set_multi_subgraphs
|
|
|
|
|
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
|
|
|
|
from . import amp
|
|
|
|
|
from ..common.api import _pynative_exec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _transfer_tensor_to_tuple(inputs):
|
|
|
|
@ -47,6 +48,11 @@ def _transfer_tensor_to_tuple(inputs):
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _StepSync(Callback):
|
|
|
|
|
def step_end(self, run_context):
|
|
|
|
|
_pynative_exec.sync()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model:
|
|
|
|
|
"""
|
|
|
|
|
High-Level API for Training or Testing.
|
|
|
|
@ -365,6 +371,9 @@ class Model:
|
|
|
|
|
cb_params.device_number = self._device_number
|
|
|
|
|
cb_params.train_dataset = train_dataset
|
|
|
|
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
|
cb_params.list_callback.insert(0, _StepSync())
|
|
|
|
|
callbacks = cb_params.list_callback
|
|
|
|
|
cb_params.train_dataset_element = None
|
|
|
|
|
cb_params.network = self._network
|
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
@ -374,8 +383,8 @@ class Model:
|
|
|
|
|
with _CallbackManager(callbacks) as list_callback:
|
|
|
|
|
if not dataset_sink_mode:
|
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
|
elif context.get_context("device_target") == "CPU" or context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
|
logger.warning("The CPU or PyNative mode cannot support dataset sink mode currently."
|
|
|
|
|
elif context.get_context("device_target") == "CPU":
|
|
|
|
|
logger.warning("The CPU cannot support dataset sink mode currently."
|
|
|
|
|
"So the training process will be performed with dataset not sink.")
|
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
|
else:
|
|
|
|
@ -417,7 +426,7 @@ class Model:
|
|
|
|
|
|
|
|
|
|
run_context = RunContext(cb_params)
|
|
|
|
|
list_callback.begin(run_context)
|
|
|
|
|
|
|
|
|
|
is_graph = (context.get_context("mode") == context.GRAPH_MODE)
|
|
|
|
|
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
|
|
|
|
should_stop = False
|
|
|
|
|
dataset_helper = None
|
|
|
|
@ -441,7 +450,10 @@ class Model:
|
|
|
|
|
cb_params.train_dataset_element = inputs
|
|
|
|
|
list_callback.step_begin(run_context)
|
|
|
|
|
outputs = self._train_network(*inputs)
|
|
|
|
|
cb_params.cur_step_num += dataset_helper.sink_size()
|
|
|
|
|
if is_graph:
|
|
|
|
|
cb_params.cur_step_num += dataset_helper.sink_size()
|
|
|
|
|
else:
|
|
|
|
|
cb_params.cur_step_num += 1
|
|
|
|
|
cb_params.net_outputs = outputs
|
|
|
|
|
list_callback.step_end(run_context)
|
|
|
|
|
|
|
|
|
|