!13789 dataset: Pybind change for data transforms

From: @cathwong
Reviewed-by: @robingrosman,@pandoublefeng
Signed-off-by: @pandoublefeng
pull/13789/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a250b8a1f7

@ -20,15 +20,11 @@
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/tensor_helpers.h" #include "minddata/dataset/core/tensor_helpers.h"
#include "minddata/dataset/kernels/data/concatenate_op.h" #include "minddata/dataset/kernels/data/concatenate_op.h"
#include "minddata/dataset/kernels/data/duplicate_op.h"
#include "minddata/dataset/kernels/data/fill_op.h" #include "minddata/dataset/kernels/data/fill_op.h"
#include "minddata/dataset/kernels/data/mask_op.h" #include "minddata/dataset/kernels/data/mask_op.h"
#include "minddata/dataset/kernels/data/one_hot_op.h"
#include "minddata/dataset/kernels/data/pad_end_op.h" #include "minddata/dataset/kernels/data/pad_end_op.h"
#include "minddata/dataset/kernels/data/slice_op.h" #include "minddata/dataset/kernels/data/slice_op.h"
#include "minddata/dataset/kernels/data/to_float16_op.h" #include "minddata/dataset/kernels/data/to_float16_op.h"
#include "minddata/dataset/kernels/data/type_cast_op.h"
#include "minddata/dataset/kernels/data/unique_op.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -38,15 +34,6 @@ PYBIND_REGISTER(ConcatenateOp, 1, ([](const py::module *m) {
.def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>()); .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>());
})); }));
PYBIND_REGISTER(
DuplicateOp, 1, ([](const py::module *m) {
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp").def(py::init<>());
}));
PYBIND_REGISTER(UniqueOp, 1, ([](const py::module *m) {
(void)py::class_<UniqueOp, TensorOp, std::shared_ptr<UniqueOp>>(*m, "UniqueOp").def(py::init<>());
}));
PYBIND_REGISTER( PYBIND_REGISTER(
FillOp, 1, ([](const py::module *m) { FillOp, 1, ([](const py::module *m) {
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>()); (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>());
@ -57,11 +44,6 @@ PYBIND_REGISTER(MaskOp, 1, ([](const py::module *m) {
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>()); .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
})); }));
PYBIND_REGISTER(
OneHotOp, 1, ([](const py::module *m) {
(void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(*m, "OneHotOp").def(py::init<int32_t>());
}));
PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) { PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) {
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(*m, "PadEndOp") (void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(*m, "PadEndOp")
.def(py::init<TensorShape, std::shared_ptr<Tensor>>()); .def(py::init<TensorShape, std::shared_ptr<Tensor>>());
@ -111,12 +93,6 @@ PYBIND_REGISTER(ToFloat16Op, 1, ([](const py::module *m) {
.def(py::init<>()); .def(py::init<>());
})); }));
PYBIND_REGISTER(TypeCastOp, 1, ([](const py::module *m) {
(void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(*m, "TypeCastOp")
.def(py::init<DataType>())
.def(py::init<std::string>());
}));
PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) { PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) {
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic()) (void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
.value("EQ", RelationalOp::kEqual) .value("EQ", RelationalOp::kEqual)

@ -64,6 +64,28 @@ PYBIND_REGISTER(
})); }));
})); }));
PYBIND_REGISTER(
DuplicateOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::DuplicateOperation, TensorOperation, std::shared_ptr<transforms::DuplicateOperation>>(
*m, "DuplicateOperation")
.def(py::init([]() {
auto duplicate = std::make_shared<transforms::DuplicateOperation>();
THROW_IF_ERROR(duplicate->ValidateParams());
return duplicate;
}));
}));
PYBIND_REGISTER(
OneHotOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::OneHotOperation, TensorOperation, std::shared_ptr<transforms::OneHotOperation>>(
*m, "OneHotOperation")
.def(py::init([](int32_t num_classes) {
auto one_hot = std::make_shared<transforms::OneHotOperation>(num_classes);
THROW_IF_ERROR(one_hot->ValidateParams());
return one_hot;
}));
}));
PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::RandomChoiceOperation, TensorOperation, (void)py::class_<transforms::RandomChoiceOperation, TensorOperation,
std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation") std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation")
@ -87,5 +109,28 @@ PYBIND_REGISTER(RandomApplyOperation, 1, ([](const py::module *m) {
return random_apply; return random_apply;
})); }));
})); }));
PYBIND_REGISTER(
TypeCastOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::TypeCastOperation, TensorOperation, std::shared_ptr<transforms::TypeCastOperation>>(
*m, "TypeCastOperation")
.def(py::init([](std::string data_type) {
auto type_cast = std::make_shared<transforms::TypeCastOperation>(data_type);
THROW_IF_ERROR(type_cast->ValidateParams());
return type_cast;
}));
}));
PYBIND_REGISTER(
UniqueOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::UniqueOperation, TensorOperation, std::shared_ptr<transforms::UniqueOperation>>(
*m, "UniqueOperation")
.def(py::init([]() {
auto unique = std::make_shared<transforms::UniqueOperation>();
THROW_IF_ERROR(unique->ValidateParams());
return unique;
}));
}));
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -64,7 +64,7 @@ std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<Du
// Constructor to OneHot // Constructor to OneHot
struct OneHot::Data { struct OneHot::Data {
explicit Data(int32_t num_classes) : num_classes_(num_classes) {} explicit Data(int32_t num_classes) : num_classes_(num_classes) {}
float num_classes_; int32_t num_classes_;
}; };
OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {} OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {}

