Add call for decoupled image and text ops

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>
pull/11853/head
alex-yuyue 4 years ago
parent 4cd6588af0
commit 6fd58dc580

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,10 +14,11 @@
* limitations under the License.
*/
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/core/tensor_row.h"
#ifdef ENABLE_ANDROID
#include "minddata/dataset/include/de_tensor.h"
#endif
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#ifndef ENABLE_ANDROID
@ -84,5 +85,25 @@ std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Te
return de_output;
}
Status Execute::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *output_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(op_ != nullptr, "Input TensorOperation is not valid");
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");
TensorRow input, output;
std::copy(input_tensor_list.begin(), input_tensor_list.end(), std::back_inserter(input));
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "Input Tensor is not valid");
std::shared_ptr<TensorOp> transform = op_->Build();
Status rc = transform->Compute(input, &output);
if (rc.IsError()) {
// execution failed
RETURN_STATUS_UNEXPECTED("Operation execution failed : " + rc.ToString());
}
std::copy(output.begin(), output.end(), std::back_inserter(*output_tensor_list));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,14 +28,26 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
auto execute = std::make_shared<Execute>(toTensorOperation(operation));
return execute;
}))
.def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) {
.def("__call__",
[](Execute &self, std::shared_ptr<Tensor> in) {
std::shared_ptr<Tensor> out = self(in);
if (out == nullptr) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
RETURN_STATUS_UNEXPECTED(
"Failed to execute op in eager mode, please check ERROR log above.");
}());
}
return out;
})
.def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) {
std::vector<std::shared_ptr<Tensor>> output_tensor_list;
THROW_IF_ERROR(self(input_tensor_list, &output_tensor_list));
if (output_tensor_list.empty()) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
}());
}
return output_tensor_list;
});
}));
} // namespace dataset

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,16 +43,23 @@ class Execute {
#ifdef ENABLE_ANDROID
/// \brief callable function to execute the TensorOperation in eager mode
/// \param[inout] input - the tensor to be transformed
/// \param[in] input - the tensor to be transformed
/// \return - the output tensor, nullptr if Compute fails
std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input);
#endif
/// \brief callable function to execute the TensorOperation in eager mode
/// \param[inout] input - the tensor to be transformed
/// \param[in] input - the tensor to be transformed
/// \return - the output tensor, nullptr if Compute fails
std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input);
/// \brief callable function to execute the TensorOperation in eager mode
/// \param[in] input_tensor_list - the tensor to be transformed
/// \param[out] out - the result tensor after transform
/// \return - Status
Status operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *out);
private:
std::shared_ptr<TensorOperation> op_;
};

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,6 +26,14 @@ from .validators import check_num_classes, check_de_type, check_fill_value, chec
from ..core.datatypes import mstype_to_detype
class TensorOperation:
def __call__(self):
raise NotImplementedError("TensorOperation has to implement __call__() method.")
def parse(self):
raise NotImplementedError("TensorOperation has to implement parse() method.")
class OneHot(cde.OneHotOp):
"""
Tensor operation to apply one hot encoding.
@ -304,7 +312,7 @@ class Unique(cde.UniqueOp):
Also return an index tensor that contains the index of each element of the
input tensor in the Unique output tensor.
Finally, return a count tensor that constains the count of each element of
Finally, return a count tensor that contains the count of each element of
the output tensor in the input tensor.
Note:

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -237,8 +237,8 @@ def check_compose_list(method):
type_check(transforms, (list,), transforms)
if not transforms:
raise ValueError("transforms list is empty.")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
for i, transform in enumerate(transforms):
if not callable(transform):
raise ValueError("transforms[{}] is not callable.".format(i))
return method(self, *args, **kwargs)
@ -269,9 +269,10 @@ def check_random_apply(method):
[transforms, prob], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
for i, transform in enumerate(transforms):
if str(transform).find("c_transform") >= 0:
raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
.format(i))
if prob is not None:
type_check(prob, (float, int,), "prob")
@ -290,9 +291,10 @@ def check_transforms_list(method):
[transforms], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
for i, transform in enumerate(transforms):
if str(transform).find("c_transform") >= 0:
raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
.format(i))
return method(self, *args, **kwargs)
return new_method

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,6 +29,20 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_HWC2CHW_callable():
"""
Test HWC2CHW is callable
"""
logger.info("Test HWC2CHW callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = c_vision.Decode()(img)
img = c_vision.HWC2CHW()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
assert img.shape == (3, 2268, 4032)
def test_HWC2CHW(plot=False):
"""
Test HWC2CHW
@ -122,6 +136,7 @@ def test_HWC2CHW_comp(plot=False):
if __name__ == '__main__':
test_HWC2CHW_callable()
test_HWC2CHW(True)
test_HWC2CHW_md5()
test_HWC2CHW_comp(True)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -219,7 +219,7 @@ def test_c_py_compose_vision_module(plot=False, run_golden=True):
def test_py_transforms_with_c_vision():
"""
These examples will fail, as py_transforms.Random(Apply/Choice/Order) expect callable functions
These examples will fail, as c_transform should not be used in py_transforms.Random(Apply/Choice/Order)
"""
ds.config.set_seed(0)
@ -236,15 +236,15 @@ def test_py_transforms_with_c_vision():
with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)]))
assert "transforms[0] is not callable." in str(error_info.value)
assert "transforms[0] is not a py transforms." in str(error_info.value)
with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)]))
assert "transforms[0] is not callable." in str(error_info.value)
assert "transforms[0] is not a py transforms." in str(error_info.value)
with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)]))
assert "transforms[1] is not callable." in str(error_info.value)
assert "transforms[1] is not a py transforms." in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
test_config([py_transforms.OneHotOp(20, 0.1)])

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,6 +29,21 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False
def test_invert_callable():
"""
Test Invert is callable
"""
logger.info("Test Invert callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = C.Decode()(img)
img = C.Invert()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
assert img.shape == (2268, 4032, 3)
def test_invert_py(plot=False):
"""
Test Invert python op
@ -247,6 +262,7 @@ def test_invert_md5_c():
if __name__ == "__main__":
test_invert_callable()
test_invert_py(plot=False)
test_invert_c(plot=False)
test_invert_py_c(plot=False)

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -34,6 +34,22 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False
def test_random_crop_and_resize_callable():
"""
Test RandomCropAndResize op is callable
"""
logger.info("test_random_crop_and_resize_callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
decode_op = c_vision.Decode()
img = decode_op(img)
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
img = random_crop_and_resize_op(img)
assert np.shape(img) == (256, 512, 3)
def test_random_crop_and_resize_op_c(plot=False):
"""
Test RandomCropAndResize op in c transforms
@ -389,6 +405,7 @@ def test_random_crop_and_resize_06():
if __name__ == "__main__":
test_random_crop_and_resize_callable()
test_random_crop_and_resize_op_c(True)
test_random_crop_and_resize_op_py(True)
test_random_crop_and_resize_op_py_ANTIALIAS()

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@ import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.text import JiebaTokenizer
from mindspore.dataset.text import JiebaMode, to_str
from mindspore import log as logger
DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
@ -24,6 +25,23 @@ HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8"
MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8"
def test_jieba_callable():
"""
Test jieba tokenizer op is callable
"""
logger.info("test_jieba_callable")
jieba_op1 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op2 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM)
text1 = "今天天气太好了我们一起去外面玩吧"
text2 = "男默女泪市长江大桥"
assert np.array_equal(jieba_op1(text1), ['今天天气', '太好了', '我们', '一起', '', '外面', '玩吧'])
assert np.array_equal(jieba_op2(text1), ['今天', '天气', '', '', '', '我们', '一起', '', '外面', '', ''])
jieba_op1.add_word("男默女泪")
assert np.array_equal(jieba_op1(text2), ['男默女泪', '', '长江大桥'])
def test_jieba_1():
"""Test jieba tokenizer with MP mode"""
data = ds.TextFileDataset(DATA_FILE)
@ -457,6 +475,7 @@ def test_jieba_6():
if __name__ == "__main__":
test_jieba_callable()
test_jieba_1()
test_jieba_1_1()
test_jieba_1_2()

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -28,6 +28,24 @@ from util import visualize_list, diff_mse
DATA_DIR = "../data/dataset/testImageNetData/train/"
def test_uniform_augment_callable(num_ops=2):
"""
Test UniformAugment is callable
"""
logger.info("test_uniform_augment_callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
decode_op = C.Decode()
img = decode_op(img)
transforms_ua = [C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]),
C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])]
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
img = uni_aug([img, img])
assert ((np.shape(img) == (2, 2268, 4032, 3)) or (np.shape(img) == (1, 400, 400, 3)))
def test_uniform_augment(plot=False, num_ops=2):
"""
Test UniformAugment
@ -262,6 +280,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
if __name__ == "__main__":
test_uniform_augment_callable(num_ops=2)
test_uniform_augment(num_ops=1, plot=True)
test_cpp_uniform_augment(num_ops=1, plot=True)
test_cpp_uniform_augment_exception_pyops(num_ops=1)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.common.dtype as mstype
from mindspore import log as logger
# this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt"
@ -25,6 +26,16 @@ VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"
def test_lookup_callable():
"""
Test lookup is callable
"""
logger.info("test_lookup_callable")
vocab = text.Vocab.from_list(['', '', '', '', ''])
lookup = text.Lookup(vocab)
word = ""
assert lookup(word) == 3
def test_from_list_tutorial():
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
lookup = text.Lookup(vocab, "<unk>")
@ -171,6 +182,7 @@ def test_lookup_cast_type():
if __name__ == '__main__':
test_lookup_callable()
test_from_dict_exception()
test_from_list_tutorial()
test_from_file_tutorial()

Loading…
Cancel
Save