Cleanup work for Concate, Mask, Slice, PadEnd and TruncatePair

pull/2372/head
hesham 5 years ago
parent bc4b1c2460
commit 674415f7be

@ -403,7 +403,7 @@ def check_to_number(method):
if not isinstance(data_type, typing.Type): if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.") raise TypeError("data_type is not a MindSpore data type.")
if not data_type in mstype.number_type: if data_type not in mstype.number_type:
raise TypeError("data_type is not numeric data type.") raise TypeError("data_type is not numeric data type.")
kwargs["data_type"] = data_type kwargs["data_type"] = data_type

@ -79,12 +79,13 @@ class Slice(cde.SliceOp):
(Currently only rank 1 Tensors are supported) (Currently only rank 1 Tensors are supported)
Args: Args:
*slices: Maximum n number of objects to slice a tensor of rank n. *slices(Variable length argument list): Maximum `n` number of arguments to slice a tensor of rank `n`.
One object in slices can be one of: One object in slices can be one of:
1. int: slice this index only. Negative index is supported. 1. int: slice this index only. Negative index is supported.
2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`. 2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`.
3. None: slice the whole dimension. Similar to `:` in python indexing. 3. None: slice the whole dimension. Similar to `:` in python indexing.
4. Ellipses ...: slice all dimensions between the two slices. 4. Ellipses ...: slice all dimensions between the two slices.
Examples: Examples:
>>> # Data before >>> # Data before
>>> # | col | >>> # | col |
@ -134,11 +135,13 @@ class Mask(cde.MaskOp):
""" """
Mask content of the input tensor with the given predicate. Mask content of the input tensor with the given predicate.
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False. Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
Args: Args:
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
constant (python types (str, int, float, or bool): constant to be compared to. constant (python types (str, int, float, or bool): constant to be compared to.
Constant will be casted to the type of the input tensor Constant will be casted to the type of the input tensor
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
Examples: Examples:
>>> # Data before >>> # Data before
>>> # | col1 | >>> # | col1 |
@ -163,11 +166,13 @@ class Mask(cde.MaskOp):
class PadEnd(cde.PadEndOp): class PadEnd(cde.PadEndOp):
""" """
Pad input tensor according to `pad_shape`, need to have same rank. Pad input tensor according to `pad_shape`, need to have same rank.
Args: Args:
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values. not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty
string in case of Tensors of strings. string in case of Tensors of strings.
Examples: Examples:
>>> # Data before >>> # Data before
>>> # | col | >>> # | col |
@ -201,21 +206,25 @@ class Concatenate(cde.ConcatenateOp):
@check_concat_type @check_concat_type
def __init__(self, axis=0, prepend=None, append=None): def __init__(self, axis=0, prepend=None, append=None):
# add some validations here later if prepend is not None:
prepend = cde.Tensor(np.array(prepend))
if append is not None:
append = cde.Tensor(np.array(append))
super().__init__(axis, prepend, append) super().__init__(axis, prepend, append)
class Duplicate(cde.DuplicateOp): class Duplicate(cde.DuplicateOp):
""" """
Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list. Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list.
Examples:
Examples:
>>> # Data before >>> # Data before
>>> # | x | >>> # | x |
>>> # +---------+ >>> # +---------+
>>> # | [1,2,3] | >>> # | [1,2,3] |
>>> # +---------+ >>> # +---------+
>>> data = data.map(input_columns=["x"], operations=Duplicate(), >>> data = data.map(input_columns=["x"], operations=Duplicate(),
>>> output_columns=["x", "y"], output_order=["x", "y"]) >>> output_columns=["x", "y"], columns_order=["x", "y"])
>>> # Data after >>> # Data after
>>> # | x | y | >>> # | x | y |
>>> # +---------+---------+ >>> # +---------+---------+

@ -17,7 +17,6 @@
from functools import wraps from functools import wraps
import numpy as np import numpy as np
import mindspore._c_dataengine as cde
from mindspore._c_expression import typing from mindspore._c_expression import typing
# POS_INT_MIN is used to limit values from starting from 0 # POS_INT_MIN is used to limit values from starting from 0
@ -243,12 +242,13 @@ def check_mask_op(method):
if not isinstance(constant, (str, float, bool, int, bytes)): if not isinstance(constant, (str, float, bool, int, bytes)):
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int") raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
if not isinstance(dtype, typing.Type): if dtype is not None:
raise TypeError("dtype is not a MindSpore data type.") if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
kwargs["dtype"] = dtype
kwargs["operator"] = operator kwargs["operator"] = operator
kwargs["constant"] = constant kwargs["constant"] = constant
kwargs["dtype"] = dtype
return method(self, **kwargs) return method(self, **kwargs)
@ -269,8 +269,10 @@ def check_pad_end(method):
if pad_shape is None: if pad_shape is None:
raise ValueError("pad_shape is not provided.") raise ValueError("pad_shape is not provided.")
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)): if pad_value is not None:
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes.") if not isinstance(pad_value, (str, float, bool, int, bytes)):
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
kwargs["pad_value"] = pad_value
if not isinstance(pad_shape, list): if not isinstance(pad_shape, list):
raise TypeError("pad_shape must be a list") raise TypeError("pad_shape must be a list")
@ -283,7 +285,6 @@ def check_pad_end(method):
raise TypeError("a value in the list is not an integer.") raise TypeError("a value in the list is not an integer.")
kwargs["pad_shape"] = pad_shape kwargs["pad_shape"] = pad_shape
kwargs["pad_value"] = pad_value
return method(self, **kwargs) return method(self, **kwargs)
@ -303,30 +304,22 @@ def check_concat_type(method):
if "axis" in kwargs: if "axis" in kwargs:
axis = kwargs.get("axis") axis = kwargs.get("axis")
if not isinstance(axis, (type(None), int)): if axis is not None:
raise TypeError("axis type is not valid, must be None or an integer.") if not isinstance(axis, int):
raise TypeError("axis type is not valid, must be an integer.")
if isinstance(axis, type(None)): if axis not in (0, -1):
axis = 0 raise ValueError("only 1D concatenation supported.")
kwargs["axis"] = axis
if axis not in (None, 0, -1):
raise ValueError("only 1D concatenation supported.") if prepend is not None:
if not isinstance(prepend, (type(None), np.ndarray)):
if not isinstance(prepend, (type(None), np.ndarray)): raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") kwargs["prepend"] = prepend
if not isinstance(append, (type(None), np.ndarray)): if append is not None:
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") if not isinstance(append, (type(None), np.ndarray)):
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
if isinstance(prepend, np.ndarray): kwargs["append"] = append
prepend = cde.Tensor(prepend)
if isinstance(append, np.ndarray):
append = cde.Tensor(append)
kwargs["axis"] = axis
kwargs["prepend"] = prepend
kwargs["append"] = append
return method(self, **kwargs) return method(self, **kwargs)