@ -351,7 +351,7 @@ class SentencePieceTokenizer final : public TensorTransform {
/// \param[in] vocab a SentencePieceVocab object. /// \param[in] vocab a SentencePieceVocab object.
/// \param[in] out_type The type of output. /// \param[in] out_type The type of output.
SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab, SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab,
mindspore::dataset::SPieceTokenizerOutType out_typee); mindspore::dataset::SPieceTokenizerOutType out_type);
/// \brief Constructor. /// \brief Constructor.
/// \param[in] vocab_path vocab model file path. /// \param[in] vocab_path vocab model file path.
@ -398,14 +398,14 @@ class SlidingWindow final : public TensorTransform {
}; };
/// \brief Tensor operation to convert every element of a string tensor to a number. /// \brief Tensor operation to convert every element of a string tensor to a number.
/// Strings are casted according to the rules specified in the following links: /// Strings are cast according to the rules specified in the following links:
/// https://en.cppreference.com/w/cpp/string/basic_string/stof, /// https://en.cppreference.com/w/cpp/string/basic_string/stof,
/// https://en.cppreference.com/w/cpp/string/basic_string/stoul, /// https://en.cppreference.com/w/cpp/string/basic_string/stoul,
/// except that any strings which represent negative numbers cannot be cast to an unsigned integer type. /// except that any strings which represent negative numbers cannot be cast to an unsigned integer type.
class ToNumber final : public TensorTransform { class ToNumber final : public TensorTransform {
public: public:
/// \brief Constructor. /// \brief Constructor.
/// \param[in] data_type of the tensor to be casted to. Must be a numeric type. /// \param[in] data_type of the tensor to be cast to. Must be a numeric type.
explicit ToNumber(const std::string &data_type) : ToNumber(StringToChar(data_type)) {} explicit ToNumber(const std::string &data_type) : ToNumber(StringToChar(data_type)) {}
explicit ToNumber(const std::vector<char> &data_type); explicit ToNumber(const std::vector<char> &data_type);

@ -38,11 +38,5 @@ Status OneHotOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector
return Status(StatusCode::kMDUnexpectedError, "OneHot: invalid input shape."); return Status(StatusCode::kMDUnexpectedError, "OneHot: invalid input shape.");
} }
Status OneHotOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_classes"] = num_classes_;
*out_json = args;
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -37,8 +37,6 @@ class OneHotOp : public TensorOp {
std::string Name() const override { return kOneHotOp; } std::string Name() const override { return kOneHotOp; }
Status to_json(nlohmann::json *out_json) override;
private: private:
int num_classes_; int num_classes_;
}; };

