modify config api

pull/2936/head
ms_yan 5 years ago
parent 8844462e15
commit edd7e184d8

@ -18,7 +18,7 @@ datasets in special format, including mindrecord, tfrecord, manifest. Users
can also create samplers with this module to sample data. can also create samplers with this module to sample data.
""" """
from .core.configuration import config from .core import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset

@ -0,0 +1,195 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration manager.
"""
import random
import numpy
import mindspore._c_dataengine as cde
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load']
INT32_MAX = 2147483647
UINT32_MAX = 4294967295
_config = cde.GlobalContext.config_manager()
def set_seed(seed):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set.
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new seed value, now operators with a random seed will use new seed value.
>>> ds.config.set_seed(1000)
"""
if seed < 0 or seed > UINT32_MAX:
raise ValueError("Seed given is not within the required range.")
_config.set_seed(seed)
random.seed(seed)
# numpy.random isn't thread safe
numpy.random.seed(seed)
def get_seed():
"""
Get the seed.
Returns:
Int, seed.
"""
return _config.get_seed()
def set_prefetch_size(size):
"""
Set the number of rows to be prefetched.
Args:
size (int): total number of rows to be prefetched.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new prefetch value.
>>> ds.config.set_prefetch_size(1000)
"""
if size <= 0 or size > INT32_MAX:
raise ValueError("Prefetch size given is not within the required range.")
_config.set_op_connector_size(size)
def get_prefetch_size():
"""
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
"""
return _config.get_op_connector_size()
def set_num_parallel_workers(num):
"""
Set the default number of parallel workers.
Args:
num (int): number of parallel workers to be used as a default for each operation.
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
>>> ds.config.set_num_parallel_workers(8)
"""
if num <= 0 or num > INT32_MAX:
raise ValueError("Num workers given is not within the required range.")
_config.set_num_parallel_workers(num)
def get_num_parallel_workers():
"""
Get the default number of parallel workers.
Returns:
Int, number of parallel workers to be used as a default for each operation
"""
return _config.get_num_parallel_workers()
def set_monitor_sampling_interval(interval):
"""
Set the default interval(ms) of monitor sampling.
Args:
interval (int): interval(ms) to be used to performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new interval value.
>>> ds.config.set_monitor_sampling_interval(100)
"""
if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range.")
_config.set_monitor_sampling_interval(interval)
def get_monitor_sampling_interval():
"""
Get the default interval of performance monitor sampling.
Returns:
Interval: interval(ms) of performance monitor sampling.
"""
return _config.get_monitor_sampling_interval()
def __str__():
"""
String representation of the configurations.
Returns:
Str, configurations.
"""
return str(_config)
def load(file):
"""
Load configuration from a file.
Args:
file (str): path the config file to be loaded.
Raises:
RuntimeError: If file is invalid and parsing fails.
Examples:
>>> import mindspore.dataset as ds
>>> # sets the default value according to values in configuration file.
>>> ds.config.load("path/to/config/file")
>>> # example config file:
>>> # {
>>> # "logFilePath": "/tmp",
>>> # "rowsPerBuffer": 32,
>>> # "numParallelWorkers": 4,
>>> # "workerConnectorSize": 16,
>>> # "opConnectorSize": 16,
>>> # "seed": 5489,
>>> # "monitorSamplingInterval": 30
>>> # }
"""
_config.load(file)

@ -1,195 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration manager.
"""
import random
import numpy
import mindspore._c_dataengine as cde
INT32_MAX = 2147483647
UINT32_MAX = 4294967295
class ConfigurationManager:
"""The configuration manager"""
def __init__(self):
self.config = cde.GlobalContext.config_manager()
def set_seed(self, seed):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new seed value, now operators with a random seed will use new seed value.
>>> con.set_seed(1000)
"""
if seed < 0 or seed > UINT32_MAX:
raise ValueError("Seed given is not within the required range")
self.config.set_seed(seed)
random.seed(seed)
# numpy.random isn't thread safe
numpy.random.seed(seed)
def get_seed(self):
"""
Get the seed
Returns:
Int, seed.
"""
return self.config.get_seed()
def set_prefetch_size(self, size):
"""
Set the number of rows to be prefetched.
Args:
size: total number of rows to be prefetched.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new prefetch value.
>>> con.set_prefetch_size(1000)
"""
if size <= 0 or size > INT32_MAX:
raise ValueError("Prefetch size given is not within the required range")
self.config.set_op_connector_size(size)
def get_prefetch_size(self):
"""
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
"""
return self.config.get_op_connector_size()
def set_num_parallel_workers(self, num):
"""
Set the default number of parallel workers
Args:
num: number of parallel workers to be used as a default for each operation
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
>>> con.set_num_parallel_workers(8)
"""
if num <= 0 or num > INT32_MAX:
raise ValueError("Num workers given is not within the required range")
self.config.set_num_parallel_workers(num)
def get_num_parallel_workers(self):
"""
Get the default number of parallel workers.
Returns:
Int, number of parallel workers to be used as a default for each operation
"""
return self.config.get_num_parallel_workers()
def set_monitor_sampling_interval(self, interval):
"""
Set the default interval(ms) of monitor sampling.
Args:
interval: interval(ms) to be used to performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new interval value.
>>> con.set_monitor_sampling_interval(100)
"""
if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range")
self.config.set_monitor_sampling_interval(interval)
def get_monitor_sampling_interval(self):
"""
Get the default interval of performance monitor sampling.
Returns:
Interval: interval(ms) of performance monitor sampling.
"""
return self.config.get_monitor_sampling_interval()
def __str__(self):
"""
String representation of the configurations.
Returns:
Str, configurations.
"""
return str(self.config)
def load(self, file):
"""
Load configuration from a file.
Args:
file: path the config file to be loaded
Raises:
RuntimeError: If file is invalid and parsing fails.
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the default value according to values in configuration file.
>>> con.load("path/to/config/file")
>>> # example config file:
>>> # {
>>> # "logFilePath": "/tmp",
>>> # "rowsPerBuffer": 32,
>>> # "numParallelWorkers": 4,
>>> # "workerConnectorSize": 16,
>>> # "opConnectorSize": 16,
>>> # "seed": 5489,
>>> # "monitorSamplingInterval": 30
>>> # }
"""
self.config.load(file)
config = ConfigurationManager()

@ -26,10 +26,9 @@ from .datasets import *
from .iterators import * from .iterators import *
from .serializer_deserializer import serialize, deserialize, show, compare from .serializer_deserializer import serialize, deserialize, show, compare
from .samplers import * from .samplers import *
from ..core.configuration import config, ConfigurationManager from ..core import config
__all__ = ["config", "ConfigurationManager", "zip", __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset",
"ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler",

@ -22,7 +22,7 @@ import sys
from mindspore import log as logger from mindspore import log as logger
from . import datasets as de from . import datasets as de
from ..transforms.vision.utils import Inter, Border from ..transforms.vision.utils import Inter, Border
from ..core.configuration import config from ..core import config
def serialize(dataset, json_filepath=None): def serialize(dataset, json_filepath=None):
""" """

Loading…
Cancel
Save