From d8b9442ab862b1af3a4b302c7e1bae29d63d6e7b Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Tue, 7 Apr 2020 17:44:07 +0800 Subject: [PATCH] dataset_sink_mode is supported in model.eval() and not in model.train() in pynative mode --- mindspore/train/model.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index bcfd897f58..657c84de65 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -206,6 +206,8 @@ class Model: function respectively. callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + Configure pynative mode, the training process will be performed with + dataset not sink. """ epoch = check_int_positive(epoch) self._train_network.set_train() @@ -227,8 +229,13 @@ class Model: cb_params.train_dataset = train_dataset cb_params.list_callback = list_callback - if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE: - self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) + if dataset_sink_mode: + if context.get_context("mode") == context.PYNATIVE_MODE: + logger.warning("The pynative mode cannot support dataset sink mode currently." + "So the training process will be performed with dataset not sink.") + self._train_process(epoch, train_dataset, list_callback, cb_params) + else: + self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) else: self._train_process(epoch, train_dataset, list_callback, cb_params) @@ -349,7 +356,7 @@ class Model: """ Training API where the iteration is controlled by python front-end. - Configure to pynative mode, the training will be performed with dataset non-sink mode. + When setting pynative mode, the training process will be performed with dataset not sink. Note: CPU is not supported when dataset_sink_mode is true. @@ -363,6 +370,8 @@ class Model: function respectively. callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + Configure pynative mode, the training process will be performed with + dataset not sink. Examples: @@ -508,7 +517,7 @@ class Model: self._clear_metrics() - if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE: + if dataset_sink_mode: return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) return self._eval_process(valid_dataset, list_callback, cb_params)