@ -34,11 +34,5 @@ Status TypeCastOp::OutputType(const std::vector<DataType> &inputs, std::vector<D
return Status::OK(); return Status::OK();
} }
Status TypeCastOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["data_type"] = type_.ToString();
*out_json = args;
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -43,8 +43,6 @@ class TypeCastOp : public TensorOp {
std::string Name() const override { return kTypeCastOp; } std::string Name() const override { return kTypeCastOp; }
Status to_json(nlohmann::json *out_json) override;
private: private:
DataType type_; DataType type_;
}; };

@ -78,6 +78,13 @@ Status OneHotOperation::ValidateParams() {
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
Status OneHotOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_classes"] = num_classes_;
*out_json = args;
return Status::OK();
}
// PreBuiltOperation // PreBuiltOperation
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) { PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@ -149,6 +156,13 @@ Status TypeCastOperation::ValidateParams() {
std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
Status TypeCastOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["data_type"] = data_type_;
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// UniqueOperation // UniqueOperation
Status UniqueOperation::ValidateParams() { return Status::OK(); } Status UniqueOperation::ValidateParams() { return Status::OK(); }

@ -81,8 +81,10 @@ class OneHotOperation : public TensorOperation {
std::string Name() const override { return kOneHotOperation; } std::string Name() const override { return kOneHotOperation; }
Status to_json(nlohmann::json *out_json) override;
private: private:
float num_classes_; int32_t num_classes_;
}; };
class PreBuiltOperation : public TensorOperation { class PreBuiltOperation : public TensorOperation {
@ -147,6 +149,8 @@ class TypeCastOperation : public TensorOperation {
std::string Name() const override { return kTypeCastOperation; } std::string Name() const override { return kTypeCastOperation; }
Status to_json(nlohmann::json *out_json) override;
private: private:
std::string data_type_; std::string data_type_;
}; };

@ -362,8 +362,7 @@ def construct_tensor_ops(operations):
if hasattr(op_module_vis, op_name): if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None) op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]): elif hasattr(op_module_trans, op_name):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name, None) op_class = getattr(op_module_trans, op_name, None)
else: else:
raise RuntimeError(op_name + " is not yet supported by deserialize().") raise RuntimeError(op_name + " is not yet supported by deserialize().")

@ -387,18 +387,18 @@ class ToNumber(TextTensorOperation):
""" """
Tensor operation to convert every element of a string tensor to a number. Tensor operation to convert every element of a string tensor to a number.
Strings are casted according to the rules specified in the following links: Strings are cast according to the rules specified in the following links:
https://en.cppreference.com/w/cpp/string/basic_string/stof, https://en.cppreference.com/w/cpp/string/basic_string/stof,
https://en.cppreference.com/w/cpp/string/basic_string/stoul, https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be cast to an except that any strings which represent negative numbers cannot be cast to an
unsigned integer type. unsigned integer type.
Args: Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be data_type (mindspore.dtype): mindspore.dtype to be cast to. Must be
a numeric type. a numeric type.
Raises: Raises:
RuntimeError: If strings are invalid to cast, or are out of range after being casted. RuntimeError: If strings are invalid to cast, or are out of range after being cast.
Examples: Examples:
>>> import mindspore.common.dtype as mstype >>> import mindspore.common.dtype as mstype

