add err modify

pull/8688/head
ms_yan 4 years ago
parent de60d1d98f
commit 0cb5c47856

@ -186,7 +186,7 @@ class WaitedDSCallback(Callback, DSCallback):
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
self.epoch_event.clear() self.epoch_event.clear()
if not success: if not success:
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s).")
# by the time this thread wakes up, self.epoch_run_context is already available # by the time this thread wakes up, self.epoch_run_context is already available
self.sync_epoch_begin(self.epoch_run_context, ds_run_context) self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
@ -212,7 +212,7 @@ class WaitedDSCallback(Callback, DSCallback):
success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
self.step_event.clear() self.step_event.clear()
if not success: if not success:
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s).")
# by the time this thread wakes up, self.epoch_run_context is already available # by the time this thread wakes up, self.epoch_run_context is already available
self.sync_step_begin(self.step_run_context, ds_run_context) self.sync_step_begin(self.step_run_context, ds_run_context)

@ -122,7 +122,7 @@ def check_pos_float64(value, arg_name=""):
def check_valid_detype(type_): def check_valid_detype(type_):
if type_ not in valid_detype: if type_ not in valid_detype:
raise ValueError("Unknown column type") raise TypeError("Unknown column type.")
return True return True
@ -146,10 +146,10 @@ def check_columns(columns, name):
type_check(columns, (list, str), name) type_check(columns, (list, str), name)
if isinstance(columns, str): if isinstance(columns, str):
if not columns: if not columns:
raise ValueError("{0} should not be an empty str".format(name)) raise ValueError("{0} should not be an empty str.".format(name))
elif isinstance(columns, list): elif isinstance(columns, list):
if not columns: if not columns:
raise ValueError("{0} should not be empty".format(name)) raise ValueError("{0} should not be empty.".format(name))
for i, column_name in enumerate(columns): for i, column_name in enumerate(columns):
if not column_name: if not column_name:
raise ValueError("{0}[{1}] should not be empty.".format(name, i)) raise ValueError("{0}[{1}] should not be empty.".format(name, i))
@ -250,10 +250,10 @@ def check_filename(path):
forbidden_symbols = set(r'\/:*?"<>|`&\';') forbidden_symbols = set(r'\/:*?"<>|`&\';')
if set(filename) & forbidden_symbols: if set(filename) & forbidden_symbols:
raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") raise ValueError(r"filename should not contain \/:*?\"<>|`&;\'")
if filename.startswith(' ') or filename.endswith(' '): if filename.startswith(' ') or filename.endswith(' '):
raise ValueError("filename should not start/end with space") raise ValueError("filename should not start/end with space.")
return True return True
@ -374,4 +374,4 @@ def check_gnn_list_or_ndarray(param, param_name):
def check_tensor_op(param, param_name): def check_tensor_op(param, param_name):
"""check whether param is a tensor op or a callable Python function""" """check whether param is a tensor op or a callable Python function"""
if not isinstance(param, cde.TensorOp) and not callable(param): if not isinstance(param, cde.TensorOp) and not callable(param):
raise TypeError("{0} is not a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name)) raise TypeError("{0} is neither a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))

File diff suppressed because it is too large Load Diff

