|
|
@ -23,6 +23,7 @@ import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
import mindspore._c_dataengine as cde
|
|
|
|
import mindspore._c_dataengine as cde
|
|
|
|
from ..engine import samplers
|
|
|
|
from ..engine import samplers
|
|
|
|
|
|
|
|
|
|
|
|
# POS_INT_MIN is used to limit values from starting from 0
|
|
|
|
# POS_INT_MIN is used to limit values from starting from 0
|
|
|
|
POS_INT_MIN = 1
|
|
|
|
POS_INT_MIN = 1
|
|
|
|
UINT8_MAX = 255
|
|
|
|
UINT8_MAX = 255
|
|
|
@ -289,7 +290,6 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|
|
|
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
|
|
|
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
|
|
|
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
|
|
|
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
|
|
|
num_samples = param_dict.get('num_samples')
|
|
|
|
num_samples = param_dict.get('num_samples')
|
|
|
|
check_sampler(sampler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sampler is not None:
|
|
|
|
if sampler is not None:
|
|
|
|
if shuffle is not None:
|
|
|
|
if shuffle is not None:
|
|
|
@ -348,6 +348,7 @@ def check_num_samples(value):
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
"num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))
|
|
|
|
"num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_dataset_param_value(param_list, param_dict, param_type):
|
|
|
|
def validate_dataset_param_value(param_list, param_dict, param_type):
|
|
|
|
for param_name in param_list:
|
|
|
|
for param_name in param_list:
|
|
|
|
if param_dict.get(param_name) is not None:
|
|
|
|
if param_dict.get(param_name) is not None:
|
|
|
@ -387,6 +388,7 @@ def check_tensor_op(param, param_name):
|
|
|
|
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
|
|
|
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
|
|
|
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
|
|
|
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_sampler(sampler):
|
|
|
|
def check_sampler(sampler):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Check if the sampler is of valid input.
|
|
|
|
Check if the sampler is of valid input.
|
|
|
@ -419,5 +421,6 @@ def check_sampler(sampler):
|
|
|
|
if not (builtin or base_sampler or list_num):
|
|
|
|
if not (builtin or base_sampler or list_num):
|
|
|
|
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
|
|
|
|
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_none(value, default):
|
|
|
|
def replace_none(value, default):
|
|
|
|
return value if value is not None else default
|
|
|
|
return value if value is not None else default
|
|
|
|