You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1033 lines
37 KiB
1033 lines
37 KiB
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License foNtest_resr the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""
|
|
Built-in validators.
|
|
"""
|
|
import inspect as ins
|
|
import os
|
|
from functools import wraps
|
|
|
|
import numpy as np
|
|
from mindspore._c_expression import typing
|
|
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
|
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
|
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
|
check_columns, check_positive
|
|
|
|
from . import datasets
|
|
from . import samplers
|
|
|
|
|
|
def check_imagefolderdatasetv2(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_bool = ['shuffle', 'decode']
|
|
nreq_param_list = ['extensions']
|
|
nreq_param_dict = ['class_indexing']
|
|
|
|
dataset_dir = param_dict.get('dataset_dir')
|
|
check_dir(dataset_dir)
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_mnist_cifar_dataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_bool = ['shuffle']
|
|
|
|
dataset_dir = param_dict.get('dataset_dir')
|
|
check_dir(dataset_dir)
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_manifestdataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_bool = ['shuffle', 'decode']
|
|
nreq_param_str = ['usage']
|
|
nreq_param_dict = ['class_indexing']
|
|
|
|
dataset_file = param_dict.get('dataset_file')
|
|
check_file(dataset_file)
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_tfrecorddataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_list = ['columns_list']
|
|
nreq_param_bool = ['shard_equal_rows']
|
|
|
|
dataset_files = param_dict.get('dataset_files')
|
|
if not isinstance(dataset_files, (str, list)):
|
|
raise TypeError("dataset_files should be of type str or a list of strings.")
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_vocdataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_bool = ['shuffle', 'decode']
|
|
nreq_param_dict = ['class_indexing']
|
|
|
|
dataset_dir = param_dict.get('dataset_dir')
|
|
check_dir(dataset_dir)
|
|
|
|
task = param_dict.get('task')
|
|
type_check(task, (str,), "task")
|
|
|
|
mode = param_dict.get('mode')
|
|
type_check(mode, (str,), "mode")
|
|
|
|
if task == "Segmentation":
|
|
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
|
|
if param_dict.get('class_indexing') is not None:
|
|
raise ValueError("class_indexing is invalid in Segmentation task")
|
|
elif task == "Detection":
|
|
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
|
|
else:
|
|
raise ValueError("Invalid task : " + task)
|
|
|
|
check_file(imagesets_file)
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_cocodataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(CocoDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_bool = ['shuffle', 'decode']
|
|
|
|
dataset_dir = param_dict.get('dataset_dir')
|
|
check_dir(dataset_dir)
|
|
|
|
annotation_file = param_dict.get('annotation_file')
|
|
check_file(annotation_file)
|
|
|
|
task = param_dict.get('task')
|
|
type_check(task, (str,), "task")
|
|
|
|
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
|
|
raise ValueError("Invalid task type")
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
|
|
sampler = param_dict.get('sampler')
|
|
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
|
raise ValueError("CocoDataset doesn't support PKSampler")
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_celebadataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
nreq_param_bool = ['shuffle', 'decode']
|
|
nreq_param_list = ['extensions']
|
|
nreq_param_str = ['dataset_type']
|
|
|
|
dataset_dir = param_dict.get('dataset_dir')
|
|
|
|
check_dir(dataset_dir)
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
|
|
dataset_type = param_dict.get('dataset_type')
|
|
if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
|
|
raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
sampler = param_dict.get('sampler')
|
|
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
|
raise ValueError("CelebADataset does not support PKSampler.")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_minddataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
|
|
nreq_param_list = ['columns_list']
|
|
nreq_param_bool = ['block_reader']
|
|
nreq_param_dict = ['padded_sample']
|
|
|
|
dataset_file = param_dict.get('dataset_file')
|
|
if isinstance(dataset_file, list):
|
|
for f in dataset_file:
|
|
check_file(f)
|
|
else:
|
|
check_file(dataset_file)
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
check_padding_options(param_dict)
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_generatordataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
source = param_dict.get('source')
|
|
|
|
if not callable(source):
|
|
try:
|
|
iter(source)
|
|
except TypeError:
|
|
raise TypeError("source should be callable, iterable or random accessible")
|
|
|
|
column_names = param_dict.get('column_names')
|
|
if column_names is not None:
|
|
check_columns(column_names, "column_names")
|
|
schema = param_dict.get('schema')
|
|
if column_names is None and schema is None:
|
|
raise ValueError("Neither columns_names not schema are provided.")
|
|
|
|
if schema is not None:
|
|
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
|
|
raise ValueError("schema should be a path to schema file or a schema object.")
|
|
|
|
# check optional argument
|
|
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
nreq_param_list = ["column_types"]
|
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
nreq_param_bool = ["shuffle"]
|
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
|
|
num_shards = param_dict.get("num_shards")
|
|
shard_id = param_dict.get("shard_id")
|
|
if (num_shards is None) != (shard_id is None):
|
|
# These two parameters appear together.
|
|
raise ValueError("num_shards and shard_id need to be passed in together")
|
|
if num_shards is not None:
|
|
type_check(num_shards, (int,), "num_shards")
|
|
check_positive(num_shards, "num_shards")
|
|
if shard_id >= num_shards:
|
|
raise ValueError("shard_id should be less than num_shards")
|
|
|
|
sampler = param_dict.get("sampler")
|
|
if sampler is not None:
|
|
if isinstance(sampler, samplers.PKSampler):
|
|
raise ValueError("PKSampler is not supported by GeneratorDataset")
|
|
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
|
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
|
samplers.WeightedRandomSampler, samplers.Sampler)):
|
|
try:
|
|
iter(sampler)
|
|
except TypeError:
|
|
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
|
|
|
|
if sampler is not None and not hasattr(source, "__getitem__"):
|
|
raise ValueError("sampler is not supported if source does not have attribute '__getitem__'")
|
|
if num_shards is not None and not hasattr(source, "__getitem__"):
|
|
raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_pad_info(key, val):
|
|
"""check the key and value pair of pad_info in batch"""
|
|
type_check(key, (str,), "key in pad_info")
|
|
|
|
if val is not None:
|
|
assert len(val) == 2, "value of pad_info should be a tuple of size 2"
|
|
type_check(val, (tuple,), "value in pad_info")
|
|
|
|
if val[0] is not None:
|
|
type_check(val[0], (list,), "pad_shape")
|
|
|
|
for dim in val[0]:
|
|
if dim is not None:
|
|
type_check(dim, (int,), "dim in pad_shape")
|
|
assert dim > 0, "pad shape should be positive integers"
|
|
if val[1] is not None:
|
|
type_check(val[1], (int, float, str, bytes), "pad_value")
|
|
|
|
|
|
def check_bucket_batch_by_length(method):
|
|
"""check the input arguments of bucket_batch_by_length."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
|
|
pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
|
|
|
|
type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
|
|
|
|
nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
|
|
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
|
|
|
|
# check column_names: must be list of string.
|
|
check_columns(column_names, "column_names")
|
|
|
|
if element_length_function is None and len(column_names) != 1:
|
|
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
|
|
|
|
# check bucket_boundaries: must be list of int, positive and strictly increasing
|
|
if not bucket_boundaries:
|
|
raise ValueError("bucket_boundaries cannot be empty.")
|
|
|
|
all_int = all(isinstance(item, int) for item in bucket_boundaries)
|
|
if not all_int:
|
|
raise TypeError("bucket_boundaries should be a list of int.")
|
|
|
|
all_non_negative = all(item > 0 for item in bucket_boundaries)
|
|
if not all_non_negative:
|
|
raise ValueError("bucket_boundaries cannot contain any negative numbers.")
|
|
|
|
for i in range(len(bucket_boundaries) - 1):
|
|
if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
|
|
raise ValueError("bucket_boundaries should be strictly increasing.")
|
|
|
|
# check bucket_batch_sizes: must be list of int and positive
|
|
if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
|
|
raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
|
|
|
|
all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
|
|
if not all_int:
|
|
raise TypeError("bucket_batch_sizes should be a list of int.")
|
|
|
|
all_non_negative = all(item > 0 for item in bucket_batch_sizes)
|
|
if not all_non_negative:
|
|
raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
|
|
|
|
if pad_info is not None:
|
|
type_check(pad_info, (dict,), "pad_info")
|
|
|
|
for k, v in pad_info.items():
|
|
check_pad_info(k, v)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_batch(method):
|
|
"""check the input arguments of batch."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[batch_size, drop_remainder, num_parallel_workers, per_batch_map,
|
|
input_columns, pad_info], param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
if not (isinstance(batch_size, int) or (callable(batch_size))):
|
|
raise TypeError("batch_size should either be an int or a callable.")
|
|
|
|
if callable(batch_size):
|
|
sig = ins.signature(batch_size)
|
|
if len(sig.parameters) != 1:
|
|
raise ValueError("batch_size callable should take one parameter (BatchInfo).")
|
|
|
|
if num_parallel_workers is not None:
|
|
check_num_parallel_workers(num_parallel_workers)
|
|
type_check(drop_remainder, (bool,), "drop_remainder")
|
|
|
|
if (pad_info is not None) and (per_batch_map is not None):
|
|
raise ValueError("pad_info and per_batch_map can't both be set")
|
|
|
|
if pad_info is not None:
|
|
type_check(param_dict["pad_info"], (dict,), "pad_info")
|
|
for k, v in param_dict.get('pad_info').items():
|
|
check_pad_info(k, v)
|
|
|
|
if input_columns is not None:
|
|
check_columns(input_columns, "input_columns")
|
|
|
|
if (per_batch_map is None) != (input_columns is None):
|
|
# These two parameters appear together.
|
|
raise ValueError("per_batch_map and input_columns need to be passed in together.")
|
|
|
|
if input_columns is not None:
|
|
if not input_columns: # Check whether input_columns is empty.
|
|
raise ValueError("input_columns can not be empty")
|
|
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
|
|
raise ValueError("the signature of per_batch_map should match with input columns")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_sync_wait(method):
|
|
"""check the input arguments of sync_wait."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(condition_name, (str,), "condition_name")
|
|
type_check(num_batch, (int,), "num_batch")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_shuffle(method):
|
|
"""check the input arguments of shuffle."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[buffer_size], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(buffer_size, (int,), "buffer_size")
|
|
|
|
check_value(buffer_size, [2, INT32_MAX], "buffer_size")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_map(method):
|
|
"""check the input arguments of map."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \
|
|
parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_columns = ['input_columns', 'output_columns']
|
|
|
|
if columns_order is not None:
|
|
type_check(columns_order, (list,), "columns_order")
|
|
if num_parallel_workers is not None:
|
|
check_num_parallel_workers(num_parallel_workers)
|
|
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
|
|
|
|
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
|
if param is not None:
|
|
check_columns(param, param_name)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_filter(method):
|
|
""""check the input arguments of filter."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
|
|
if not callable(predicate):
|
|
raise TypeError("Predicate should be a python function or a callable python object.")
|
|
|
|
check_num_parallel_workers(num_parallel_workers)
|
|
|
|
if num_parallel_workers is not None:
|
|
check_num_parallel_workers(num_parallel_workers)
|
|
|
|
if input_columns is not None:
|
|
check_columns(input_columns, "input_columns")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_repeat(method):
|
|
"""check the input arguments of repeat."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[count], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(count, (int, type(None)), "repeat")
|
|
if isinstance(count, int):
|
|
check_value(count, (-1, INT32_MAX), "count")
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_skip(method):
|
|
"""check the input arguments of skip."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[count], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(count, (int,), "count")
|
|
check_value(count, (-1, INT32_MAX), "count")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_take(method):
|
|
"""check the input arguments of take."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[count], _ = parse_user_args(method, *args, **kwargs)
|
|
type_check(count, (int,), "count")
|
|
if (count <= 0 and count != -1) or count > INT32_MAX:
|
|
raise ValueError("count should be either -1 or positive integer.")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_zip(method):
|
|
"""check the input arguments of zip."""
|
|
|
|
@wraps(method)
|
|
def new_method(*args, **kwargs):
|
|
[ds], _ = parse_user_args(method, *args, **kwargs)
|
|
type_check(ds, (tuple,), "datasets")
|
|
|
|
return method(*args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_zip_dataset(method):
|
|
"""check the input arguments of zip method in `Dataset`."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[ds], _ = parse_user_args(method, *args, **kwargs)
|
|
type_check(ds, (tuple, datasets.Dataset), "datasets")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_concat(method):
|
|
"""check the input arguments of concat method in `Dataset`."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[ds], _ = parse_user_args(method, *args, **kwargs)
|
|
type_check(ds, (list, datasets.Dataset), "datasets")
|
|
if isinstance(ds, list):
|
|
dataset_names = ["dataset[{0}]".format(i) for i in range(len(ds)) if isinstance(ds, list)]
|
|
type_check_list(ds, (datasets.Dataset,), dataset_names)
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_rename(method):
|
|
"""check the input arguments of rename."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
values, _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
req_param_columns = ['input_columns', 'output_columns']
|
|
for param_name, param in zip(req_param_columns, values):
|
|
check_columns(param, param_name)
|
|
|
|
input_size, output_size = 1, 1
|
|
input_columns, output_columns = values
|
|
if isinstance(input_columns, list):
|
|
input_size = len(input_columns)
|
|
if isinstance(output_columns, list):
|
|
output_size = len(output_columns)
|
|
if input_size != output_size:
|
|
raise ValueError("Number of column in input_columns and output_columns is not equal.")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_project(method):
|
|
"""check the input arguments of project."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[columns], _ = parse_user_args(method, *args, **kwargs)
|
|
check_columns(columns, 'columns')
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_add_column(method):
|
|
"""check the input arguments of add_column."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(name, (str,), "name")
|
|
|
|
if not name:
|
|
raise TypeError("Expected non-empty string.")
|
|
|
|
if de_type is not None:
|
|
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
|
|
raise TypeError("Unknown column type.")
|
|
else:
|
|
raise TypeError("Expected non-empty string.")
|
|
|
|
if shape is not None:
|
|
type_check(shape, (list,), "shape")
|
|
shape_names = ["shape[{0}]".format(i) for i in range(len(shape))]
|
|
type_check_list(shape, (int,), shape_names)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_cluedataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
|
|
dataset_files = param_dict.get('dataset_files')
|
|
type_check(dataset_files, (str, list), "dataset files")
|
|
|
|
# check task
|
|
task_param = param_dict.get('task')
|
|
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
|
|
raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
|
|
|
|
# check usage
|
|
usage_param = param_dict.get('usage')
|
|
if usage_param not in ['train', 'test', 'eval']:
|
|
raise ValueError("usage should be train, test or eval")
|
|
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_textfiledataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
|
|
dataset_files = param_dict.get('dataset_files')
|
|
type_check(dataset_files, (str, list), "dataset files")
|
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
check_sampler_shuffle_shard_options(param_dict)
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_split(method):
|
|
"""check the input arguments of split."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(sizes, (list,), "sizes")
|
|
type_check(randomize, (bool,), "randomize")
|
|
|
|
# check sizes: must be list of float or list of int
|
|
if not sizes:
|
|
raise ValueError("sizes cannot be empty.")
|
|
|
|
all_int = all(isinstance(item, int) for item in sizes)
|
|
all_float = all(isinstance(item, float) for item in sizes)
|
|
|
|
if not (all_int or all_float):
|
|
raise ValueError("sizes should be list of int or list of float.")
|
|
|
|
if all_int:
|
|
all_positive = all(item > 0 for item in sizes)
|
|
if not all_positive:
|
|
raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
|
|
|
|
if all_float:
|
|
all_valid_percentages = all(0 < item <= 1 for item in sizes)
|
|
if not all_valid_percentages:
|
|
raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
|
|
|
|
epsilon = 0.00001
|
|
if not abs(sum(sizes) - 1) < epsilon:
|
|
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_graphdata(method):
|
|
"""check the input arguments of graphdata."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
|
|
check_file(dataset_file)
|
|
|
|
if num_parallel_workers is not None:
|
|
check_num_parallel_workers(num_parallel_workers)
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_all_nodes(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[node_type], _ = parse_user_args(method, *args, **kwargs)
|
|
type_check(node_type, (int,), "node_type")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_all_edges(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[edge_type], _ = parse_user_args(method, *args, **kwargs)
|
|
type_check(edge_type, (int,), "edge_type")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_nodes_from_edges(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[edge_list], _ = parse_user_args(method, *args, **kwargs)
|
|
check_gnn_list_or_ndarray(edge_list, "edge_list")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_all_neighbors(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
check_gnn_list_or_ndarray(node_list, 'node_list')
|
|
type_check(neighbour_type, (int,), "neighbour_type")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_sampled_neighbors(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
check_gnn_list_or_ndarray(node_list, 'node_list')
|
|
|
|
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
|
|
if not neighbor_nums or len(neighbor_nums) > 6:
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
|
'neighbor_nums', len(neighbor_nums)))
|
|
|
|
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
|
if not neighbor_types or len(neighbor_types) > 6:
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
|
'neighbor_types', len(neighbor_types)))
|
|
|
|
if len(neighbor_nums) != len(neighbor_types):
|
|
raise ValueError(
|
|
"The number of members of neighbor_nums and neighbor_types is inconsistent")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_neg_sampled_neighbors(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
check_gnn_list_or_ndarray(node_list, 'node_list')
|
|
type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
|
|
type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_random_walk(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args,
|
|
**kwargs)
|
|
check_gnn_list_or_ndarray(target_nodes, 'target_nodes')
|
|
check_gnn_list_or_ndarray(meta_path, 'meta_path')
|
|
type_check(step_home_param, (float,), "step_home_param")
|
|
type_check(step_away_param, (float,), "step_away_param")
|
|
type_check(default_node, (int,), "default_node")
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_aligned_list(param, param_name, member_type):
|
|
"""Check whether the structure of each member of the list is the same."""
|
|
|
|
type_check(param, (list,), "param")
|
|
if not param:
|
|
raise TypeError(
|
|
"Parameter {0} or its members are empty".format(param_name))
|
|
member_have_list = None
|
|
list_len = None
|
|
for member in param:
|
|
if isinstance(member, list):
|
|
check_aligned_list(member, param_name, member_type)
|
|
|
|
if member_have_list not in (None, True):
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
|
param_name))
|
|
if list_len is not None and len(member) != list_len:
|
|
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
|
|
param_name))
|
|
member_have_list = True
|
|
list_len = len(member)
|
|
else:
|
|
type_check(member, (member_type,), param_name)
|
|
if member_have_list not in (None, False):
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
|
param_name))
|
|
member_have_list = False
|
|
|
|
|
|
def check_gnn_get_node_feature(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[node_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(node_list, (list, np.ndarray), "node_list")
|
|
if isinstance(node_list, list):
|
|
check_aligned_list(node_list, 'node_list', int)
|
|
elif isinstance(node_list, np.ndarray):
|
|
if not node_list.dtype == np.int32:
|
|
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
|
node_list, node_list.dtype))
|
|
|
|
check_gnn_list_or_ndarray(feature_types, 'feature_types')
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_gnn_get_edge_feature(method):
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
[edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
|
|
|
|
type_check(edge_list, (list, np.ndarray), "edge_list")
|
|
if isinstance(edge_list, list):
|
|
check_aligned_list(edge_list, 'edge_list', int)
|
|
elif isinstance(edge_list, np.ndarray):
|
|
if not edge_list.dtype == np.int32:
|
|
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
|
edge_list, edge_list.dtype))
|
|
|
|
check_gnn_list_or_ndarray(feature_types, 'feature_types')
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|
|
|
|
|
|
def check_numpyslicesdataset(method):
|
|
"""A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
|
|
|
|
@wraps(method)
|
|
def new_method(self, *args, **kwargs):
|
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
|
|
data = param_dict.get("data")
|
|
column_names = param_dict.get("column_names")
|
|
|
|
type_check(data, (list, tuple, dict, np.ndarray), "data")
|
|
if isinstance(data, tuple):
|
|
type_check(data[0], (list, np.ndarray), "data[0]")
|
|
|
|
# check column_names
|
|
if column_names is not None:
|
|
check_columns(column_names, "column_names")
|
|
|
|
# check num of input column in column_names
|
|
column_num = 1 if isinstance(column_names, str) else len(column_names)
|
|
if isinstance(data, dict):
|
|
data_column = len(list(data.keys()))
|
|
if column_num != data_column:
|
|
raise ValueError("Num of input column names is {0}, but required is {1}."
|
|
.format(column_num, data_column))
|
|
|
|
elif isinstance(data, tuple):
|
|
if column_num != len(data):
|
|
raise ValueError("Num of input column names is {0}, but required is {1}."
|
|
.format(column_num, len(data)))
|
|
else:
|
|
if column_num != 1:
|
|
raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
|
|
.format(column_num, 1))
|
|
|
|
return method(self, *args, **kwargs)
|
|
|
|
return new_method
|