@ -89,7 +89,7 @@ class GraphData:
while self._graph_data.is_stoped() is not True: while self._graph_data.is_stoped() is not True:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Graph data server receives KeyboardInterrupt") raise Exception("Graph data server receives KeyboardInterrupt.")
@check_gnn_get_all_nodes @check_gnn_get_all_nodes
def get_all_nodes(self, node_type): def get_all_nodes(self, node_type):
@ -112,7 +112,7 @@ class GraphData:
TypeError: If `node_type` is not integer. TypeError: If `node_type` is not integer.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_all_nodes(node_type).as_array() return self._graph_data.get_all_nodes(node_type).as_array()
@check_gnn_get_all_edges @check_gnn_get_all_edges
@ -136,7 +136,7 @@ class GraphData:
TypeError: If `edge_type` is not integer. TypeError: If `edge_type` is not integer.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_all_edges(edge_type).as_array() return self._graph_data.get_all_edges(edge_type).as_array()
@check_gnn_get_nodes_from_edges @check_gnn_get_nodes_from_edges
@ -154,7 +154,7 @@ class GraphData:
TypeError: If `edge_list` is not list or ndarray. TypeError: If `edge_list` is not list or ndarray.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_nodes_from_edges(edge_list).as_array() return self._graph_data.get_nodes_from_edges(edge_list).as_array()
@check_gnn_get_all_neighbors @check_gnn_get_all_neighbors
@ -181,7 +181,7 @@ class GraphData:
TypeError: If `neighbor_type` is not integer. TypeError: If `neighbor_type` is not integer.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
@check_gnn_get_sampled_neighbors @check_gnn_get_sampled_neighbors
@ -216,7 +216,7 @@ class GraphData:
TypeError: If `neighbor_types` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_sampled_neighbors( return self._graph_data.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array() node_list, neighbor_nums, neighbor_types).as_array()
@ -246,7 +246,7 @@ class GraphData:
TypeError: If `neg_neighbor_type` is not integer. TypeError: If `neg_neighbor_type` is not integer.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_neg_sampled_neighbors( return self._graph_data.get_neg_sampled_neighbors(
node_list, neg_neighbor_num, neg_neighbor_type).as_array() node_list, neg_neighbor_num, neg_neighbor_type).as_array()
@ -274,7 +274,7 @@ class GraphData:
TypeError: If `feature_types` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
if isinstance(node_list, list): if isinstance(node_list, list):
node_list = np.array(node_list, dtype=np.int32) node_list = np.array(node_list, dtype=np.int32)
return [ return [
@ -306,7 +306,7 @@ class GraphData:
TypeError: If `feature_types` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
if isinstance(edge_list, list): if isinstance(edge_list, list):
edge_list = np.array(edge_list, dtype=np.int32) edge_list = np.array(edge_list, dtype=np.int32)
return [ return [
@ -324,7 +324,7 @@ class GraphData:
node_feature_type and edge_feature_type. node_feature_type and edge_feature_type.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.graph_info() return self._graph_data.graph_info()
@check_gnn_random_walk @check_gnn_random_walk
@ -360,6 +360,6 @@ class GraphData:
TypeError: If `meta_path` is not list or ndarray. TypeError: If `meta_path` is not list or ndarray.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param, return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array() default_node).as_array()

@ -29,20 +29,25 @@ from . import datasets as de
_ITERATOR_CLEANUP = False _ITERATOR_CLEANUP = False
def _set_iterator_cleanup(): def _set_iterator_cleanup():
global _ITERATOR_CLEANUP global _ITERATOR_CLEANUP
_ITERATOR_CLEANUP = True _ITERATOR_CLEANUP = True
def _unset_iterator_cleanup(): def _unset_iterator_cleanup():
global _ITERATOR_CLEANUP global _ITERATOR_CLEANUP
_ITERATOR_CLEANUP = False _ITERATOR_CLEANUP = False
def check_iterator_cleanup(): def check_iterator_cleanup():
global _ITERATOR_CLEANUP global _ITERATOR_CLEANUP
return _ITERATOR_CLEANUP return _ITERATOR_CLEANUP
ITERATORS_LIST = list() ITERATORS_LIST = list()
def _cleanup(): def _cleanup():
"""Release all the Iterator.""" """Release all the Iterator."""
_set_iterator_cleanup() _set_iterator_cleanup()
@ -51,6 +56,7 @@ def _cleanup():
if itr is not None: if itr is not None:
itr.release() itr.release()
def alter_tree(node): def alter_tree(node):
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes.""" """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
if not node.children: if not node.children:
@ -73,6 +79,7 @@ def _alter_node(node):
node.iterator_bootstrap() node.iterator_bootstrap()
return node return node
class Iterator: class Iterator:
""" """
General Iterator over a dataset. General Iterator over a dataset.
@ -93,7 +100,7 @@ class Iterator:
# The dataset passed into the iterator is not the root of the tree. # The dataset passed into the iterator is not the root of the tree.
# Trim the tree by saving the parent subtree into self.parent_subtree and # Trim the tree by saving the parent subtree into self.parent_subtree and
# restore it after launching our c++ pipeline. # restore it after launching our C++ pipeline.
if self.dataset.parent: if self.dataset.parent:
logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.") logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
self.parent_subtree = self.dataset.parent self.parent_subtree = self.dataset.parent
@ -101,7 +108,7 @@ class Iterator:
self.dataset = alter_tree(self.dataset) self.dataset = alter_tree(self.dataset)
if not self.__is_tree(): if not self.__is_tree():
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).")
self.depipeline = DEPipeline() self.depipeline = DEPipeline()
# for manifest temporary use # for manifest temporary use
@ -116,7 +123,7 @@ class Iterator:
""" """
Manually terminate Python iterator instead of relying on out of scope destruction. Manually terminate Python iterator instead of relying on out of scope destruction.
""" """
logger.info("terminating Python iterator. This will also terminate c++ pipeline.") logger.info("Terminating Python iterator. This will also terminate C++ pipeline.")
if hasattr(self, 'depipeline') and self.depipeline: if hasattr(self, 'depipeline') and self.depipeline:
del self.depipeline del self.depipeline
@ -205,7 +212,7 @@ class Iterator:
elif isinstance(dataset, de.CSVDataset): elif isinstance(dataset, de.CSVDataset):
op_type = OpName.CSV op_type = OpName.CSV
else: else:
raise ValueError("Unsupported DatasetOp") raise ValueError("Unsupported DatasetOp.")
return op_type return op_type
@ -256,9 +263,9 @@ class Iterator:
def __next__(self): def __next__(self):
if not self.depipeline: if not self.depipeline:
logger.warning("Iterator does not have a running c++ pipeline." + logger.warning("Iterator does not have a running C++ pipeline." +
"It can be because Iterator stop() had been called, or c++ pipeline crashed silently.") "It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
raise RuntimeError("Iterator does not have a running c++ pipeline.") raise RuntimeError("Iterator does not have a running C++ pipeline.")
data = self.get_next() data = self.get_next()
if not data: if not data:
@ -298,6 +305,7 @@ class Iterator:
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self return self
class SaveOp(Iterator): class SaveOp(Iterator):
""" """
The derived class of Iterator with dict type. The derived class of Iterator with dict type.
@ -375,7 +383,7 @@ class TupleIterator(Iterator):
return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()] return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()]
class DummyIterator(): class DummyIterator:
""" """
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
""" """

@ -24,6 +24,7 @@ import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
import mindspore.dataset as ds import mindspore.dataset as ds
class Sampler: class Sampler:
""" """
Base class for user defined sampler. Base class for user defined sampler.
@ -245,22 +246,22 @@ class DistributedSampler(BuiltinSampler):
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
if num_shards <= 0: if num_shards <= 0:
raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards)) raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards))
if shard_id < 0 or shard_id >= num_shards: if shard_id < 0 or shard_id >= num_shards:
raise ValueError("shard_id is invalid, shard_id={}".format(shard_id)) raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id))
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)) "value, but got num_samples: {}.".format(num_samples))
if offset > num_shards: if offset > num_shards:
raise ValueError("offset should be no more than num_shards={}, " raise ValueError("offset should be no more than num_shards: {}, "
"but got offset={}".format(num_shards, offset)) "but got offset: {}".format(num_shards, offset))
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
@ -332,18 +333,18 @@ class PKSampler(BuiltinSampler):
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
if num_val <= 0: if num_val <= 0:
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val))
if num_class is not None: if num_class is not None:
raise NotImplementedError("Not support specify num_class") raise NotImplementedError("Not supported to specify num_class for PKSampler.")
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle))
if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)) "value, but got num_samples: {}.".format(num_samples))
self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
@ -372,7 +373,7 @@ class PKSampler(BuiltinSampler):
def create_for_minddataset(self): def create_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str): if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \ raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column)) but got class_column: {}.".format(class_column))
num_samples = self.num_samples if self.num_samples is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.create_child_for_minddataset() c_child_sampler = self.create_child_for_minddataset()
@ -404,12 +405,12 @@ class RandomSampler(BuiltinSampler):
def __init__(self, replacement=False, num_samples=None): def __init__(self, replacement=False, num_samples=None):
if not isinstance(replacement, bool): if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement)) raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)) "value, but got num_samples: {}.".format(num_samples))
self.deterministic = False self.deterministic = False
self.replacement = replacement self.replacement = replacement
@ -462,12 +463,12 @@ class SequentialSampler(BuiltinSampler):
if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)) "value, but got num_samples: {}.".format(num_samples))
if start_index is not None: if start_index is not None:
if start_index < 0: if start_index < 0:
raise ValueError("start_index should be a positive integer " raise ValueError("start_index should be a positive integer "
"value or 0, but got start_index={}".format(start_index)) "value or 0, but got start_index: {}.".format(start_index))
self.start_index = start_index self.start_index = start_index
super().__init__(num_samples) super().__init__(num_samples)
@ -517,7 +518,7 @@ class SubsetRandomSampler(BuiltinSampler):
>>> indices = [0, 1, 2, 3, 7, 88, 119] >>> indices = [0, 1, 2, 3, 7, 88, 119]
>>> >>>
>>> # creates a SubsetRandomSampler, will sample from the provided indices >>> # creates a SubsetRandomSampler, will sample from the provided indices
>>> sampler = ds.SubsetRandomSampler() >>> sampler = ds.SubsetRandomSampler(indices)
>>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler) >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
""" """
@ -525,7 +526,7 @@ class SubsetRandomSampler(BuiltinSampler):
if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)) "value, but got num_samples: {}.".format(num_samples))
if not isinstance(indices, list): if not isinstance(indices, list):
indices = [indices] indices = [indices]
@ -595,24 +596,24 @@ class WeightedRandomSampler(BuiltinSampler):
for ind, w in enumerate(weights): for ind, w in enumerate(weights):
if not isinstance(w, numbers.Number): if not isinstance(w, numbers.Number):
raise TypeError("type of weights element should be number, " raise TypeError("type of weights element should be number, "
"but got w[{}]={}, type={}".format(ind, w, type(w))) "but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
if weights == []: if weights == []:
raise ValueError("weights size should not be 0") raise ValueError("weights size should not be 0")
if list(filter(lambda x: x < 0, weights)) != []: if list(filter(lambda x: x < 0, weights)) != []:
raise ValueError("weights should not contain negative numbers") raise ValueError("weights should not contain negative numbers.")
if list(filter(lambda x: x == 0, weights)) == weights: if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zero") raise ValueError("elements of weights should not be all zeros.")
if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)) "value, but got num_samples: {}.".format(num_samples))
if not isinstance(replacement, bool): if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement)) raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement))
self.weights = weights self.weights = weights
self.replacement = replacement self.replacement = replacement

