From edd7e184d8aaa73bc607bf5d4f2028b6b1f46b06 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Tue, 7 Jul 2020 21:52:03 +0800 Subject: [PATCH] modify config api --- mindspore/dataset/__init__.py | 2 +- mindspore/dataset/core/config.py | 195 ++++++++++++++++++ mindspore/dataset/core/configuration.py | 195 ------------------ mindspore/dataset/engine/__init__.py | 5 +- .../dataset/engine/serializer_deserializer.py | 2 +- 5 files changed, 199 insertions(+), 200 deletions(-) create mode 100644 mindspore/dataset/core/config.py delete mode 100644 mindspore/dataset/core/configuration.py diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index f0070b428d..971915f27e 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -18,7 +18,7 @@ datasets in special format, including mindrecord, tfrecord, manifest. Users can also create samplers with this module to sample data. """ -from .core.configuration import config +from .core import config from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py new file mode 100644 index 0000000000..c863186d97 --- /dev/null +++ b/mindspore/dataset/core/config.py @@ -0,0 +1,195 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The configuration manager. +""" +import random +import numpy +import mindspore._c_dataengine as cde + +__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', + 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load'] + +INT32_MAX = 2147483647 +UINT32_MAX = 4294967295 + +_config = cde.GlobalContext.config_manager() + + +def set_seed(seed): + """ + Set the seed to be used in any random generator. This is used to produce deterministic results. + + Note: + This set_seed function sets the seed in the python random library and numpy.random library + for deterministic python augmentations using randomness. This set_seed function should + be called with every iterator created to reset the random seed. In our pipeline this + does not guarantee deterministic results with num_parallel_workers > 1. + + Args: + seed(int): seed to be set. + + Raises: + ValueError: If seed is invalid (< 0 or > MAX_UINT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new seed value, now operators with a random seed will use new seed value. + >>> ds.config.set_seed(1000) + """ + if seed < 0 or seed > UINT32_MAX: + raise ValueError("Seed given is not within the required range.") + _config.set_seed(seed) + random.seed(seed) + # numpy.random isn't thread safe + numpy.random.seed(seed) + + +def get_seed(): + """ + Get the seed. + + Returns: + Int, seed. + """ + return _config.get_seed() + + +def set_prefetch_size(size): + """ + Set the number of rows to be prefetched. + + Args: + size (int): total number of rows to be prefetched. + + Raises: + ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new prefetch value. + >>> ds.config.set_prefetch_size(1000) + """ + if size <= 0 or size > INT32_MAX: + raise ValueError("Prefetch size given is not within the required range.") + _config.set_op_connector_size(size) + + +def get_prefetch_size(): + """ + Get the prefetch size in number of rows. + + Returns: + Size, total number of rows to be prefetched. + """ + return _config.get_op_connector_size() + + +def set_num_parallel_workers(num): + """ + Set the default number of parallel workers. + + Args: + num (int): number of parallel workers to be used as a default for each operation. + + Raises: + ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers. + >>> ds.config.set_num_parallel_workers(8) + """ + if num <= 0 or num > INT32_MAX: + raise ValueError("Num workers given is not within the required range.") + _config.set_num_parallel_workers(num) + + +def get_num_parallel_workers(): + """ + Get the default number of parallel workers. + + Returns: + Int, number of parallel workers to be used as a default for each operation + """ + return _config.get_num_parallel_workers() + + +def set_monitor_sampling_interval(interval): + """ + Set the default interval(ms) of monitor sampling. + + Args: + interval (int): interval(ms) to be used to performance monitor sampling. + + Raises: + ValueError: If interval is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new interval value. + >>> ds.config.set_monitor_sampling_interval(100) + """ + if interval <= 0 or interval > INT32_MAX: + raise ValueError("Interval given is not within the required range.") + _config.set_monitor_sampling_interval(interval) + + +def get_monitor_sampling_interval(): + """ + Get the default interval of performance monitor sampling. + + Returns: + Interval: interval(ms) of performance monitor sampling. + """ + return _config.get_monitor_sampling_interval() + + +def __str__(): + """ + String representation of the configurations. + + Returns: + Str, configurations. + """ + return str(_config) + + +def load(file): + """ + Load configuration from a file. + + Args: + file (str): path the config file to be loaded. + + Raises: + RuntimeError: If file is invalid and parsing fails. + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the default value according to values in configuration file. + >>> ds.config.load("path/to/config/file") + >>> # example config file: + >>> # { + >>> # "logFilePath": "/tmp", + >>> # "rowsPerBuffer": 32, + >>> # "numParallelWorkers": 4, + >>> # "workerConnectorSize": 16, + >>> # "opConnectorSize": 16, + >>> # "seed": 5489, + >>> # "monitorSamplingInterval": 30 + >>> # } + """ + _config.load(file) diff --git a/mindspore/dataset/core/configuration.py b/mindspore/dataset/core/configuration.py deleted file mode 100644 index 5376c668c4..0000000000 --- a/mindspore/dataset/core/configuration.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -The configuration manager. -""" -import random -import numpy -import mindspore._c_dataengine as cde - -INT32_MAX = 2147483647 -UINT32_MAX = 4294967295 - - -class ConfigurationManager: - """The configuration manager""" - - def __init__(self): - self.config = cde.GlobalContext.config_manager() - - def set_seed(self, seed): - """ - Set the seed to be used in any random generator. This is used to produce deterministic results. - - Note: - This set_seed function sets the seed in the python random library and numpy.random library - for deterministic python augmentations using randomness. This set_seed function should - be called with every iterator created to reset the random seed. In our pipeline this - does not guarantee deterministic results with num_parallel_workers > 1. - - Args: - seed(int): seed to be set - - Raises: - ValueError: If seed is invalid (< 0 or > MAX_UINT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new seed value, now operators with a random seed will use new seed value. - >>> con.set_seed(1000) - """ - if seed < 0 or seed > UINT32_MAX: - raise ValueError("Seed given is not within the required range") - self.config.set_seed(seed) - random.seed(seed) - # numpy.random isn't thread safe - numpy.random.seed(seed) - - def get_seed(self): - """ - Get the seed - - Returns: - Int, seed. - """ - return self.config.get_seed() - - def set_prefetch_size(self, size): - """ - Set the number of rows to be prefetched. - - Args: - size: total number of rows to be prefetched. - - Raises: - ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new prefetch value. - >>> con.set_prefetch_size(1000) - """ - if size <= 0 or size > INT32_MAX: - raise ValueError("Prefetch size given is not within the required range") - self.config.set_op_connector_size(size) - - def get_prefetch_size(self): - """ - Get the prefetch size in number of rows. - - Returns: - Size, total number of rows to be prefetched. - """ - return self.config.get_op_connector_size() - - def set_num_parallel_workers(self, num): - """ - Set the default number of parallel workers - - Args: - num: number of parallel workers to be used as a default for each operation - - Raises: - ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers. - >>> con.set_num_parallel_workers(8) - """ - if num <= 0 or num > INT32_MAX: - raise ValueError("Num workers given is not within the required range") - self.config.set_num_parallel_workers(num) - - def get_num_parallel_workers(self): - """ - Get the default number of parallel workers. - - Returns: - Int, number of parallel workers to be used as a default for each operation - """ - return self.config.get_num_parallel_workers() - - def set_monitor_sampling_interval(self, interval): - """ - Set the default interval(ms) of monitor sampling. - - Args: - interval: interval(ms) to be used to performance monitor sampling. - - Raises: - ValueError: If interval is invalid (<= 0 or > MAX_INT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new interval value. - >>> con.set_monitor_sampling_interval(100) - """ - if interval <= 0 or interval > INT32_MAX: - raise ValueError("Interval given is not within the required range") - self.config.set_monitor_sampling_interval(interval) - - def get_monitor_sampling_interval(self): - """ - Get the default interval of performance monitor sampling. - - Returns: - Interval: interval(ms) of performance monitor sampling. - """ - return self.config.get_monitor_sampling_interval() - - def __str__(self): - """ - String representation of the configurations. - - Returns: - Str, configurations. - """ - return str(self.config) - - def load(self, file): - """ - Load configuration from a file. - - Args: - file: path the config file to be loaded - - Raises: - RuntimeError: If file is invalid and parsing fails. - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the default value according to values in configuration file. - >>> con.load("path/to/config/file") - >>> # example config file: - >>> # { - >>> # "logFilePath": "/tmp", - >>> # "rowsPerBuffer": 32, - >>> # "numParallelWorkers": 4, - >>> # "workerConnectorSize": 16, - >>> # "opConnectorSize": 16, - >>> # "seed": 5489, - >>> # "monitorSamplingInterval": 30 - >>> # } - """ - self.config.load(file) - - -config = ConfigurationManager() diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 674848f156..b3624e1ca3 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -26,10 +26,9 @@ from .datasets import * from .iterators import * from .serializer_deserializer import serialize, deserialize, show, compare from .samplers import * -from ..core.configuration import config, ConfigurationManager +from ..core import config -__all__ = ["config", "ConfigurationManager", "zip", - "ImageFolderDatasetV2", "MnistDataset", +__all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 9d3339e26d..a1b9e908f3 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -22,7 +22,7 @@ import sys from mindspore import log as logger from . import datasets as de from ..transforms.vision.utils import Inter, Border -from ..core.configuration import config +from ..core import config def serialize(dataset, json_filepath=None): """