!8987 support getnext in pynative mode

From: @chujinjin
Reviewed-by: 
Signed-off-by:
pull/8987/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d38f8205dc

@ -1383,5 +1383,14 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const
RunInfer(func_graph, inputs);
return CompileGraphImpl(func_graph);
}
void AscendSession::SyncStream() {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream();
if (!ret) {
MS_LOG(ERROR) << "Sync stream error!";
}
}
} // namespace session
} // namespace mindspore

@ -48,6 +48,7 @@ class AscendSession : public SessionBasic {
void Init(uint32_t device_id) override;
// get graph id of final graph
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
void SyncStream() override;
protected:
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;

@ -508,6 +508,15 @@ void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &ke
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get());
}
void GPUSession::SyncStream() {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream();
if (!ret) {
MS_LOG(ERROR) << "Sync stream error!";
}
}
} // namespace gpu
} // namespace session
} // namespace mindspore

@ -32,6 +32,7 @@ class GPUSession : public SessionBasic {
GPUSession() = default;
~GPUSession() override = default;
void Init(uint32_t device_id) override;
void SyncStream() override;
protected:
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;

@ -66,9 +66,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
}
virtual void Init(uint32_t device_id) { device_id_ = device_id; }
void InitExecutor(const std::string &device_name, uint32_t device_id);
virtual void SyncStream() {}
virtual ~SessionBasic() { summary_callback_ = nullptr; }
GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);

@ -2113,6 +2113,13 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
}
void PynativeExecutor::Sync() {
if (session == nullptr) {
MS_EXCEPTION(NotExistsError) << "No session has been created!";
}
session->SyncStream();
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
@ -2121,6 +2128,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
.def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
.def("clear", &PynativeExecutor::Clear, "pynative clear status.")
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
.def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
"Executor run function.")
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),

@ -96,6 +96,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void Clean();
// Destrcut call
void ClearRes();
// Sync stream
void Sync();
private:
PynativeExecutor() = default;

@ -314,6 +314,9 @@ class _PynativeExecutor:
def clear(self, flag=""):
self._executor.clear(flag)
def sync(self):
self._executor.sync()
def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag)

@ -67,6 +67,7 @@ def connect_network_with_dataset(network, dataset_helper):
>>> net = Net()
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
"""
class _DataWrapper(nn.Cell):
"""
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
@ -163,16 +164,20 @@ class DatasetHelper:
if context.get_context("enable_ge"):
iterclass = _DatasetIterGE
else:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "GPU":
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
iterclass = _DatasetIterPSLite
else:
if context.get_context("mode") == context.GRAPH_MODE:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
elif context.get_context("device_target") == "GPU":
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
iterclass = _DatasetIterPSLite
else:
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "CPU":
raise RuntimeError(
"Currently dataset sink mode is not supported when the device target is CPU.")
else:
iterclass = _DatasetIterPyNative
self.iter = iterclass(dataset, sink_size, epoch_num)
else:
iterclass = _DatasetIterNormal
@ -281,6 +286,20 @@ class _DatasetIterGE(_DatasetIter):
self.op = op
class _DatasetIterPyNative(_DatasetIter):
"""Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
def op():
return tuple()
self.op = op
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)"""
@ -329,6 +348,7 @@ class _DatasetIterPSLite(_DatasetIter):
def op():
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
self.op = op

@ -23,7 +23,7 @@ from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
@ -35,6 +35,7 @@ from ..context import ParallelMode
from ..parallel._cost_model_context import _set_multi_subgraphs
from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp
from ..common.api import _pynative_exec
def _transfer_tensor_to_tuple(inputs):
@ -47,6 +48,11 @@ def _transfer_tensor_to_tuple(inputs):
return inputs
class _StepSync(Callback):
def step_end(self, run_context):
_pynative_exec.sync()
class Model:
"""
High-Level API for Training or Testing.
@ -365,6 +371,9 @@ class Model:
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = self._transform_callbacks(callbacks)
if context.get_context("mode") == context.PYNATIVE_MODE:
cb_params.list_callback.insert(0, _StepSync())
callbacks = cb_params.list_callback
cb_params.train_dataset_element = None
cb_params.network = self._network
if _is_role_pserver() or _is_role_sched():
@ -374,8 +383,8 @@ class Model:
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("device_target") == "CPU" or context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The CPU or PyNative mode cannot support dataset sink mode currently."
elif context.get_context("device_target") == "CPU":
logger.warning("The CPU cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
@ -417,7 +426,7 @@ class Model:
run_context = RunContext(cb_params)
list_callback.begin(run_context)
is_graph = (context.get_context("mode") == context.GRAPH_MODE)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
dataset_helper = None
@ -441,7 +450,10 @@ class Model:
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()
if is_graph:
cb_params.cur_step_num += dataset_helper.sink_size()
else:
cb_params.cur_step_num += 1
cb_params.net_outputs = outputs
list_callback.step_end(run_context)

Loading…
Cancel
Save