@ -348,15 +348,15 @@ def create_node(node):
elif dataset_op == 'CacheDataset': elif dataset_op == 'CacheDataset':
# Member function cache() is not defined in class Dataset yet. # Member function cache() is not defined in class Dataset yet.
raise RuntimeError(dataset_op + " is not yet supported") raise RuntimeError(dataset_op + " is not yet supported.")
elif dataset_op == 'FilterDataset': elif dataset_op == 'FilterDataset':
# Member function filter() is not defined in class Dataset yet. # Member function filter() is not defined in class Dataset yet.
raise RuntimeError(dataset_op + " is not yet supported") raise RuntimeError(dataset_op + " is not yet supported.")
elif dataset_op == 'TakeDataset': elif dataset_op == 'TakeDataset':
# Member function take() is not defined in class Dataset yet. # Member function take() is not defined in class Dataset yet.
raise RuntimeError(dataset_op + " is not yet supported") raise RuntimeError(dataset_op + " is not yet supported.")
elif dataset_op == 'ZipDataset': elif dataset_op == 'ZipDataset':
# Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
@ -376,7 +376,7 @@ def create_node(node):
pyobj = de.Dataset().to_device() pyobj = de.Dataset().to_device()
else: else:
raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize()") raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().")
return pyobj return pyobj
@ -401,7 +401,7 @@ def construct_sampler(in_sampler):
elif sampler_name == 'WeightedRandomSampler': elif sampler_name == 'WeightedRandomSampler':
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement')) sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
else: else:
raise ValueError("Sampler type is unknown: " + sampler_name) raise ValueError("Sampler type is unknown: {}.".format(sampler_name))
return sampler return sampler
@ -461,7 +461,7 @@ def construct_tensor_ops(operations):
result.append(op_class()) result.append(op_class())
elif op_name == 'CHW2HWC': elif op_name == 'CHW2HWC':
raise ValueError("Tensor op is not supported: " + op_name) raise ValueError("Tensor op is not supported: {}.".format(op_name))
elif op_name == 'OneHot': elif op_name == 'OneHot':
result.append(op_class(op['num_classes'])) result.append(op_class(op['num_classes']))
@ -474,6 +474,6 @@ def construct_tensor_ops(operations):
result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode']))) result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode'])))
else: else:
raise ValueError("Tensor op name is unknown: " + op_name) raise ValueError("Tensor op name is unknown: {}.".format(op_name))
return result return result

