|
|
|
@ -157,6 +157,12 @@ class Model:
|
|
|
|
|
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
|
|
|
|
raise ValueError(f"Unsupported arg '{arg}'")
|
|
|
|
|
|
|
|
|
|
def _check_reuse_dataset(self, dataset):
|
|
|
|
|
if not hasattr(dataset, '__model_hash__'):
|
|
|
|
|
dataset.__model_hash__ = hash(self)
|
|
|
|
|
if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
|
|
|
|
|
raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.')
|
|
|
|
|
|
|
|
|
|
def _build_train_network(self):
|
|
|
|
|
"""Build train network"""
|
|
|
|
|
network = self._network
|
|
|
|
@ -388,6 +394,7 @@ class Model:
|
|
|
|
|
"So the training process will be performed with dataset not sink.")
|
|
|
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
|
|
|
else:
|
|
|
|
|
self._check_reuse_dataset(train_dataset)
|
|
|
|
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|