|
|
@ -19,7 +19,7 @@ from mindspore import log as logger
|
|
|
|
from ..common.tensor import Tensor
|
|
|
|
from ..common.tensor import Tensor
|
|
|
|
from ..nn.metrics import get_metrics
|
|
|
|
from ..nn.metrics import get_metrics
|
|
|
|
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
|
|
|
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
|
|
|
from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
|
|
|
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
|
|
|
from .. import context
|
|
|
|
from .. import context
|
|
|
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
|
|
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
|
|
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
|
|
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
|
|
@ -332,8 +332,6 @@ class Model:
|
|
|
|
if self._parameter_broadcast:
|
|
|
|
if self._parameter_broadcast:
|
|
|
|
self._train_network.set_broadcast_flag()
|
|
|
|
self._train_network.set_broadcast_flag()
|
|
|
|
|
|
|
|
|
|
|
|
# build callback list
|
|
|
|
|
|
|
|
list_callback = _build_callbacks(callbacks)
|
|
|
|
|
|
|
|
cb_params = _InternalCallbackParam()
|
|
|
|
cb_params = _InternalCallbackParam()
|
|
|
|
cb_params.train_network = self._train_network
|
|
|
|
cb_params.train_network = self._train_network
|
|
|
|
cb_params.epoch_num = epoch
|
|
|
|
cb_params.epoch_num = epoch
|
|
|
@ -344,17 +342,18 @@ class Model:
|
|
|
|
cb_params.parallel_mode = self._parallel_mode
|
|
|
|
cb_params.parallel_mode = self._parallel_mode
|
|
|
|
cb_params.device_number = self._device_number
|
|
|
|
cb_params.device_number = self._device_number
|
|
|
|
cb_params.train_dataset = train_dataset
|
|
|
|
cb_params.train_dataset = train_dataset
|
|
|
|
cb_params.list_callback = list_callback
|
|
|
|
cb_params.list_callback = callbacks
|
|
|
|
|
|
|
|
|
|
|
|
if dataset_sink_mode:
|
|
|
|
# build callback list
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
with _CallbackManager(callbacks) as list_callback:
|
|
|
|
|
|
|
|
if not dataset_sink_mode:
|
|
|
|
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
|
|
|
|
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
|
|
|
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
|
|
|
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
|
|
|
"So the training process will be performed with dataset not sink.")
|
|
|
|
"So the training process will be performed with dataset not sink.")
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
else:
|
|
|
|
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
|
|
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -367,7 +366,7 @@ class Model:
|
|
|
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
|
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
|
|
be returned, and the data and label are passed to the network and loss
|
|
|
|
be returned, and the data and label are passed to the network and loss
|
|
|
|
function respectively.
|
|
|
|
function respectively.
|
|
|
|
list_callback (_ListCallback): Executor of callback list. Default: None.
|
|
|
|
list_callback (Callback): Executor of callback list. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
|
|
|
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
|
|
@ -415,7 +414,7 @@ class Model:
|
|
|
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
|
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
|
|
be returned, and the data and label are passed to the network and loss
|
|
|
|
be returned, and the data and label are passed to the network and loss
|
|
|
|
function respectively.
|
|
|
|
function respectively.
|
|
|
|
list_callback (_ListCallback): Executor of callback list. Default: None.
|
|
|
|
list_callback (Callback): Executor of callback list. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
|
|
|
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
|
|
@ -522,7 +521,7 @@ class Model:
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
|
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
|
|
list_callback (ListCallback): Executor of callback list. Default: None.
|
|
|
|
list_callback (Callback): Executor of callback list. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
@ -561,7 +560,7 @@ class Model:
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
|
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
|
|
list_callback (ListCallback): Executor of callback list. Default: None.
|
|
|
|
list_callback (Callback): Executor of callback list. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
@ -620,7 +619,6 @@ class Model:
|
|
|
|
if not self._metric_fns:
|
|
|
|
if not self._metric_fns:
|
|
|
|
raise ValueError("metric fn can not be None or empty.")
|
|
|
|
raise ValueError("metric fn can not be None or empty.")
|
|
|
|
|
|
|
|
|
|
|
|
list_callback = _build_callbacks(callbacks)
|
|
|
|
|
|
|
|
cb_params = _InternalCallbackParam()
|
|
|
|
cb_params = _InternalCallbackParam()
|
|
|
|
cb_params.eval_network = self._eval_network
|
|
|
|
cb_params.eval_network = self._eval_network
|
|
|
|
cb_params.valid_dataset = valid_dataset
|
|
|
|
cb_params.valid_dataset = valid_dataset
|
|
|
@ -633,9 +631,10 @@ class Model:
|
|
|
|
|
|
|
|
|
|
|
|
self._clear_metrics()
|
|
|
|
self._clear_metrics()
|
|
|
|
|
|
|
|
|
|
|
|
if dataset_sink_mode:
|
|
|
|
with _CallbackManager(callbacks) as list_callback:
|
|
|
|
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
|
|
|
if dataset_sink_mode:
|
|
|
|
return self._eval_process(valid_dataset, list_callback, cb_params)
|
|
|
|
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
|
|
|
|
|
|
|
return self._eval_process(valid_dataset, list_callback, cb_params)
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, *predict_data):
|
|
|
|
def predict(self, *predict_data):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|