@ -62,7 +62,7 @@ def mask_compare(array, op, constant, dtype=mstype.bool_):
np.testing.assert_array_equal(array, d[0]) np.testing.assert_array_equal(array, d[0])
def test_int_comparison(): def test_mask_int_comparison():
for k in mstype_to_np_type: for k in mstype_to_np_type:
if k == mstype.string: if k == mstype.string:
continue continue
@ -74,7 +74,7 @@ def test_int_comparison():
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k) mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k)
def test_float_comparison(): def test_mask_float_comparison():
for k in mstype_to_np_type: for k in mstype_to_np_type:
if k == mstype.string: if k == mstype.string:
continue continue
@ -86,7 +86,7 @@ def test_float_comparison():
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k) mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k)
def test_float_comparison2(): def test_mask_float_comparison2():
for k in mstype_to_np_type: for k in mstype_to_np_type:
if k == mstype.string: if k == mstype.string:
continue continue
@ -98,7 +98,7 @@ def test_float_comparison2():
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k) mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k)
def test_string_comparison(): def test_mask_string_comparison():
for k in mstype_to_np_type: for k in mstype_to_np_type:
if k == mstype.string: if k == mstype.string:
continue continue
@ -125,8 +125,8 @@ def test_mask_exceptions_str():
if __name__ == "__main__": if __name__ == "__main__":
test_int_comparison() test_mask_int_comparison()
test_float_comparison() test_mask_float_comparison()
test_float_comparison2() test_mask_float_comparison2()
test_string_comparison() test_mask_string_comparison()
test_mask_exceptions_str() test_mask_exceptions_str()

Loading…
Cancel
Save