From 6222e3252fa1a0c073c510d7e8e02654a2aff1df Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Fri, 5 Mar 2021 11:32:25 +0800 Subject: [PATCH] fix map object of dataset pickle failed issue --- mindspore/train/dataset_helper.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 92a61965d1..e145ccd775 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -102,24 +102,24 @@ def connect_network_with_dataset(network, dataset_helper): and context.get_context("mode") == context.GRAPH_MODE \ and ms_role != "MS_WORKER": - if not hasattr(dataset, '__network__'): - dataset.__network__ = network - network = dataset.__network__ + if not hasattr(dataset_iter, '__network__'): + dataset_iter.__network__ = network + network = dataset_iter.__network__ dataset_types, dataset_shapes = dataset_helper.get_data_info() dataset_types = [pytype_to_dtype(x) for x in dataset_types] key = str(dataset_types) + str(dataset_shapes) - if hasattr(dataset, '__network_manage__') and key in dataset.__network_manage__: - network = dataset.__network_manage__[key] + if hasattr(dataset_iter, '__network_manage__') and key in dataset_iter.__network_manage__: + network = dataset_iter.__network_manage__[key] else: if _need_to_full(): device_num = _get_device_num() dataset_shapes = _to_full_shapes(dataset_shapes, device_num) network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) - dataset.__network_manage__ = dataset.__network_manage__ if hasattr( - dataset, '__network_manage__') else dict() - dataset.__network_manage__[key] = network + dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( + dataset_iter, '__network_manage__') else dict() + dataset_iter.__network_manage__[key] = network return network