|
|
|
@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper
|
|
|
|
|
from . import amp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _transfer_tensor_to_tuple(inputs):
|
|
|
|
|
"""
|
|
|
|
|
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(inputs, Tensor):
|
|
|
|
|
return (inputs,)
|
|
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model:
|
|
|
|
|
"""
|
|
|
|
|
High-Level API for Training or Testing.
|
|
|
|
@ -476,6 +486,7 @@ class Model:
|
|
|
|
|
|
|
|
|
|
for next_element in dataset_helper:
|
|
|
|
|
len_element = len(next_element)
|
|
|
|
|
next_element = _transfer_tensor_to_tuple(next_element)
|
|
|
|
|
if self._loss_fn and len_element != 2:
|
|
|
|
|
raise ValueError("when loss_fn is not None, train_dataset should"
|
|
|
|
|
"return two elements, but got {}".format(len_element))
|
|
|
|
@ -630,6 +641,7 @@ class Model:
|
|
|
|
|
for next_element in dataset_helper:
|
|
|
|
|
cb_params.cur_step_num += 1
|
|
|
|
|
list_callback.step_begin(run_context)
|
|
|
|
|
next_element = _transfer_tensor_to_tuple(next_element)
|
|
|
|
|
outputs = self._eval_network(*next_element)
|
|
|
|
|
cb_params.net_outputs = outputs
|
|
|
|
|
list_callback.step_end(run_context)
|
|
|
|
|