@ -21,7 +21,7 @@ import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_option, check_slice_op, \ from .validators import check_num_classes, check_ms_type, check_fill_value, check_slice_option, check_slice_op, \
check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops
from ..core.datatypes import mstype_to_detype from ..core.datatypes import mstype_to_detype
@ -52,7 +52,7 @@ class TensorOperation:
raise NotImplementedError("TensorOperation has to implement parse() method.") raise NotImplementedError("TensorOperation has to implement parse() method.")
class OneHot(cde.OneHotOp): class OneHot(TensorOperation):
""" """
Tensor operation to apply one hot encoding. Tensor operation to apply one hot encoding.
@ -72,7 +72,9 @@ class OneHot(cde.OneHotOp):
@check_num_classes @check_num_classes
def __init__(self, num_classes): def __init__(self, num_classes):
self.num_classes = num_classes self.num_classes = num_classes
super().__init__(num_classes)
def parse(self):
return cde.OneHotOperation(self.num_classes)
class Fill(cde.FillOp): class Fill(cde.FillOp):
@ -102,7 +104,7 @@ class Fill(cde.FillOp):
super().__init__(cde.Tensor(np.array(fill_value))) super().__init__(cde.Tensor(np.array(fill_value)))
class TypeCast(cde.TypeCastOp): class TypeCast(TensorOperation):
""" """
Tensor operation to cast to a given MindSpore data type. Tensor operation to cast to a given MindSpore data type.
@ -123,11 +125,13 @@ class TypeCast(cde.TypeCastOp):
>>> dataset = dataset.map(operations=type_cast_op) >>> dataset = dataset.map(operations=type_cast_op)
""" """
@check_de_type @check_ms_type
def __init__(self, data_type): def __init__(self, data_type):
data_type = mstype_to_detype(data_type) data_type = mstype_to_detype(data_type)
self.data_type = str(data_type) self.data_type = str(data_type)
super().__init__(data_type)
def parse(self):
return cde.TypeCastOperation(self.data_type)
class _SliceOption(cde.SliceOption): class _SliceOption(cde.SliceOption):
@ -314,7 +318,7 @@ class Concatenate(cde.ConcatenateOp):
super().__init__(axis, prepend, append) super().__init__(axis, prepend, append)
class Duplicate(cde.DuplicateOp): class Duplicate(TensorOperation):
""" """
Duplicate the input tensor to output, only support transform one column each time. Duplicate the input tensor to output, only support transform one column each time.
@ -337,8 +341,11 @@ class Duplicate(cde.DuplicateOp):
>>> # +---------+---------+ >>> # +---------+---------+
""" """
def parse(self):
return cde.DuplicateOperation()
class Unique(cde.UniqueOp): class Unique(TensorOperation):
""" """
Perform the unique operation on the input tensor, only support transform one column each time. Perform the unique operation on the input tensor, only support transform one column each time.
@ -373,9 +380,11 @@ class Unique(cde.UniqueOp):
>>> # +---------+-----------------+---------+ >>> # +---------+-----------------+---------+
""" """
def parse(self):
return cde.UniqueOperation()
class Compose(): class Compose(TensorOperation):
""" """
Compose a list of transforms into a single transform. Compose a list of transforms into a single transform.
@ -401,7 +410,7 @@ class Compose():
return cde.ComposeOperation(operations) return cde.ComposeOperation(operations)
class RandomApply(): class RandomApply(TensorOperation):
""" """
Randomly perform a series of transforms with a given probability. Randomly perform a series of transforms with a given probability.
@ -429,7 +438,7 @@ class RandomApply():
return cde.RandomApplyOperation(self.prob, operations) return cde.RandomApplyOperation(self.prob, operations)
class RandomChoice(): class RandomChoice(TensorOperation):
""" """
Randomly select one transform from a list of transforms to perform operation. Randomly select one transform from a list of transforms to perform operation.

@ -87,7 +87,7 @@ def check_num_classes(method):
return new_method return new_method
def check_de_type(method): def check_ms_type(method):
"""Wrapper method to check the parameters of data type.""" """Wrapper method to check the parameters of data type."""
@wraps(method) @wraps(method)

Loading…
Cancel
Save