You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/dataset/engine/validators.py

1033 lines
37 KiB

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License foNtest_resr the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Built-in validators.
"""
import inspect as ins
import os
from functools import wraps
import numpy as np
from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
check_columns, check_positive
from . import datasets
from . import samplers
def check_imagefolderdatasetv2(method):
"""A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_list = ['extensions']
nreq_param_dict = ['class_indexing']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_list, param_dict, list)
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_mnist_cifar_dataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_manifestdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_str = ['usage']
nreq_param_dict = ['class_indexing']
dataset_file = param_dict.get('dataset_file')
check_file(dataset_file)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_str, param_dict, str)
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_tfrecorddataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_list = ['columns_list']
nreq_param_bool = ['shard_equal_rows']
dataset_files = param_dict.get('dataset_files')
if not isinstance(dataset_files, (str, list)):
raise TypeError("dataset_files should be of type str or a list of strings.")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_list, param_dict, list)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_vocdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_dict = ['class_indexing']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
task = param_dict.get('task')
type_check(task, (str,), "task")
mode = param_dict.get('mode')
type_check(mode, (str,), "mode")
if task == "Segmentation":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
if param_dict.get('class_indexing') is not None:
raise ValueError("class_indexing is invalid in Segmentation task")
elif task == "Detection":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
else:
raise ValueError("Invalid task : " + task)
check_file(imagesets_file)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_cocodataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CocoDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
annotation_file = param_dict.get('annotation_file')
check_file(annotation_file)
task = param_dict.get('task')
type_check(task, (str,), "task")
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
raise ValueError("Invalid task type")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
sampler = param_dict.get('sampler')
if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CocoDataset doesn't support PKSampler")
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_celebadataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_list = ['extensions']
nreq_param_str = ['dataset_type']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_list, param_dict, list)
validate_dataset_param_value(nreq_param_str, param_dict, str)
dataset_type = param_dict.get('dataset_type')
if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
check_sampler_shuffle_shard_options(param_dict)
sampler = param_dict.get('sampler')
if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CelebADataset does not support PKSampler.")
return method(self, *args, **kwargs)
return new_method
def check_minddataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
nreq_param_list = ['columns_list']
nreq_param_bool = ['block_reader']
nreq_param_dict = ['padded_sample']
dataset_file = param_dict.get('dataset_file')
if isinstance(dataset_file, list):
for f in dataset_file:
check_file(f)
else:
check_file(dataset_file)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_list, param_dict, list)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
check_padding_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_generatordataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
source = param_dict.get('source')
if not callable(source):
try:
iter(source)
except TypeError:
raise TypeError("source should be callable, iterable or random accessible")
column_names = param_dict.get('column_names')
if column_names is not None:
check_columns(column_names, "column_names")
schema = param_dict.get('schema')
if column_names is None and schema is None:
raise ValueError("Neither columns_names not schema are provided.")
if schema is not None:
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
raise ValueError("schema should be a path to schema file or a schema object.")
# check optional argument
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
validate_dataset_param_value(nreq_param_int, param_dict, int)
nreq_param_list = ["column_types"]
validate_dataset_param_value(nreq_param_list, param_dict, list)
nreq_param_bool = ["shuffle"]
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
num_shards = param_dict.get("num_shards")
shard_id = param_dict.get("shard_id")
if (num_shards is None) != (shard_id is None):
# These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together")
if num_shards is not None:
type_check(num_shards, (int,), "num_shards")
check_positive(num_shards, "num_shards")
if shard_id >= num_shards:
raise ValueError("shard_id should be less than num_shards")
sampler = param_dict.get("sampler")
if sampler is not None:
if isinstance(sampler, samplers.PKSampler):
raise ValueError("PKSampler is not supported by GeneratorDataset")
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)):
try:
iter(sampler)
except TypeError:
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
if sampler is not None and not hasattr(source, "__getitem__"):
raise ValueError("sampler is not supported if source does not have attribute '__getitem__'")
if num_shards is not None and not hasattr(source, "__getitem__"):
raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'")
return method(self, *args, **kwargs)
return new_method
def check_pad_info(key, val):
"""check the key and value pair of pad_info in batch"""
type_check(key, (str,), "key in pad_info")
if val is not None:
assert len(val) == 2, "value of pad_info should be a tuple of size 2"
type_check(val, (tuple,), "value in pad_info")
if val[0] is not None:
type_check(val[0], (list,), "pad_shape")
for dim in val[0]:
if dim is not None:
type_check(dim, (int,), "dim in pad_shape")
assert dim > 0, "pad shape should be positive integers"
if val[1] is not None:
type_check(val[1], (int, float, str, bytes), "pad_value")
def check_bucket_batch_by_length(method):
"""check the input arguments of bucket_batch_by_length."""
@wraps(method)
def new_method(self, *args, **kwargs):
[column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
# check column_names: must be list of string.
check_columns(column_names, "column_names")
if element_length_function is None and len(column_names) != 1:
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
# check bucket_boundaries: must be list of int, positive and strictly increasing
if not bucket_boundaries:
raise ValueError("bucket_boundaries cannot be empty.")
all_int = all(isinstance(item, int) for item in bucket_boundaries)
if not all_int:
raise TypeError("bucket_boundaries should be a list of int.")
all_non_negative = all(item > 0 for item in bucket_boundaries)
if not all_non_negative:
raise ValueError("bucket_boundaries cannot contain any negative numbers.")
for i in range(len(bucket_boundaries) - 1):
if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
raise ValueError("bucket_boundaries should be strictly increasing.")
# check bucket_batch_sizes: must be list of int and positive
if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
if not all_int:
raise TypeError("bucket_batch_sizes should be a list of int.")
all_non_negative = all(item > 0 for item in bucket_batch_sizes)
if not all_non_negative:
raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
if pad_info is not None:
type_check(pad_info, (dict,), "pad_info")
for k, v in pad_info.items():
check_pad_info(k, v)
return method(self, *args, **kwargs)
return new_method
def check_batch(method):
"""check the input arguments of batch."""
@wraps(method)
def new_method(self, *args, **kwargs):
[batch_size, drop_remainder, num_parallel_workers, per_batch_map,
input_columns, pad_info], param_dict = parse_user_args(method, *args, **kwargs)
if not (isinstance(batch_size, int) or (callable(batch_size))):
raise TypeError("batch_size should either be an int or a callable.")
if callable(batch_size):
sig = ins.signature(batch_size)
if len(sig.parameters) != 1:
raise ValueError("batch_size callable should take one parameter (BatchInfo).")
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
type_check(drop_remainder, (bool,), "drop_remainder")
if (pad_info is not None) and (per_batch_map is not None):
raise ValueError("pad_info and per_batch_map can't both be set")
if pad_info is not None:
type_check(param_dict["pad_info"], (dict,), "pad_info")
for k, v in param_dict.get('pad_info').items():
check_pad_info(k, v)
if input_columns is not None:
check_columns(input_columns, "input_columns")
if (per_batch_map is None) != (input_columns is None):
# These two parameters appear together.
raise ValueError("per_batch_map and input_columns need to be passed in together.")
if input_columns is not None:
if not input_columns: # Check whether input_columns is empty.
raise ValueError("input_columns can not be empty")
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
raise ValueError("the signature of per_batch_map should match with input columns")
return method(self, *args, **kwargs)
return new_method
def check_sync_wait(method):
"""check the input arguments of sync_wait."""
@wraps(method)
def new_method(self, *args, **kwargs):
[condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
type_check(condition_name, (str,), "condition_name")
type_check(num_batch, (int,), "num_batch")
return method(self, *args, **kwargs)
return new_method
def check_shuffle(method):
"""check the input arguments of shuffle."""
@wraps(method)
def new_method(self, *args, **kwargs):
[buffer_size], _ = parse_user_args(method, *args, **kwargs)
type_check(buffer_size, (int,), "buffer_size")
check_value(buffer_size, [2, INT32_MAX], "buffer_size")
return method(self, *args, **kwargs)
return new_method
def check_map(method):
"""check the input arguments of map."""
@wraps(method)
def new_method(self, *args, **kwargs):
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \
parse_user_args(method, *args, **kwargs)
nreq_param_columns = ['input_columns', 'output_columns']
if columns_order is not None:
type_check(columns_order, (list,), "columns_order")
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
if param is not None:
check_columns(param, param_name)
return method(self, *args, **kwargs)
return new_method
def check_filter(method):
""""check the input arguments of filter."""
@wraps(method)
def new_method(self, *args, **kwargs):
[predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
if not callable(predicate):
raise TypeError("Predicate should be a python function or a callable python object.")
check_num_parallel_workers(num_parallel_workers)
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
if input_columns is not None:
check_columns(input_columns, "input_columns")
return method(self, *args, **kwargs)
return new_method
def check_repeat(method):
"""check the input arguments of repeat."""
@wraps(method)
def new_method(self, *args, **kwargs):
[count], _ = parse_user_args(method, *args, **kwargs)
type_check(count, (int, type(None)), "repeat")
if isinstance(count, int):
check_value(count, (-1, INT32_MAX), "count")
return method(self, *args, **kwargs)
return new_method
def check_skip(method):
"""check the input arguments of skip."""
@wraps(method)
def new_method(self, *args, **kwargs):
[count], _ = parse_user_args(method, *args, **kwargs)
type_check(count, (int,), "count")
check_value(count, (-1, INT32_MAX), "count")
return method(self, *args, **kwargs)
return new_method
def check_take(method):
"""check the input arguments of take."""
@wraps(method)
def new_method(self, *args, **kwargs):
[count], _ = parse_user_args(method, *args, **kwargs)
type_check(count, (int,), "count")
if (count <= 0 and count != -1) or count > INT32_MAX:
raise ValueError("count should be either -1 or positive integer.")
return method(self, *args, **kwargs)
return new_method
def check_zip(method):
"""check the input arguments of zip."""
@wraps(method)
def new_method(*args, **kwargs):
[ds], _ = parse_user_args(method, *args, **kwargs)
type_check(ds, (tuple,), "datasets")
return method(*args, **kwargs)
return new_method
def check_zip_dataset(method):
"""check the input arguments of zip method in `Dataset`."""
@wraps(method)
def new_method(self, *args, **kwargs):
[ds], _ = parse_user_args(method, *args, **kwargs)
type_check(ds, (tuple, datasets.Dataset), "datasets")
return method(self, *args, **kwargs)
return new_method
def check_concat(method):
"""check the input arguments of concat method in `Dataset`."""
@wraps(method)
def new_method(self, *args, **kwargs):
[ds], _ = parse_user_args(method, *args, **kwargs)
type_check(ds, (list, datasets.Dataset), "datasets")
if isinstance(ds, list):
dataset_names = ["dataset[{0}]".format(i) for i in range(len(ds)) if isinstance(ds, list)]
type_check_list(ds, (datasets.Dataset,), dataset_names)
return method(self, *args, **kwargs)
return new_method
def check_rename(method):
"""check the input arguments of rename."""
@wraps(method)
def new_method(self, *args, **kwargs):
values, _ = parse_user_args(method, *args, **kwargs)
req_param_columns = ['input_columns', 'output_columns']
for param_name, param in zip(req_param_columns, values):
check_columns(param, param_name)
input_size, output_size = 1, 1
input_columns, output_columns = values
if isinstance(input_columns, list):
input_size = len(input_columns)
if isinstance(output_columns, list):
output_size = len(output_columns)
if input_size != output_size:
raise ValueError("Number of column in input_columns and output_columns is not equal.")
return method(self, *args, **kwargs)
return new_method
def check_project(method):
"""check the input arguments of project."""
@wraps(method)
def new_method(self, *args, **kwargs):
[columns], _ = parse_user_args(method, *args, **kwargs)
check_columns(columns, 'columns')
return method(self, *args, **kwargs)
return new_method
def check_add_column(method):
"""check the input arguments of add_column."""
@wraps(method)
def new_method(self, *args, **kwargs):
[name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
type_check(name, (str,), "name")
if not name:
raise TypeError("Expected non-empty string.")
if de_type is not None:
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
raise TypeError("Unknown column type.")
else:
raise TypeError("Expected non-empty string.")
if shape is not None:
type_check(shape, (list,), "shape")
shape_names = ["shape[{0}]".format(i) for i in range(len(shape))]
type_check_list(shape, (int,), shape_names)
return method(self, *args, **kwargs)
return new_method
def check_cluedataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (str, list), "dataset files")
# check task
task_param = param_dict.get('task')
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
# check usage
usage_param = param_dict.get('usage')
if usage_param not in ['train', 'test', 'eval']:
raise ValueError("usage should be train, test or eval")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_textfiledataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (str, list), "dataset files")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_split(method):
"""check the input arguments of split."""
@wraps(method)
def new_method(self, *args, **kwargs):
[sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
type_check(sizes, (list,), "sizes")
type_check(randomize, (bool,), "randomize")
# check sizes: must be list of float or list of int
if not sizes:
raise ValueError("sizes cannot be empty.")
all_int = all(isinstance(item, int) for item in sizes)
all_float = all(isinstance(item, float) for item in sizes)
if not (all_int or all_float):
raise ValueError("sizes should be list of int or list of float.")
if all_int:
all_positive = all(item > 0 for item in sizes)
if not all_positive:
raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
if all_float:
all_valid_percentages = all(0 < item <= 1 for item in sizes)
if not all_valid_percentages:
raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
epsilon = 0.00001
if not abs(sum(sizes) - 1) < epsilon:
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
return method(self, *args, **kwargs)
return new_method
def check_gnn_graphdata(method):
"""check the input arguments of graphdata."""
@wraps(method)
def new_method(self, *args, **kwargs):
[dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
check_file(dataset_file)
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_all_nodes(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[node_type], _ = parse_user_args(method, *args, **kwargs)
type_check(node_type, (int,), "node_type")
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_all_edges(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[edge_type], _ = parse_user_args(method, *args, **kwargs)
type_check(edge_type, (int,), "edge_type")
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_nodes_from_edges(method):
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[edge_list], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(edge_list, "edge_list")
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_all_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(node_list, 'node_list')
type_check(neighbour_type, (int,), "neighbour_type")
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_sampled_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(node_list, 'node_list')
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
if not neighbor_nums or len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
'neighbor_nums', len(neighbor_nums)))
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
if not neighbor_types or len(neighbor_types) > 6:
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
'neighbor_types', len(neighbor_types)))
if len(neighbor_nums) != len(neighbor_types):
raise ValueError(
"The number of members of neighbor_nums and neighbor_types is inconsistent")
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_neg_sampled_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(node_list, 'node_list')
type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
return method(self, *args, **kwargs)
return new_method
def check_gnn_random_walk(method):
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args,
**kwargs)
check_gnn_list_or_ndarray(target_nodes, 'target_nodes')
check_gnn_list_or_ndarray(meta_path, 'meta_path')
type_check(step_home_param, (float,), "step_home_param")
type_check(step_away_param, (float,), "step_away_param")
type_check(default_node, (int,), "default_node")
return method(self, *args, **kwargs)
return new_method
def check_aligned_list(param, param_name, member_type):
"""Check whether the structure of each member of the list is the same."""
type_check(param, (list,), "param")
if not param:
raise TypeError(
"Parameter {0} or its members are empty".format(param_name))
member_have_list = None
list_len = None
for member in param:
if isinstance(member, list):
check_aligned_list(member, param_name, member_type)
if member_have_list not in (None, True):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
param_name))
if list_len is not None and len(member) != list_len:
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
param_name))
member_have_list = True
list_len = len(member)
else:
type_check(member, (member_type,), param_name)
if member_have_list not in (None, False):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
param_name))
member_have_list = False
def check_gnn_get_node_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
type_check(node_list, (list, np.ndarray), "node_list")
if isinstance(node_list, list):
check_aligned_list(node_list, 'node_list', int)
elif isinstance(node_list, np.ndarray):
if not node_list.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
node_list, node_list.dtype))
check_gnn_list_or_ndarray(feature_types, 'feature_types')
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_edge_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
type_check(edge_list, (list, np.ndarray), "edge_list")
if isinstance(edge_list, list):
check_aligned_list(edge_list, 'edge_list', int)
elif isinstance(edge_list, np.ndarray):
if not edge_list.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
edge_list, edge_list.dtype))
check_gnn_list_or_ndarray(feature_types, 'feature_types')
return method(self, *args, **kwargs)
return new_method
def check_numpyslicesdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
data = param_dict.get("data")
column_names = param_dict.get("column_names")
type_check(data, (list, tuple, dict, np.ndarray), "data")
if isinstance(data, tuple):
type_check(data[0], (list, np.ndarray), "data[0]")
# check column_names
if column_names is not None:
check_columns(column_names, "column_names")
# check num of input column in column_names
column_num = 1 if isinstance(column_names, str) else len(column_names)
if isinstance(data, dict):
data_column = len(list(data.keys()))
if column_num != data_column:
raise ValueError("Num of input column names is {0}, but required is {1}."
.format(column_num, data_column))
elif isinstance(data, tuple):
if column_num != len(data):
raise ValueError("Num of input column names is {0}, but required is {1}."
.format(column_num, len(data)))
else:
if column_num != 1:
raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
.format(column_num, 1))
return method(self, *args, **kwargs)
return new_method