From 3e037b6e8a3a7ed56f2998e491c0e88f4e187b2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=98=89=E7=90=AA?= Date: Wed, 26 Aug 2020 14:45:46 +0800 Subject: [PATCH] dataset_return_single_value --- mindspore/train/model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 324b09b453..7e28eb9704 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper from . import amp +def _transfer_tensor_to_tuple(inputs): + """ + If the input is a tensor, convert it to a tuple. If not, the output is unchanged. + """ + if isinstance(inputs, Tensor): + return (inputs,) + + return inputs + + class Model: """ High-Level API for Training or Testing. @@ -476,6 +486,7 @@ class Model: for next_element in dataset_helper: len_element = len(next_element) + next_element = _transfer_tensor_to_tuple(next_element) if self._loss_fn and len_element != 2: raise ValueError("when loss_fn is not None, train_dataset should" "return two elements, but got {}".format(len_element)) @@ -630,6 +641,7 @@ class Model: for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) + next_element = _transfer_tensor_to_tuple(next_element) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs list_callback.step_end(run_context)