@ -134,7 +134,7 @@ def check_tfrecorddataset(method):
dataset_files = param_dict.get('dataset_files') dataset_files = param_dict.get('dataset_files')
if not isinstance(dataset_files, (str, list)): if not isinstance(dataset_files, (str, list)):
raise TypeError("dataset_files should be of type str or a list of strings.") raise TypeError("dataset_files should be type str or a list of strings.")
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_list, param_dict, list) validate_dataset_param_value(nreq_param_list, param_dict, list)
@ -173,11 +173,11 @@ def check_vocdataset(method):
if task == "Segmentation": if task == "Segmentation":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt") imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
if param_dict.get('class_indexing') is not None: if param_dict.get('class_indexing') is not None:
raise ValueError("class_indexing is invalid in Segmentation task") raise ValueError("class_indexing is not supported in Segmentation task.")
elif task == "Detection": elif task == "Detection":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt") imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
else: else:
raise ValueError("Invalid task : " + task) raise ValueError("Invalid task : " + task + ".")
check_file(imagesets_file) check_file(imagesets_file)
@ -214,7 +214,7 @@ def check_cocodataset(method):
type_check(task, (str,), "task") type_check(task, (str,), "task")
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
raise ValueError("Invalid task type") raise ValueError("Invalid task type: " + task + ".")
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
@ -222,7 +222,7 @@ def check_cocodataset(method):
sampler = param_dict.get('sampler') sampler = param_dict.get('sampler')
if sampler is not None and isinstance(sampler, samplers.PKSampler): if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CocoDataset doesn't support PKSampler") raise ValueError("CocoDataset doesn't support PKSampler.")
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache') cache = param_dict.get('cache')
@ -256,13 +256,13 @@ def check_celebadataset(method):
usage = param_dict.get('usage') usage = param_dict.get('usage')
if usage is not None and usage not in ('all', 'train', 'valid', 'test'): if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
raise ValueError("usage should be one of 'all', 'train', 'valid' or 'test'.") raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.")
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
sampler = param_dict.get('sampler') sampler = param_dict.get('sampler')
if sampler is not None and isinstance(sampler, samplers.PKSampler): if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CelebADataset does not support PKSampler.") raise ValueError("CelebADataset doesn't support PKSampler.")
cache = param_dict.get('cache') cache = param_dict.get('cache')
check_cache_option(cache) check_cache_option(cache)
@ -350,14 +350,14 @@ def check_generatordataset(method):
try: try:
iter(source) iter(source)
except TypeError: except TypeError:
raise TypeError("source should be callable, iterable or random accessible") raise TypeError("source should be callable, iterable or random accessible.")
column_names = param_dict.get('column_names') column_names = param_dict.get('column_names')
if column_names is not None: if column_names is not None:
check_columns(column_names, "column_names") check_columns(column_names, "column_names")
schema = param_dict.get('schema') schema = param_dict.get('schema')
if column_names is None and schema is None: if column_names is None and schema is None:
raise ValueError("Neither columns_names not schema are provided.") raise ValueError("Neither columns_names nor schema are provided.")
if schema is not None: if schema is not None:
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str): if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
@ -375,7 +375,7 @@ def check_generatordataset(method):
shard_id = param_dict.get("shard_id") shard_id = param_dict.get("shard_id")
if (num_shards is None) != (shard_id is None): if (num_shards is None) != (shard_id is None):
# These two parameters appear together. # These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together") raise ValueError("num_shards and shard_id need to be passed in together.")
if num_shards is not None: if num_shards is not None:
check_pos_int32(num_shards, "num_shards") check_pos_int32(num_shards, "num_shards")
if shard_id >= num_shards: if shard_id >= num_shards:
@ -384,19 +384,19 @@ def check_generatordataset(method):
sampler = param_dict.get("sampler") sampler = param_dict.get("sampler")
if sampler is not None: if sampler is not None:
if isinstance(sampler, samplers.PKSampler): if isinstance(sampler, samplers.PKSampler):
raise ValueError("PKSampler is not supported by GeneratorDataset") raise ValueError("GeneratorDataset doesn't support PKSampler.")
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)): samplers.WeightedRandomSampler, samplers.Sampler)):
try: try:
iter(sampler) iter(sampler)
except TypeError: except TypeError:
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers") raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.")
if sampler is not None and not hasattr(source, "__getitem__"): if sampler is not None and not hasattr(source, "__getitem__"):
raise ValueError("sampler is not supported if source does not have attribute '__getitem__'") raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.")
if num_shards is not None and not hasattr(source, "__getitem__"): if num_shards is not None and not hasattr(source, "__getitem__"):
raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'") raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -433,7 +433,7 @@ def check_pad_info(key, val):
type_check(key, (str,), "key in pad_info") type_check(key, (str,), "key in pad_info")
if val is not None: if val is not None:
assert len(val) == 2, "value of pad_info should be a tuple of size 2" assert len(val) == 2, "value of pad_info should be a tuple of size 2."
type_check(val, (tuple,), "value in pad_info") type_check(val, (tuple,), "value in pad_info")
if val[0] is not None: if val[0] is not None:
@ -521,14 +521,14 @@ def check_batch(method):
if callable(batch_size): if callable(batch_size):
sig = ins.signature(batch_size) sig = ins.signature(batch_size)
if len(sig.parameters) != 1: if len(sig.parameters) != 1:
raise ValueError("batch_size callable should take one parameter (BatchInfo).") raise ValueError("callable batch_size should take one parameter (BatchInfo).")
if num_parallel_workers is not None: if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers) check_num_parallel_workers(num_parallel_workers)
type_check(drop_remainder, (bool,), "drop_remainder") type_check(drop_remainder, (bool,), "drop_remainder")
if (pad_info is not None) and (per_batch_map is not None): if (pad_info is not None) and (per_batch_map is not None):
raise ValueError("pad_info and per_batch_map can't both be set") raise ValueError("pad_info and per_batch_map can't both be set.")
if pad_info is not None: if pad_info is not None:
type_check(param_dict["pad_info"], (dict,), "pad_info") type_check(param_dict["pad_info"], (dict,), "pad_info")
@ -542,7 +542,7 @@ def check_batch(method):
if input_columns is not None: if input_columns is not None:
check_columns(input_columns, "input_columns") check_columns(input_columns, "input_columns")
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
raise ValueError("the signature of per_batch_map should match with input columns") raise ValueError("The signature of per_batch_map should match with input columns.")
if output_columns is not None: if output_columns is not None:
check_columns(output_columns, "output_columns") check_columns(output_columns, "output_columns")
@ -816,13 +816,13 @@ def check_add_column(method):
type_check(name, (str,), "name") type_check(name, (str,), "name")
if not name: if not name:
raise TypeError("Expected non-empty string.") raise TypeError("Expected non-empty string for column name.")
if de_type is not None: if de_type is not None:
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
raise TypeError("Unknown column type.") raise TypeError("Unknown column type: {}.".format(de_type))
else: else:
raise TypeError("Expected non-empty string.") raise TypeError("Expected non-empty string for de_type.")
if shape is not None: if shape is not None:
type_check(shape, (list,), "shape") type_check(shape, (list,), "shape")
@ -848,12 +848,12 @@ def check_cluedataset(method):
# check task # check task
task_param = param_dict.get('task') task_param = param_dict.get('task')
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']: if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL") raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.")
# check usage # check usage
usage_param = param_dict.get('usage') usage_param = param_dict.get('usage')
if usage_param not in ['train', 'test', 'eval']: if usage_param not in ['train', 'test', 'eval']:
raise ValueError("usage should be train, test or eval") raise ValueError("usage should be 'train', 'test' or 'eval'.")
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
@ -883,7 +883,7 @@ def check_csvdataset(method):
field_delim = param_dict.get('field_delim') field_delim = param_dict.get('field_delim')
type_check(field_delim, (str,), 'field delim') type_check(field_delim, (str,), 'field delim')
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
raise ValueError("field_delim is not legal.") raise ValueError("field_delim is invalid.")
# check column_defaults # check column_defaults
column_defaults = param_dict.get('column_defaults') column_defaults = param_dict.get('column_defaults')
@ -892,7 +892,7 @@ def check_csvdataset(method):
raise TypeError("column_defaults should be type of list.") raise TypeError("column_defaults should be type of list.")
for item in column_defaults: for item in column_defaults:
if not isinstance(item, (str, int, float)): if not isinstance(item, (str, int, float)):
raise TypeError("column type is not legal.") raise TypeError("column type in column_defaults is invalid.")
# check column_names: must be list of string. # check column_names: must be list of string.
column_names = param_dict.get("column_names") column_names = param_dict.get("column_names")
@ -997,7 +997,7 @@ def check_gnn_graphdata(method):
raise ValueError("The hostname is illegal") raise ValueError("The hostname is illegal")
type_check(working_mode, (str,), "working_mode") type_check(working_mode, (str,), "working_mode")
if working_mode not in {'local', 'client', 'server'}: if working_mode not in {'local', 'client', 'server'}:
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'") raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.")
type_check(port, (int,), "port") type_check(port, (int,), "port")
check_value(port, (1024, 65535), "port") check_value(port, (1024, 65535), "port")
type_check(num_client, (int,), "num_client") type_check(num_client, (int,), "num_client")
@ -1073,17 +1073,17 @@ def check_gnn_get_sampled_neighbors(method):
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
if not neighbor_nums or len(neighbor_nums) > 6: if not neighbor_nums or len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
'neighbor_nums', len(neighbor_nums))) 'neighbor_nums', len(neighbor_nums)))
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
if not neighbor_types or len(neighbor_types) > 6: if not neighbor_types or len(neighbor_types) > 6:
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
'neighbor_types', len(neighbor_types))) 'neighbor_types', len(neighbor_types)))
if len(neighbor_nums) != len(neighbor_types): if len(neighbor_nums) != len(neighbor_types):
raise ValueError( raise ValueError(
"The number of members of neighbor_nums and neighbor_types is inconsistent") "The number of members of neighbor_nums and neighbor_types is inconsistent.")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -1139,17 +1139,17 @@ def check_aligned_list(param, param_name, member_type):
check_aligned_list(member, param_name, member_type) check_aligned_list(member, param_name, member_type)
if member_have_list not in (None, True): if member_have_list not in (None, True):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format( raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
param_name)) param_name))
if list_len is not None and len(member) != list_len: if list_len is not None and len(member) != list_len:
raise TypeError("The size of each member of parameter {0} is inconsistent".format( raise TypeError("The size of each member of parameter {0} is inconsistent.".format(
param_name)) param_name))
member_have_list = True member_have_list = True
list_len = len(member) list_len = len(member)
else: else:
type_check(member, (member_type,), param_name) type_check(member, (member_type,), param_name)
if member_have_list not in (None, False): if member_have_list not in (None, False):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format( raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
param_name)) param_name))
member_have_list = False member_have_list = False
@ -1248,7 +1248,7 @@ def check_paddeddataset(method):
padded_samples = param_dict.get("padded_samples") padded_samples = param_dict.get("padded_samples")
if not padded_samples: if not padded_samples:
raise ValueError("Argument padded_samples cannot be empty") raise ValueError("padded_samples cannot be empty.")
type_check(padded_samples, (list,), "padded_samples") type_check(padded_samples, (list,), "padded_samples")
type_check(padded_samples[0], (dict,), "padded_element") type_check(padded_samples[0], (dict,), "padded_element")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -1261,6 +1261,6 @@ def check_cache_option(cache):
if cache is not None: if cache is not None:
if os.getenv('MS_ENABLE_CACHE') != 'TRUE': if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
# temporary disable cache feature in the current release # temporary disable cache feature in the current release
raise ValueError("Caching is disabled in the current release") raise ValueError("Caching is disabled in the current release.")
from . import cache_client from . import cache_client
type_check(cache, (cache_client.DatasetCache,), "cache") type_check(cache, (cache_client.DatasetCache,), "cache")

