|
|
|
@ -134,7 +134,7 @@ def check_tfrecorddataset(method):
|
|
|
|
|
|
|
|
|
|
dataset_files = param_dict.get('dataset_files')
|
|
|
|
|
if not isinstance(dataset_files, (str, list)):
|
|
|
|
|
raise TypeError("dataset_files should be of type str or a list of strings.")
|
|
|
|
|
raise TypeError("dataset_files should be type str or a list of strings.")
|
|
|
|
|
|
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
|
|
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
|
|
@ -173,11 +173,11 @@ def check_vocdataset(method):
|
|
|
|
|
if task == "Segmentation":
|
|
|
|
|
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
|
|
|
|
|
if param_dict.get('class_indexing') is not None:
|
|
|
|
|
raise ValueError("class_indexing is invalid in Segmentation task")
|
|
|
|
|
raise ValueError("class_indexing is not supported in Segmentation task.")
|
|
|
|
|
elif task == "Detection":
|
|
|
|
|
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Invalid task : " + task)
|
|
|
|
|
raise ValueError("Invalid task : " + task + ".")
|
|
|
|
|
|
|
|
|
|
check_file(imagesets_file)
|
|
|
|
|
|
|
|
|
@ -214,7 +214,7 @@ def check_cocodataset(method):
|
|
|
|
|
type_check(task, (str,), "task")
|
|
|
|
|
|
|
|
|
|
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
|
|
|
|
|
raise ValueError("Invalid task type")
|
|
|
|
|
raise ValueError("Invalid task type: " + task + ".")
|
|
|
|
|
|
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
|
|
|
|
|
|
|
@ -222,7 +222,7 @@ def check_cocodataset(method):
|
|
|
|
|
|
|
|
|
|
sampler = param_dict.get('sampler')
|
|
|
|
|
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
|
|
|
|
raise ValueError("CocoDataset doesn't support PKSampler")
|
|
|
|
|
raise ValueError("CocoDataset doesn't support PKSampler.")
|
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
|
|
|
|
|
|
cache = param_dict.get('cache')
|
|
|
|
@ -256,13 +256,13 @@ def check_celebadataset(method):
|
|
|
|
|
|
|
|
|
|
usage = param_dict.get('usage')
|
|
|
|
|
if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
|
|
|
|
|
raise ValueError("usage should be one of 'all', 'train', 'valid' or 'test'.")
|
|
|
|
|
raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.")
|
|
|
|
|
|
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
|
|
|
|
|
|
sampler = param_dict.get('sampler')
|
|
|
|
|
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
|
|
|
|
raise ValueError("CelebADataset does not support PKSampler.")
|
|
|
|
|
raise ValueError("CelebADataset doesn't support PKSampler.")
|
|
|
|
|
|
|
|
|
|
cache = param_dict.get('cache')
|
|
|
|
|
check_cache_option(cache)
|
|
|
|
@ -350,14 +350,14 @@ def check_generatordataset(method):
|
|
|
|
|
try:
|
|
|
|
|
iter(source)
|
|
|
|
|
except TypeError:
|
|
|
|
|
raise TypeError("source should be callable, iterable or random accessible")
|
|
|
|
|
raise TypeError("source should be callable, iterable or random accessible.")
|
|
|
|
|
|
|
|
|
|
column_names = param_dict.get('column_names')
|
|
|
|
|
if column_names is not None:
|
|
|
|
|
check_columns(column_names, "column_names")
|
|
|
|
|
schema = param_dict.get('schema')
|
|
|
|
|
if column_names is None and schema is None:
|
|
|
|
|
raise ValueError("Neither columns_names not schema are provided.")
|
|
|
|
|
raise ValueError("Neither columns_names nor schema are provided.")
|
|
|
|
|
|
|
|
|
|
if schema is not None:
|
|
|
|
|
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
|
|
|
|
@ -375,7 +375,7 @@ def check_generatordataset(method):
|
|
|
|
|
shard_id = param_dict.get("shard_id")
|
|
|
|
|
if (num_shards is None) != (shard_id is None):
|
|
|
|
|
# These two parameters appear together.
|
|
|
|
|
raise ValueError("num_shards and shard_id need to be passed in together")
|
|
|
|
|
raise ValueError("num_shards and shard_id need to be passed in together.")
|
|
|
|
|
if num_shards is not None:
|
|
|
|
|
check_pos_int32(num_shards, "num_shards")
|
|
|
|
|
if shard_id >= num_shards:
|
|
|
|
@ -384,19 +384,19 @@ def check_generatordataset(method):
|
|
|
|
|
sampler = param_dict.get("sampler")
|
|
|
|
|
if sampler is not None:
|
|
|
|
|
if isinstance(sampler, samplers.PKSampler):
|
|
|
|
|
raise ValueError("PKSampler is not supported by GeneratorDataset")
|
|
|
|
|
raise ValueError("GeneratorDataset doesn't support PKSampler.")
|
|
|
|
|
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
|
|
|
|
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
|
|
|
|
samplers.WeightedRandomSampler, samplers.Sampler)):
|
|
|
|
|
try:
|
|
|
|
|
iter(sampler)
|
|
|
|
|
except TypeError:
|
|
|
|
|
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
|
|
|
|
|
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.")
|
|
|
|
|
|
|
|
|
|
if sampler is not None and not hasattr(source, "__getitem__"):
|
|
|
|
|
raise ValueError("sampler is not supported if source does not have attribute '__getitem__'")
|
|
|
|
|
raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.")
|
|
|
|
|
if num_shards is not None and not hasattr(source, "__getitem__"):
|
|
|
|
|
raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'")
|
|
|
|
|
raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.")
|
|
|
|
|
|
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -433,7 +433,7 @@ def check_pad_info(key, val):
|
|
|
|
|
type_check(key, (str,), "key in pad_info")
|
|
|
|
|
|
|
|
|
|
if val is not None:
|
|
|
|
|
assert len(val) == 2, "value of pad_info should be a tuple of size 2"
|
|
|
|
|
assert len(val) == 2, "value of pad_info should be a tuple of size 2."
|
|
|
|
|
type_check(val, (tuple,), "value in pad_info")
|
|
|
|
|
|
|
|
|
|
if val[0] is not None:
|
|
|
|
@ -521,14 +521,14 @@ def check_batch(method):
|
|
|
|
|
if callable(batch_size):
|
|
|
|
|
sig = ins.signature(batch_size)
|
|
|
|
|
if len(sig.parameters) != 1:
|
|
|
|
|
raise ValueError("batch_size callable should take one parameter (BatchInfo).")
|
|
|
|
|
raise ValueError("callable batch_size should take one parameter (BatchInfo).")
|
|
|
|
|
|
|
|
|
|
if num_parallel_workers is not None:
|
|
|
|
|
check_num_parallel_workers(num_parallel_workers)
|
|
|
|
|
type_check(drop_remainder, (bool,), "drop_remainder")
|
|
|
|
|
|
|
|
|
|
if (pad_info is not None) and (per_batch_map is not None):
|
|
|
|
|
raise ValueError("pad_info and per_batch_map can't both be set")
|
|
|
|
|
raise ValueError("pad_info and per_batch_map can't both be set.")
|
|
|
|
|
|
|
|
|
|
if pad_info is not None:
|
|
|
|
|
type_check(param_dict["pad_info"], (dict,), "pad_info")
|
|
|
|
@ -542,7 +542,7 @@ def check_batch(method):
|
|
|
|
|
if input_columns is not None:
|
|
|
|
|
check_columns(input_columns, "input_columns")
|
|
|
|
|
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
|
|
|
|
|
raise ValueError("the signature of per_batch_map should match with input columns")
|
|
|
|
|
raise ValueError("The signature of per_batch_map should match with input columns.")
|
|
|
|
|
|
|
|
|
|
if output_columns is not None:
|
|
|
|
|
check_columns(output_columns, "output_columns")
|
|
|
|
@ -816,13 +816,13 @@ def check_add_column(method):
|
|
|
|
|
type_check(name, (str,), "name")
|
|
|
|
|
|
|
|
|
|
if not name:
|
|
|
|
|
raise TypeError("Expected non-empty string.")
|
|
|
|
|
raise TypeError("Expected non-empty string for column name.")
|
|
|
|
|
|
|
|
|
|
if de_type is not None:
|
|
|
|
|
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
|
|
|
|
|
raise TypeError("Unknown column type.")
|
|
|
|
|
raise TypeError("Unknown column type: {}.".format(de_type))
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("Expected non-empty string.")
|
|
|
|
|
raise TypeError("Expected non-empty string for de_type.")
|
|
|
|
|
|
|
|
|
|
if shape is not None:
|
|
|
|
|
type_check(shape, (list,), "shape")
|
|
|
|
@ -848,12 +848,12 @@ def check_cluedataset(method):
|
|
|
|
|
# check task
|
|
|
|
|
task_param = param_dict.get('task')
|
|
|
|
|
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
|
|
|
|
|
raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
|
|
|
|
|
raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.")
|
|
|
|
|
|
|
|
|
|
# check usage
|
|
|
|
|
usage_param = param_dict.get('usage')
|
|
|
|
|
if usage_param not in ['train', 'test', 'eval']:
|
|
|
|
|
raise ValueError("usage should be train, test or eval")
|
|
|
|
|
raise ValueError("usage should be 'train', 'test' or 'eval'.")
|
|
|
|
|
|
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
@ -883,7 +883,7 @@ def check_csvdataset(method):
|
|
|
|
|
field_delim = param_dict.get('field_delim')
|
|
|
|
|
type_check(field_delim, (str,), 'field delim')
|
|
|
|
|
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
|
|
|
|
|
raise ValueError("field_delim is not legal.")
|
|
|
|
|
raise ValueError("field_delim is invalid.")
|
|
|
|
|
|
|
|
|
|
# check column_defaults
|
|
|
|
|
column_defaults = param_dict.get('column_defaults')
|
|
|
|
@ -892,7 +892,7 @@ def check_csvdataset(method):
|
|
|
|
|
raise TypeError("column_defaults should be type of list.")
|
|
|
|
|
for item in column_defaults:
|
|
|
|
|
if not isinstance(item, (str, int, float)):
|
|
|
|
|
raise TypeError("column type is not legal.")
|
|
|
|
|
raise TypeError("column type in column_defaults is invalid.")
|
|
|
|
|
|
|
|
|
|
# check column_names: must be list of string.
|
|
|
|
|
column_names = param_dict.get("column_names")
|
|
|
|
@ -997,7 +997,7 @@ def check_gnn_graphdata(method):
|
|
|
|
|
raise ValueError("The hostname is illegal")
|
|
|
|
|
type_check(working_mode, (str,), "working_mode")
|
|
|
|
|
if working_mode not in {'local', 'client', 'server'}:
|
|
|
|
|
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
|
|
|
|
|
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.")
|
|
|
|
|
type_check(port, (int,), "port")
|
|
|
|
|
check_value(port, (1024, 65535), "port")
|
|
|
|
|
type_check(num_client, (int,), "num_client")
|
|
|
|
@ -1073,17 +1073,17 @@ def check_gnn_get_sampled_neighbors(method):
|
|
|
|
|
|
|
|
|
|
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
|
|
|
|
|
if not neighbor_nums or len(neighbor_nums) > 6:
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
|
|
|
|
|
'neighbor_nums', len(neighbor_nums)))
|
|
|
|
|
|
|
|
|
|
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
|
|
|
|
if not neighbor_types or len(neighbor_types) > 6:
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
|
|
|
|
|
'neighbor_types', len(neighbor_types)))
|
|
|
|
|
|
|
|
|
|
if len(neighbor_nums) != len(neighbor_types):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The number of members of neighbor_nums and neighbor_types is inconsistent")
|
|
|
|
|
"The number of members of neighbor_nums and neighbor_types is inconsistent.")
|
|
|
|
|
|
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -1139,17 +1139,17 @@ def check_aligned_list(param, param_name, member_type):
|
|
|
|
|
check_aligned_list(member, param_name, member_type)
|
|
|
|
|
|
|
|
|
|
if member_have_list not in (None, True):
|
|
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
|
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
|
|
|
|
|
param_name))
|
|
|
|
|
if list_len is not None and len(member) != list_len:
|
|
|
|
|
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
|
|
|
|
|
raise TypeError("The size of each member of parameter {0} is inconsistent.".format(
|
|
|
|
|
param_name))
|
|
|
|
|
member_have_list = True
|
|
|
|
|
list_len = len(member)
|
|
|
|
|
else:
|
|
|
|
|
type_check(member, (member_type,), param_name)
|
|
|
|
|
if member_have_list not in (None, False):
|
|
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
|
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
|
|
|
|
|
param_name))
|
|
|
|
|
member_have_list = False
|
|
|
|
|
|
|
|
|
@ -1248,7 +1248,7 @@ def check_paddeddataset(method):
|
|
|
|
|
|
|
|
|
|
padded_samples = param_dict.get("padded_samples")
|
|
|
|
|
if not padded_samples:
|
|
|
|
|
raise ValueError("Argument padded_samples cannot be empty")
|
|
|
|
|
raise ValueError("padded_samples cannot be empty.")
|
|
|
|
|
type_check(padded_samples, (list,), "padded_samples")
|
|
|
|
|
type_check(padded_samples[0], (dict,), "padded_element")
|
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
@ -1261,6 +1261,6 @@ def check_cache_option(cache):
|
|
|
|
|
if cache is not None:
|
|
|
|
|
if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
|
|
|
|
|
# temporary disable cache feature in the current release
|
|
|
|
|
raise ValueError("Caching is disabled in the current release")
|
|
|
|
|
raise ValueError("Caching is disabled in the current release.")
|
|
|
|
|
from . import cache_client
|
|
|
|
|
type_check(cache, (cache_client.DatasetCache,), "cache")
|
|
|
|
|