fix bug in8 train twice with the same model

pull/12403/head
liyong 4 years ago
parent 8fecd185df
commit beca294b17

@ -157,6 +157,12 @@ class Model:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
raise ValueError(f"Unsupported arg '{arg}'")
def _check_reuse_dataset(self, dataset):
if not hasattr(dataset, '__model_hash__'):
dataset.__model_hash__ = hash(self)
if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.')
def _build_train_network(self):
"""Build train network"""
network = self._network
@ -388,6 +394,7 @@ class Model:
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._check_reuse_dataset(train_dataset)
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
@staticmethod

Loading…
Cancel
Save