@ -257,7 +257,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
for k, v in user_dict.items(): for k, v in user_dict.items():
self.add_word(k, v) self.add_word(k, v)
else: else:
raise ValueError("the type of user_dict must str or dict") raise TypeError("The type of user_dict must str or dict.")
def __add_dict_py_file(self, file_path): def __add_dict_py_file(self, file_path):
"""Add user defined word by file""" """Add user defined word by file"""
@ -273,7 +273,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
"""parser user defined word by file""" """parser user defined word by file"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise ValueError( raise ValueError(
"user dict file {} is not exist".format(file_path)) "user dict file {} is not exist.".format(file_path))
real_file_path = os.path.realpath(file_path) real_file_path = os.path.realpath(file_path)
file_dict = open(real_file_path) file_dict = open(real_file_path)
data_re = re.compile('^(.+?)( [0-9]+)?$', re.U) data_re = re.compile('^(.+?)( [0-9]+)?$', re.U)
@ -285,7 +285,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
words = data_re.match(data).groups() words = data_re.match(data).groups()
if len(words) != 2: if len(words) != 2:
raise ValueError( raise ValueError(
"user dict file {} format error".format(real_file_path)) "user dict file {} format error.".format(real_file_path))
words_list.append(words) words_list.append(words)
file_dict.close() file_dict.close()
return words_list return words_list
@ -295,14 +295,14 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
try: try:
data = data.decode('utf-8') data = data.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
raise ValueError("user dict file must utf8") raise ValueError("user dict file must be utf8 format.")
return data.lstrip('\ufeff') return data.lstrip('\ufeff')
def __check_path__(self, model_path): def __check_path__(self, model_path):
"""check model path""" """check model path"""
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise ValueError( raise ValueError(
" jieba mode file {} is not exist".format(model_path)) " jieba mode file {} is not exist.".format(model_path))
class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp): class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp):
@ -528,7 +528,7 @@ if platform.system().lower() != 'windows':
def __init__(self, normalize_form=NormalizeForm.NFKC): def __init__(self, normalize_form=NormalizeForm.NFKC):
if not isinstance(normalize_form, NormalizeForm): if not isinstance(normalize_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be NormalizeForm.") raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")
self.normalize_form = DE_C_INTER_NORMALIZE_FORM[normalize_form] self.normalize_form = DE_C_INTER_NORMALIZE_FORM[normalize_form]
super().__init__(self.normalize_form) super().__init__(self.normalize_form)
@ -650,7 +650,7 @@ if platform.system().lower() != 'windows':
def __init__(self, lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE, def __init__(self, lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE,
preserve_unused_token=True, with_offsets=False): preserve_unused_token=True, with_offsets=False):
if not isinstance(normalization_form, NormalizeForm): if not isinstance(normalization_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be NormalizeForm.") raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")
self.lower_case = lower_case self.lower_case = lower_case
self.keep_whitespace = keep_whitespace self.keep_whitespace = keep_whitespace
@ -710,7 +710,7 @@ if platform.system().lower() != 'windows':
lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE, lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE,
preserve_unused_token=True, with_offsets=False): preserve_unused_token=True, with_offsets=False):
if not isinstance(normalization_form, NormalizeForm): if not isinstance(normalization_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be NormalizeForm.") raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")
self.vocab = vocab self.vocab = vocab
self.suffix_indicator = suffix_indicator self.suffix_indicator = suffix_indicator

@ -417,7 +417,7 @@ def check_python_tokenizer(method):
[tokenizer], _ = parse_user_args(method, *args, **kwargs) [tokenizer], _ = parse_user_args(method, *args, **kwargs)
if not callable(tokenizer): if not callable(tokenizer):
raise TypeError("tokenizer is not a callable Python function") raise TypeError("tokenizer is not a callable Python function.")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -437,8 +437,7 @@ def check_from_dataset_sentencepiece(method):
if vocab_size is not None: if vocab_size is not None:
check_uint32(vocab_size, "vocab_size") check_uint32(vocab_size, "vocab_size")
else: else:
raise TypeError("vocab_size must be provided") raise TypeError("vocab_size must be provided.")
if character_coverage is not None: if character_coverage is not None:
type_check(character_coverage, (float,), "character_coverage") type_check(character_coverage, (float,), "character_coverage")

@ -49,7 +49,7 @@ def compose(transforms, *args):
if all_numpy(args): if all_numpy(args):
return args return args
raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(args))) raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args)))
raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args))) raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args)))

@ -971,7 +971,7 @@ class Cutout:
np_img (numpy.ndarray), NumPy image array with square patches cut out. np_img (numpy.ndarray), NumPy image array with square patches cut out.
""" """
if not isinstance(np_img, np.ndarray): if not isinstance(np_img, np.ndarray):
raise TypeError('img should be NumPy array. Got {}'.format(type(np_img))) raise TypeError("img should be NumPy array. Got {}.".format(type(np_img)))
_, image_h, image_w = np_img.shape _, image_h, image_w = np_img.shape
scale = (self.length * self.length) / (image_h * image_w) scale = (self.length * self.length) / (image_h * image_w)
bounded = False bounded = False

@ -26,7 +26,7 @@ from PIL import Image, ImageOps, ImageEnhance, __version__
from .utils import Inter from .utils import Inter
from ..core.py_util_helpers import is_numpy from ..core.py_util_helpers import is_numpy
augment_error_message = 'img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data.' augment_error_message = "img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data."
def is_pil(img): def is_pil(img):
@ -55,19 +55,19 @@ def normalize(img, mean, std):
img (numpy.ndarray), Normalized image. img (numpy.ndarray), Normalized image.
""" """
if not is_numpy(img): if not is_numpy(img):
raise TypeError('img should be NumPy image. Got {}'.format(type(img))) raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
num_channels = img.shape[0] # shape is (C, H, W) num_channels = img.shape[0] # shape is (C, H, W)
if len(mean) != len(std): if len(mean) != len(std):
raise ValueError("Length of mean and std must be equal") raise ValueError("Length of mean and std must be equal.")
# if length equal to 1, adjust the mean and std arrays to have the correct # if length equal to 1, adjust the mean and std arrays to have the correct
# number of channels (replicate the values) # number of channels (replicate the values)
if len(mean) == 1: if len(mean) == 1:
mean = [mean[0]] * num_channels mean = [mean[0]] * num_channels
std = [std[0]] * num_channels std = [std[0]] * num_channels
elif len(mean) != num_channels: elif len(mean) != num_channels:
raise ValueError("Length of mean and std must both be 1 or equal to the number of channels({0})" raise ValueError("Length of mean and std must both be 1 or equal to the number of channels({0})."
.format(num_channels)) .format(num_channels))
mean = np.array(mean, dtype=img.dtype) mean = np.array(mean, dtype=img.dtype)
@ -108,7 +108,7 @@ def hwc_to_chw(img):
""" """
if is_numpy(img): if is_numpy(img):
return img.transpose(2, 0, 1).copy() return img.transpose(2, 0, 1).copy()
raise TypeError('img should be NumPy array. Got {}'.format(type(img))) raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
def to_tensor(img, output_type): def to_tensor(img, output_type):
@ -123,11 +123,11 @@ def to_tensor(img, output_type):
img (numpy.ndarray), Converted image. img (numpy.ndarray), Converted image.
""" """
if not (is_pil(img) or is_numpy(img)): if not (is_pil(img) or is_numpy(img)):
raise TypeError('img should be PIL image or NumPy array. Got {}'.format(type(img))) raise TypeError("img should be PIL image or NumPy array. Got {}.".format(type(img)))
img = np.asarray(img) img = np.asarray(img)
if img.ndim not in (2, 3): if img.ndim not in (2, 3):
raise ValueError('img dimension should be 2 or 3. Got {}'.format(img.ndim)) raise ValueError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
if img.ndim == 2: if img.ndim == 2:
img = img[:, :, None] img = img[:, :, None]
@ -265,7 +265,7 @@ def resize(img, size, interpolation=Inter.BILINEAR):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)): if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)):
raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.' raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.'
'Got {}'.format(size)) 'Got {}.'.format(size))
if isinstance(size, int): if isinstance(size, int):
img_width, img_height = img.size img_width, img_height = img.size
@ -424,7 +424,7 @@ def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
img_width, img_height = img.size img_width, img_height = img.size
height, width = size height, width = size
if height > img_height or width > img_width: if height > img_height or width > img_width:
raise ValueError("Crop size {} is larger than input image size {}".format(size, (img_height, img_width))) raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
if width == img_width and height == img_height: if width == img_width and height == img_height:
return 0, 0, img_height, img_width return 0, 0, img_height, img_width
@ -558,7 +558,7 @@ def to_type(img, output_type):
img (numpy.ndarray), Converted image. img (numpy.ndarray), Converted image.
""" """
if not is_numpy(img): if not is_numpy(img):
raise TypeError('img should be NumPy image. Got {}'.format(type(img))) raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
return img.astype(output_type) return img.astype(output_type)
@ -632,7 +632,7 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
elif isinstance(value, (list, tuple)) and len(value) == 2: elif isinstance(value, (list, tuple)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]: if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("Please check your value range of {} is valid and " raise ValueError("Please check your value range of {} is valid and "
"within the bound {}".format(input_name, bound)) "within the bound {}.".format(input_name, bound))
else: else:
raise TypeError("Input of {} should be either a single value, or a list/tuple of " raise TypeError("Input of {} should be either a single value, or a list/tuple of "
"length 2.".format(input_name)) "length 2.".format(input_name))
@ -695,7 +695,7 @@ def random_rotation(img, degrees, resample, expand, center, fill_value):
if len(degrees) != 2: if len(degrees) != 2:
raise ValueError("If degrees is a sequence, the length must be 2.") raise ValueError("If degrees is a sequence, the length must be 2.")
else: else:
raise TypeError("Degrees must be a single non-negative number or a sequence") raise TypeError("Degrees must be a single non-negative number or a sequence.")
angle = random.uniform(degrees[0], degrees[1]) angle = random.uniform(degrees[0], degrees[1])
return rotate(img, angle, resample, expand, center, fill_value) return rotate(img, angle, resample, expand, center, fill_value)
@ -729,7 +729,7 @@ def five_crop(img, size):
img_width, img_height = img.size img_width, img_height = img.size
crop_height, crop_width = size crop_height, crop_width = size
if crop_height > img_height or crop_width > img_width: if crop_height > img_height or crop_width > img_width:
raise ValueError("Crop size {} is larger than input image size {}".format(size, (img_height, img_width))) raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
center = center_crop(img, (crop_height, crop_width)) center = center_crop(img, (crop_height, crop_width))
top_left = img.crop((0, 0, crop_width, crop_height)) top_left = img.crop((0, 0, crop_width, crop_height))
top_right = img.crop((img_width - crop_width, 0, img_width, crop_height)) top_right = img.crop((img_width - crop_width, 0, img_width, crop_height))
@ -802,7 +802,7 @@ def grayscale(img, num_output_channels):
np_img = np.dstack([np_gray, np_gray, np_gray]) np_img = np.dstack([np_gray, np_gray, np_gray])
img = Image.fromarray(np_img, 'RGB') img = Image.fromarray(np_img, 'RGB')
else: else:
raise ValueError('num_output_channels should be either 1 or 3. Got {}'.format(num_output_channels)) raise ValueError('num_output_channels should be either 1 or 3. Got {}.'.format(num_output_channels))
return img return img
@ -859,7 +859,7 @@ def pad(img, padding, fill_value, padding_mode):
raise TypeError("fill_value can be any of: an integer, a string or a tuple.") raise TypeError("fill_value can be any of: an integer, a string or a tuple.")
if padding_mode not in ['constant', 'edge', 'reflect', 'symmetric']: if padding_mode not in ['constant', 'edge', 'reflect', 'symmetric']:
raise ValueError("Padding mode can be any of ['constant', 'edge', 'reflect', 'symmetric'].") raise ValueError("Padding mode should be 'constant', 'edge', 'reflect', or 'symmetric'.")
if padding_mode == 'constant': if padding_mode == 'constant':
if img.mode == 'P': if img.mode == 'P':
@ -946,7 +946,7 @@ def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
"""Helper function to get parameters for RandomErasing/ Cutout. """Helper function to get parameters for RandomErasing/ Cutout.
""" """
if not is_numpy(np_img): if not is_numpy(np_img):
raise TypeError('img should be NumPy array. Got {}'.format(type(np_img))) raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
image_c, image_h, image_w = np_img.shape image_c, image_h, image_w = np_img.shape
area = image_h * image_w area = image_h * image_w
@ -1009,7 +1009,7 @@ def erase(np_img, i, j, height, width, erase_value, inplace=False):
np_img (numpy.ndarray), Erased NumPy image array. np_img (numpy.ndarray), Erased NumPy image array.
""" """
if not is_numpy(np_img): if not is_numpy(np_img):
raise TypeError('img should be NumPy array. Got {}'.format(type(np_img))) raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
if not inplace: if not inplace:
np_img = np_img.copy() np_img = np_img.copy()
@ -1111,7 +1111,7 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
else: else:
raise ValueError( raise ValueError(
"Shear should be a single value or a tuple/list containing " + "Shear should be a single value or a tuple/list containing " +
"two values. Got {}".format(shear)) "two values. Got {}.".format(shear))
scale = 1.0 / scale scale = 1.0 / scale
@ -1239,13 +1239,13 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs. np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs.
""" """
if not is_numpy(np_rgb_imgs): if not is_numpy(np_rgb_imgs):
raise TypeError('img should be NumPy image. Got {}'.format(type(np_rgb_imgs))) raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
shape_size = len(np_rgb_imgs.shape) shape_size = len(np_rgb_imgs.shape)
if not shape_size in (3, 4): if not shape_size in (3, 4):
raise TypeError('img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \ raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \
Got {}'.format(np_rgb_imgs.shape)) Got {}.".format(np_rgb_imgs.shape))
if shape_size == 3: if shape_size == 3:
batch_size = 0 batch_size = 0
@ -1261,7 +1261,7 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
num_channels = np_rgb_imgs.shape[1] num_channels = np_rgb_imgs.shape[1]
if num_channels != 3: if num_channels != 3:
raise TypeError('img should be 3 channels RGB img. Got {} channels'.format(num_channels)) raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
if batch_size == 0: if batch_size == 0:
return rgb_to_hsv(np_rgb_imgs, is_hwc) return rgb_to_hsv(np_rgb_imgs, is_hwc)
return np.array([rgb_to_hsv(img, is_hwc) for img in np_rgb_imgs]) return np.array([rgb_to_hsv(img, is_hwc) for img in np_rgb_imgs])
@ -1307,13 +1307,13 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of np_hsv_imgs. np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of np_hsv_imgs.
""" """
if not is_numpy(np_hsv_imgs): if not is_numpy(np_hsv_imgs):
raise TypeError('img should be NumPy image. Got {}'.format(type(np_hsv_imgs))) raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
shape_size = len(np_hsv_imgs.shape) shape_size = len(np_hsv_imgs.shape)
if not shape_size in (3, 4): if not shape_size in (3, 4):
raise TypeError('img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \ raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \
Got {}'.format(np_hsv_imgs.shape)) Got {}.".format(np_hsv_imgs.shape))
if shape_size == 3: if shape_size == 3:
batch_size = 0 batch_size = 0
@ -1329,7 +1329,7 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
num_channels = np_hsv_imgs.shape[1] num_channels = np_hsv_imgs.shape[1]
if num_channels != 3: if num_channels != 3:
raise TypeError('img should be 3 channels RGB img. Got {} channels'.format(num_channels)) raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
if batch_size == 0: if batch_size == 0:
return hsv_to_rgb(np_hsv_imgs, is_hwc) return hsv_to_rgb(np_hsv_imgs, is_hwc)
return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs]) return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
@ -1349,7 +1349,7 @@ def random_color(img, degrees):
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL image. Got {}'.format(type(img))) raise TypeError("img should be PIL image. Got {}.".format(type(img)))
v = (degrees[1] - degrees[0]) * random.random() + degrees[0] v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
return ImageEnhance.Color(img).enhance(v) return ImageEnhance.Color(img).enhance(v)
@ -1369,7 +1369,7 @@ def random_sharpness(img, degrees):
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL image. Got {}'.format(type(img))) raise TypeError("img should be PIL image. Got {}.".format(type(img)))
v = (degrees[1] - degrees[0]) * random.random() + degrees[0] v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
return ImageEnhance.Sharpness(img).enhance(v) return ImageEnhance.Sharpness(img).enhance(v)
@ -1390,7 +1390,7 @@ def auto_contrast(img, cutoff, ignore):
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL image. Got {}'.format(type(img))) raise TypeError("img should be PIL image. Got {}.".format(type(img)))
return ImageOps.autocontrast(img, cutoff, ignore) return ImageOps.autocontrast(img, cutoff, ignore)
@ -1408,7 +1408,7 @@ def invert_color(img):
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL image. Got {}'.format(type(img))) raise TypeError("img should be PIL image. Got {}.".format(type(img)))
return ImageOps.invert(img) return ImageOps.invert(img)
@ -1426,7 +1426,7 @@ def equalize(img):
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL image. Got {}'.format(type(img))) raise TypeError("img should be PIL image. Got {}.".format(type(img)))
return ImageOps.equalize(img) return ImageOps.equalize(img)

@ -79,7 +79,7 @@ def check_mix_up_batch_c(method):
def check_normalize_c_param(mean, std): def check_normalize_c_param(mean, std):
if len(mean) != len(std): if len(mean) != len(std):
raise ValueError("Length of mean and std must be equal") raise ValueError("Length of mean and std must be equal.")
for mean_value in mean: for mean_value in mean:
check_pos_float32(mean_value) check_pos_float32(mean_value)
for std_value in std: for std_value in std:
@ -88,7 +88,7 @@ def check_normalize_c_param(mean, std):
def check_normalize_py_param(mean, std): def check_normalize_py_param(mean, std):
if len(mean) != len(std): if len(mean) != len(std):
raise ValueError("Length of mean and std must be equal") raise ValueError("Length of mean and std must be equal.")
for mean_value in mean: for mean_value in mean:
check_value(mean_value, [0., 1.], "mean_value") check_value(mean_value, [0., 1.], "mean_value")
for std_value in std: for std_value in std:
@ -372,7 +372,7 @@ def check_num_channels(method):
if num_output_channels is not None: if num_output_channels is not None:
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError("Number of channels of the output grayscale image" raise ValueError("Number of channels of the output grayscale image"
"should be either 1 or 3. Got {0}".format(num_output_channels)) "should be either 1 or 3. Got {0}.".format(num_output_channels))
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -471,7 +471,7 @@ def check_linear_transform(method):
if transformation_matrix.shape[0] != transformation_matrix.shape[1]: if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
raise ValueError("transformation_matrix should be a square matrix. " raise ValueError("transformation_matrix should be a square matrix. "
"Got shape {} instead".format(transformation_matrix.shape)) "Got shape {} instead.".format(transformation_matrix.shape))
if mean_vector.shape[0] != transformation_matrix.shape[0]: if mean_vector.shape[0] != transformation_matrix.shape[0]:
raise ValueError("mean_vector length {0} should match either one dimension of the square" raise ValueError("mean_vector length {0} should match either one dimension of the square"
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape)) "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
@ -556,7 +556,7 @@ def check_uniform_augment_cpp(method):
check_positive(num_ops, "num_ops") check_positive(num_ops, "num_ops")
if num_ops > len(transforms): if num_ops > len(transforms):
raise ValueError("num_ops is greater than transforms list size") raise ValueError("num_ops is greater than transforms list size.")
type_check_list(transforms, (TensorOp,), "tensor_ops") type_check_list(transforms, (TensorOp,), "tensor_ops")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -693,11 +693,11 @@ def check_random_solarize(method):
type_check(threshold, (tuple,), "threshold") type_check(threshold, (tuple,), "threshold")
type_check_list(threshold, (int,), "threshold") type_check_list(threshold, (int,), "threshold")
if len(threshold) != 2: if len(threshold) != 2:
raise ValueError("threshold must be a sequence of two numbers") raise ValueError("threshold must be a sequence of two numbers.")
for element in threshold: for element in threshold:
check_value(element, (0, UINT8_MAX)) check_value(element, (0, UINT8_MAX))
if threshold[1] < threshold[0]: if threshold[1] < threshold[0]:
raise ValueError("threshold must be in min max format numbers") raise ValueError("threshold must be in min max format numbers.")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

@ -41,7 +41,7 @@ def test_compose():
# test one python transform followed by a C transform. type after oneHot is float (mixed use-case) # test one python transform followed by a C transform. type after oneHot is float (mixed use-case)
assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[[0, 1]], [[1, 0]]] assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[[0, 1]], [[1, 0]]]
# test exceptions. compose, randomApply randomChoice use the same validator # test exceptions. compose, randomApply randomChoice use the same validator
assert "op_list[0] is not a c_transform op" in test_config([1, 0], [1, ops.TypeCast(mstype.int32)]) assert "op_list[0] is neither a c_transform op" in test_config([1, 0], [1, ops.TypeCast(mstype.int32)])
# test empty op list # test empty op list
assert "op_list can not be empty." in test_config([1, 0], []) assert "op_list can not be empty." in test_config([1, 0], [])

@ -63,7 +63,7 @@ def test_compose():
# Test exceptions. # Test exceptions.
with pytest.raises(TypeError) as error_info: with pytest.raises(TypeError) as error_info:
c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)]) c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
assert "op_list[0] is not a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value) assert "op_list[0] is neither a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value)
# Test empty op list # Test empty op list
with pytest.raises(ValueError) as error_info: with pytest.raises(ValueError) as error_info:

@ -510,7 +510,8 @@ def test_generator_error_3():
for _ in data1: for _ in data1:
pass pass
assert "When (len(input_columns) != len(output_columns)), column_order must be specified." in str(info.value) assert "When length of input_columns and output_columns are not equal, column_order must be specified." in \
str(info.value)
def test_generator_error_4(): def test_generator_error_4():

@ -279,7 +279,7 @@ def test_cv_minddataset_partition_num_samples_equals_0():
with pytest.raises(Exception) as error_info: with pytest.raises(Exception) as error_info:
partitions(5) partitions(5)
try: try:
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value) assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value)
except Exception as error: except Exception as error:
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))

@ -242,7 +242,7 @@ def test_normalize_exception_unequal_size_c():
_ = c_vision.Normalize([100, 250, 125], [50, 50, 75, 75]) _ = c_vision.Normalize([100, 250, 125], [50, 50, 75, 75])
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Length of mean and std must be equal" assert str(e) == "Length of mean and std must be equal."
def test_normalize_exception_unequal_size_py(): def test_normalize_exception_unequal_size_py():
@ -255,7 +255,7 @@ def test_normalize_exception_unequal_size_py():
_ = py_vision.Normalize([0.50, 0.30, 0.75], [0.18, 0.32, 0.71, 0.72]) _ = py_vision.Normalize([0.50, 0.30, 0.75], [0.18, 0.32, 0.71, 0.72])
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Length of mean and std must be equal" assert str(e) == "Length of mean and std must be equal."
def test_normalize_exception_invalid_size_py(): def test_normalize_exception_invalid_size_py():

@ -483,7 +483,7 @@ def test_clue_padded_and_skip_with_0_samples():
count += 1 count += 1
assert count == 0 assert count == 0
with pytest.raises(ValueError, match="There is no samples in the "): with pytest.raises(ValueError, match="There are no samples in the "):
dataset = dataset.concat(data_copy1) dataset = dataset.concat(data_copy1)
count = 0 count = 0
for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):

@ -41,8 +41,8 @@ def test_random_select_subpolicy():
# test exceptions # test exceptions
assert "policy can not be empty." in test_config([[1, 2, 3]], []) assert "policy can not be empty." in test_config([[1, 2, 3]], [])
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]]) assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
assert "op of (op, prob) in policy[1][0] is not a c_transform op (TensorOp) nor a callable pyfunc" in test_config( assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOp) nor a callable pyfunc" \
[[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]]) in test_config([[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [ assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]]) [(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])

Loading…
Cancel
Save