From f2462bb00de8793efd7d2f3ff2faea89537e28af Mon Sep 17 00:00:00 2001 From: hesham Date: Tue, 16 Jun 2020 22:46:57 -0400 Subject: [PATCH] Mask Op --- .../ccsrc/dataset/api/python_bindings.cc | 16 ++- mindspore/ccsrc/dataset/core/tensor.cc | 12 +- mindspore/ccsrc/dataset/core/tensor.h | 2 +- .../ccsrc/dataset/kernels/data/CMakeLists.txt | 13 +- .../ccsrc/dataset/kernels/data/data_utils.cc | 101 +++++++++++++- .../ccsrc/dataset/kernels/data/data_utils.h | 29 ++++ .../ccsrc/dataset/kernels/data/mask_op.cc | 49 +++++++ .../ccsrc/dataset/kernels/data/mask_op.h | 54 +++++++ .../ccsrc/dataset/kernels/data/slice_op.cc | 3 +- .../ccsrc/dataset/kernels/data/slice_op.h | 6 +- mindspore/dataset/transforms/c_transforms.py | 56 +++++++- mindspore/dataset/transforms/validators.py | 37 +++++ tests/ut/cpp/dataset/mask_test.cc | 63 +++++++++ tests/ut/python/dataset/test_mask_op.py | 132 ++++++++++++++++++ tests/ut/python/dataset/test_slice_op.py | 12 +- 15 files changed, 560 insertions(+), 25 deletions(-) create mode 100644 mindspore/ccsrc/dataset/kernels/data/mask_op.cc create mode 100644 mindspore/ccsrc/dataset/kernels/data/mask_op.h create mode 100644 tests/ut/cpp/dataset/mask_test.cc create mode 100644 tests/ut/python/dataset/test_mask_op.py diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 293a443ffe..79c6e9c5dc 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -38,6 +38,7 @@ #include "dataset/kernels/image/resize_op.h" #include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/data/fill_op.h" +#include "dataset/kernels/data/mask_op.h" #include "dataset/kernels/data/slice_op.h" #include "dataset/kernels/data/type_cast_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" @@ -369,7 +370,7 @@ void bindTensorOps2(py::module *m) { *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") .def(py::init>()); - (void)py::class_>(*m, "SliceOp", "") + (void)py::class_>(*m, "SliceOp", "Tensor Slice operation.") .def(py::init()) .def(py::init([](const py::list &py_list) { std::vector c_list; @@ -400,6 +401,19 @@ void bindTensorOps2(py::module *m) { return std::make_shared(c_slice); })); + (void)py::enum_(*m, "RelationalOp", py::arithmetic()) + .value("EQ", RelationalOp::kEqual) + .value("NE", RelationalOp::kNotEqual) + .value("LT", RelationalOp::kLess) + .value("LE", RelationalOp::kLessEqual) + .value("GT", RelationalOp::kGreater) + .value("GE", RelationalOp::kGreaterEqual) + .export_values(); + + (void)py::class_>(*m, "MaskOp", + "Tensor operation mask using relational comparator") + .def(py::init, DataType>()); + (void)py::class_>( *m, "RandomRotationOp", "Tensor operation to apply RandomRotation." diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index 2b587a115a..074603f833 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -699,7 +699,7 @@ Status Tensor::GetItemAt(T *o, const std::vector &index) const { Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) const { RETURN_UNEXPECTED_IF_NULL(data_); RETURN_UNEXPECTED_IF_NULL(o); - CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not DE_STRING"); + CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string"); uchar *start = nullptr; offset_t length = 0; @@ -932,17 +932,17 @@ Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vectordata_; dsize_t count = 1; for (dsize_t i = 0; i < indices.size(); i++) { - dsize_t cur_index = handleNeg(indices[i], dim_length); + dsize_t cur_index = HandleNeg(indices[i], dim_length); CHECK_FAIL_RETURN_UNEXPECTED( cur_index >= 0 && cur_index < dim_length, "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); if (i < indices.size() - 1) { - dsize_t next_index = handleNeg(indices[i + 1], dim_length); + dsize_t next_index = HandleNeg(indices[i + 1], dim_length); if (next_index == cur_index + 1) { count++; continue; @@ -951,7 +951,7 @@ Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vectorSizeInBytes(), data_ + src_start * type_size, count * type_size); out_index += count; if (i < indices.size() - 1) { - src_start = handleNeg(indices[i + 1], dim_length); // next index + src_start = HandleNeg(indices[i + 1], dim_length); // next index } count = 1; } @@ -961,7 +961,7 @@ Status Tensor::SliceString(std::shared_ptr *out, const std::vector strings; for (dsize_t index : indices) { - dsize_t cur_index = handleNeg(index, dim_length); + dsize_t cur_index = HandleNeg(index, dim_length); CHECK_FAIL_RETURN_UNEXPECTED( cur_index >= 0 && cur_index < dim_length, "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); diff --git a/mindspore/ccsrc/dataset/core/tensor.h b/mindspore/ccsrc/dataset/core/tensor.h index 2953b7df81..ad503a9290 100644 --- a/mindspore/ccsrc/dataset/core/tensor.h +++ b/mindspore/ccsrc/dataset/core/tensor.h @@ -348,7 +348,7 @@ class Tensor { } // Handle negative indices. - static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } + static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. // Based on the type of tensor, SliceNumeric or SliceString will be called diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt index 53a1ea6151..03457ca4f5 100644 --- a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt @@ -1,9 +1,10 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(kernels-data OBJECT - data_utils.cc - one_hot_op.cc - type_cast_op.cc - to_float16_op.cc - fill_op.cc - slice_op.cc) + data_utils.cc + one_hot_op.cc + type_cast_op.cc + to_float16_op.cc + fill_op.cc + slice_op.cc + mask_op.cc) diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc index 85c4cfc67c..532c3e3cc6 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc @@ -120,7 +120,7 @@ Status Fill(const std::shared_ptr input, std::shared_ptr *output std::unique_ptr op(new TypeCastOp(to)); std::shared_ptr fill_output; - op->Compute(fill_value, &fill_output); + RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); @@ -344,6 +344,8 @@ Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, return PadEndString(src, dst, pad_shape, ""); } } + CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(), + "Source and pad_value tensors are not of the same type."); if (pad_val->type().IsNumeric()) { float val = 0; RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {})); @@ -454,5 +456,102 @@ Status PadEndStringHelper(const std::shared_ptr &src, std::vector +Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, + const std::shared_ptr &value_tensor, RelationalOp op) { + T value; + RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {})); + auto in_itr = input->begin(); + auto out_itr = output->begin(); + for (; in_itr != input->end(); in_itr++, out_itr++) { + switch (op) { + case RelationalOp::kEqual: + *out_itr = (*in_itr == value); + break; + case RelationalOp::kNotEqual: + *out_itr = (*in_itr != value); + break; + case RelationalOp::kGreater: + *out_itr = (*in_itr > value); + break; + case RelationalOp::kGreaterEqual: + *out_itr = (*in_itr >= value); + break; + case RelationalOp::kLess: + *out_itr = (*in_itr < value); + break; + case RelationalOp::kLessEqual: + *out_itr = (*in_itr <= value); + break; + default: + RETURN_STATUS_UNEXPECTED("Unknown relational operator."); + } + } + return Status::OK(); +} + +Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, + RelationalOp op) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(), + "Cannot convert constant value to the type of the input tensor."); + CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); + + std::unique_ptr value_cast_op(new TypeCastOp(input->type())); + std::shared_ptr casted_value; + if (input->type().IsNumeric()) { + RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value)); + } else { + casted_value = value; + } + + switch (input->type().value()) { + case DataType::DE_BOOL: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT8: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT8: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_STRING: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UNKNOWN: + RETURN_STATUS_UNEXPECTED("Unsupported input type."); + break; + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/dataset/kernels/data/data_utils.h index f2faee02dc..4dec0f0470 100644 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/dataset/kernels/data/data_utils.h @@ -119,6 +119,35 @@ Status PadEndString(const std::shared_ptr &src, std::shared_ptr Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, const std::string &pad_value); + +enum class RelationalOp { + kEqual = 0, // == + kNotEqual, // != + kLess, // < + kLessEqual, // <= + kGreater, // > + kGreaterEqual, // >= +}; + +/// Helper method that masks the input tensor +/// @tparam T type of the tensor +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @param value_tensor[in] scalar tensor value to compared with +/// @param op[in] RelationalOp enum +/// @return Status ok/error +template +Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, + const std::shared_ptr &value_tensor, RelationalOp op); + +/// Mask the input tensor +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @param value[in] scalar tensor value to compared with +/// @param op[in] RelationalOp enum +/// @return Status ok/error +Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, + RelationalOp op); } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc b/mindspore/ccsrc/dataset/kernels/data/mask_op.cc new file mode 100644 index 0000000000..ba98ab5892 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/mask_op.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/kernels/data/mask_op.h" + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status MaskOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::shared_ptr temp_output; + CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric."); + + RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_)); + + // cast the output to the the required type. Skip casting if type_ is bool. + if (type_ != DataType::DE_BOOL) { + RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); + } else { + *output = temp_output; + } + + return Status::OK(); +} + +Status MaskOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = type_; + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.h b/mindspore/ccsrc/dataset/kernels/data/mask_op.h new file mode 100644 index 0000000000..0affe543bb --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/mask_op.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_MASK_OP_H_ +#define DATASET_KERNELS_DATA_MASK_OP_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/kernels/data/type_cast_op.h" +#include "dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class MaskOp : public TensorOp { + public: + MaskOp(RelationalOp op, std::shared_ptr value, DataType type = DataType(DataType::DE_BOOL)) + : op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {} + + ~MaskOp() override = default; + + void Print(std::ostream &out) const override { out << "MaskOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + private: + RelationalOp op_; + std::shared_ptr value_; + DataType type_; + std::unique_ptr cast_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_MASK_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc index c90b76b9d9..8401a33511 100644 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc +++ b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ #include "dataset/kernels/data/slice_op.h" #include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" #include "dataset/kernels/tensor_op.h" namespace mindspore { diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/dataset/kernels/data/slice_op.h index cb1318dcba..1bc7f0d5b9 100644 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.h +++ b/mindspore/ccsrc/dataset/kernels/data/slice_op.h @@ -36,8 +36,8 @@ class Slice { std::vector Indices(dsize_t length) { std::vector indices; - dsize_t index = std::min(Tensor::handleNeg(start_, length), length); - dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length); + dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); + dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length); if (step_ > 0) { for (; index < end_index; index += step_) { indices.push_back(index); @@ -80,4 +80,4 @@ class SliceOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ +#endif // DATASET_KERNELS_DATA_SLICE_OP_H_ diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index fd9fb12d6a..a0fb49aa77 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -15,10 +15,14 @@ """ This module c_transforms provides common operations, including OneHotOp and TypeCast. """ -import numpy as np +from enum import IntEnum + +import mindspore.common.dtype as mstype import mindspore._c_dataengine as cde -from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op +import numpy as np + +from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op from ..core.datatypes import mstype_to_detype @@ -48,7 +52,6 @@ class Fill(cde.FillOp): @check_fill_value def __init__(self, fill_value): - print(fill_value) super().__init__(cde.Tensor(np.array(fill_value))) @@ -108,3 +111,50 @@ class Slice(cde.SliceOp): elif dim0 is Ellipsis: dim0 = True super().__init__(dim0) + + +class Relational(IntEnum): + EQ = 0 + NE = 1 + GT = 2 + GE = 3 + LT = 4 + LE = 5 + + +DE_C_RELATIONAL = {Relational.EQ: cde.RelationalOp.EQ, + Relational.NE: cde.RelationalOp.NE, + Relational.GT: cde.RelationalOp.GT, + Relational.GE: cde.RelationalOp.GE, + Relational.LT: cde.RelationalOp.LT, + Relational.LE: cde.RelationalOp.LE} + + +class Mask(cde.MaskOp): + """ + Mask content of the input tensor with the given predicate. + Any element of the tensor that matches the predicate will be evaluated to True, otherwise False. + Args: + operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE + constant (python types (str, int, float, or bool): constant to be compared to. + Constant will be casted to the type of the input tensor + dtype (optional, mindspore.dtype): type of the generated mask. Default to bool + Examples: + >>> # Data before + >>> # | col1 | + >>> # +---------+ + >>> # | [1,2,3] | + >>> # +---------+ + >>> data = data.map(operations=Mask(Relational.EQ, 2)) + >>> # Data after + >>> # | col1 | + >>> # +--------------------+ + >>> # | [False,True,False] | + >>> # +--------------------+ + """ + + @check_mask_op + def __init__(self, operator, constant, dtype=mstype.bool_): + dtype = mstype_to_detype(dtype) + constant = cde.Tensor(np.array(constant)) + super().__init__(DE_C_RELATIONAL[operator], constant, dtype) diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 0ca0d60412..4033b573ca 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -213,3 +213,40 @@ def check_slice_op(method): return method(self, *args) return new_method + + +def check_mask_op(method): + """Wrapper method to check the parameters of slice.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + operator, constant, dtype = (list(args) + 3 * [None])[:3] + if "operator" in kwargs: + operator = kwargs.get("operator") + if "constant" in kwargs: + constant = kwargs.get("constant") + if "dtype" in kwargs: + dtype = kwargs.get("dtype") + + if operator is None: + raise ValueError("operator is not provided.") + if constant is None: + raise ValueError("constant is not provided.") + + from .c_transforms import Relational + if not isinstance(operator, Relational): + raise TypeError("operator is not a Relational operator enum.") + + if not isinstance(constant, (str, float, bool, int)): + raise TypeError("constant must be either a primitive python str, float, bool, or int") + + if not isinstance(dtype, typing.Type): + raise TypeError("dtype is not a MindSpore data type.") + + kwargs["operator"] = operator + kwargs["constant"] = constant + kwargs["dtype"] = dtype + + return method(self, **kwargs) + + return new_method diff --git a/tests/ut/cpp/dataset/mask_test.cc b/tests/ut/cpp/dataset/mask_test.cc new file mode 100644 index 0000000000..b2220f2a3f --- /dev/null +++ b/tests/ut/cpp/dataset/mask_test.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "securec.h" +#include "dataset/core/tensor.h" +#include "dataset/core/cv_tensor.h" +#include "dataset/core/data_type.h" +#include "dataset/util/de_error.h" +#include "dataset/kernels/data/mask_op.h" +#include "dataset/kernels/data/data_utils.h" + +using namespace mindspore::dataset; + +namespace py = pybind11; + +class MindDataTestMaskOp : public UT::Common { + public: + MindDataTestMaskOp() {} + + void SetUp() { GlobalInit(); } +}; + +TEST_F(MindDataTestMaskOp, Basics) { + std::shared_ptr t; + Tensor::CreateTensor(&t, std::vector({1, 2, 3, 4, 5, 6})); + std::shared_ptr v; + Tensor::CreateTensor(&v, std::vector({3}), TensorShape::CreateScalar()); + std::shared_ptr op = std::make_shared(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16)); + std::shared_ptr out; + ASSERT_TRUE(op->Compute(t, &out).IsOk()); + + op = std::make_shared(RelationalOp::kNotEqual, v, DataType(DataType::DE_UINT16)); + ASSERT_TRUE(op->Compute(t, &out).IsOk()); + + op = std::make_shared(RelationalOp::kLessEqual, v, DataType(DataType::DE_UINT16)); + ASSERT_TRUE(op->Compute(t, &out).IsOk()); + + op = std::make_shared(RelationalOp::kLess, v, DataType(DataType::DE_UINT16)); + ASSERT_TRUE(op->Compute(t, &out).IsOk()); + + op = std::make_shared(RelationalOp::kGreaterEqual, v, DataType(DataType::DE_UINT16)); + ASSERT_TRUE(op->Compute(t, &out).IsOk()); + + op = std::make_shared(RelationalOp::kGreater, v, DataType(DataType::DE_UINT16)); + ASSERT_TRUE(op->Compute(t, &out).IsOk()); +} diff --git a/tests/ut/python/dataset/test_mask_op.py b/tests/ut/python/dataset/test_mask_op.py new file mode 100644 index 0000000000..878f786f97 --- /dev/null +++ b/tests/ut/python/dataset/test_mask_op.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Mask op in DE +""" +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops + +mstype_to_np_type = { + mstype.bool_: np.bool, + mstype.int8: np.int8, + mstype.uint8: np.uint8, + mstype.int16: np.int16, + mstype.uint16: np.uint16, + mstype.int32: np.int32, + mstype.uint32: np.uint32, + mstype.int64: np.int64, + mstype.uint64: np.uint64, + mstype.float16: np.float16, + mstype.float32: np.float32, + mstype.float64: np.float64, + mstype.string: np.str +} + + +def mask_compare(array, op, constant, dtype=mstype.bool_): + data = ds.NumpySlicesDataset([array]) + array = np.array(array) + data = data.map(operations=ops.Mask(op, constant, dtype)) + for d in data: + if op == ops.Relational.EQ: + array = array == np.array(constant, dtype=array.dtype) + elif op == ops.Relational.NE: + array = array != np.array(constant, dtype=array.dtype) + elif op == ops.Relational.GT: + array = array > np.array(constant, dtype=array.dtype) + elif op == ops.Relational.GE: + array = array >= np.array(constant, dtype=array.dtype) + elif op == ops.Relational.LT: + array = array < np.array(constant, dtype=array.dtype) + elif op == ops.Relational.LE: + array = array <= np.array(constant, dtype=array.dtype) + + array = array.astype(dtype=mstype_to_np_type[dtype]) + + np.testing.assert_array_equal(array, d[0]) + + +def test_int_comparison(): + for k in mstype_to_np_type: + if k == mstype.string: + continue + mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k) + + +def test_float_comparison(): + for k in mstype_to_np_type: + if k == mstype.string: + continue + mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.EQ, 3, k) + mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.NE, 3, k) + mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LT, 3, k) + mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LE, 3, k) + mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GT, 3, k) + mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k) + + +def test_float_comparison2(): + for k in mstype_to_np_type: + if k == mstype.string: + continue + mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3.5, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3.5, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3.5, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3.5, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3.5, k) + mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k) + + +def test_string_comparison(): + for k in mstype_to_np_type: + if k == mstype.string: + continue + mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.EQ, "3.", k) + mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.NE, "3.", k) + mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LT, "3.", k) + mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LE, "3.", k) + mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GT, "3.", k) + mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GE, "3.", k) + + +def test_mask_exceptions_str(): + with pytest.raises(RuntimeError) as info: + mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, "3.5") + assert "Cannot convert constant value to the type of the input tensor." in str(info.value) + + with pytest.raises(RuntimeError) as info: + mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, 3.5) + assert "Cannot convert constant value to the type of the input tensor." in str(info.value) + + with pytest.raises(RuntimeError) as info: + mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, "3.5", mstype.string) + assert "Cannot generate a string mask. Type should be numeric." in str(info.value) + + +if __name__ == "__main__": + test_int_comparison() + test_float_comparison() + test_float_comparison2() + test_string_comparison() + test_mask_exceptions_str() diff --git a/tests/ut/python/dataset/test_slice_op.py b/tests/ut/python/dataset/test_slice_op.py index 10038b9e2b..fd5e8baac9 100644 --- a/tests/ut/python/dataset/test_slice_op.py +++ b/tests/ut/python/dataset/test_slice_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """ -Testing TypeCast op in DE +Testing Slice op in DE """ import numpy as np import pytest @@ -109,6 +109,10 @@ def test_slice_exceptions(): slice_compare([1, 2, 3, 4, 5], slice(0)) assert "Indices are empty, generated tensor would be empty." in str(info.value) + with pytest.raises(RuntimeError) as info: + slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + with pytest.raises(RuntimeError) as info: slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) assert "Indices are empty, generated tensor would be empty." in str(info.value) @@ -182,6 +186,10 @@ def test_slice_exceptions_str(): slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) assert "Indices are empty, generated tensor would be empty." in str(info.value) + with pytest.raises(RuntimeError) as info: + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + with pytest.raises(RuntimeError) as info: slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) assert "Indices are empty, generated tensor would be empty." in str(info.value)