You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/minddata/dataset/include/transforms.h

226 lines
7.4 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class TensorOp;
// Abstract class to represent a dataset in the data pipeline.
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
public:
/// \brief Constructor
TensorOperation();
/// \brief Destructor
~TensorOperation() = default;
/// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
/// \return shared pointer to the newly created TensorOp.
virtual std::shared_ptr<TensorOp> Build() = 0;
virtual Status ValidateParams() = 0;
};
// Helper function to validate fill value
Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value);
// Helper function to validate probability
Status ValidateProbability(const std::string &transform_name, const float &probability);
// Helper function to validate padding
Status ValidateVectorPadding(const std::string &transform_name, const std::vector<int32_t> &padding);
// Helper function to validate size
Status ValidateVectorPositive(const std::string &transform_name, const std::vector<int32_t> &size);
// Helper function to validate transforms
Status ValidateVectorTransforms(const std::string &transform_name,
const std::vector<std::shared_ptr<TensorOperation>> &transforms);
// Helper function to compare float value
bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f);
// Transform operations for performing data transformation.
namespace transforms {
// Transform Op classes (in alphabetical order)
class ComposeOperation;
class DuplicateOperation;
class OneHotOperation;
class RandomApplyOperation;
class RandomChoiceOperation;
class TypeCastOperation;
#ifndef ENABLE_ANDROID
class UniqueOperation;
#endif
/// \brief Function to create a Compose TensorOperation.
/// \notes Compose a list of transforms into a single transform.
/// \param[in] transforms A vector of transformations to be applied.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<ComposeOperation> Compose(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
/// \brief Function to create a Duplicate TensorOperation.
/// \notes Duplicate the input tensor to a new output tensor.
/// The input tensor is carried over to the output list.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<DuplicateOperation> Duplicate();
/// \brief Function to create a OneHot TensorOperation.
/// \notes Convert the labels into OneHot format.
/// \param[in] num_classes number of classes.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes);
/// \brief Function to create a RandomApply TensorOperation.
/// \notes Randomly perform a series of transforms with a given probability.
/// \param[in] transforms A vector of transformations to be applied.
/// \param[in] prob The probability to apply the transformation list (default=0.5)
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomApplyOperation> RandomApply(const std::vector<std::shared_ptr<TensorOperation>> &transforms,
double prob = 0.5);
/// \brief Function to create a RandomChoice TensorOperation.
/// \notes Randomly selects one transform from a list of transforms to perform operation.
/// \param[in] transforms A vector of transformations to be chosen from to apply.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomChoiceOperation> RandomChoice(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
/// \brief Function to create a TypeCast TensorOperation.
/// \notes Tensor operation to cast to a given MindSpore data type.
/// \param[in] data_type mindspore.dtype to be cast to.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<TypeCastOperation> TypeCast(std::string data_type);
#ifndef ENABLE_ANDROID
/// \brief Function to create a Unique TensorOperation.
/// \notes Return an output tensor containing all the unique elements of the input tensor in
/// the same order that they occur in the input tensor.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<UniqueOperation> Unique();
#endif
/* ####################################### Derived TensorOperation classes ################################# */
class ComposeOperation : public TensorOperation {
public:
explicit ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
~ComposeOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<TensorOperation>> transforms_;
};
class DuplicateOperation : public TensorOperation {
public:
DuplicateOperation() = default;
~DuplicateOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
};
class OneHotOperation : public TensorOperation {
public:
explicit OneHotOperation(int32_t num_classes_);
~OneHotOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
float num_classes_;
};
class RandomApplyOperation : public TensorOperation {
public:
explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
~RandomApplyOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<TensorOperation>> transforms_;
double prob_;
};
class RandomChoiceOperation : public TensorOperation {
public:
explicit RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
~RandomChoiceOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<TensorOperation>> transforms_;
};
class TypeCastOperation : public TensorOperation {
public:
explicit TypeCastOperation(std::string data_type);
~TypeCastOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::string data_type_;
};
#ifndef ENABLE_ANDROID
class UniqueOperation : public TensorOperation {
public:
UniqueOperation() = default;
~UniqueOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
};
#endif
} // namespace transforms
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_