You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
226 lines
7.4 KiB
226 lines
7.4 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
#include "minddata/dataset/core/constants.h"
|
|
#include "minddata/dataset/util/status.h"
|
|
|
|
namespace mindspore {
|
|
namespace dataset {
|
|
|
|
class TensorOp;
|
|
|
|
// Abstract class to represent a dataset in the data pipeline.
|
|
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
|
public:
|
|
/// \brief Constructor
|
|
TensorOperation();
|
|
|
|
/// \brief Destructor
|
|
~TensorOperation() = default;
|
|
|
|
/// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
|
|
/// \return shared pointer to the newly created TensorOp.
|
|
virtual std::shared_ptr<TensorOp> Build() = 0;
|
|
|
|
virtual Status ValidateParams() = 0;
|
|
};
|
|
|
|
// Helper function to validate fill value
|
|
Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value);
|
|
|
|
// Helper function to validate probability
|
|
Status ValidateProbability(const std::string &transform_name, const float &probability);
|
|
|
|
// Helper function to validate padding
|
|
Status ValidateVectorPadding(const std::string &transform_name, const std::vector<int32_t> &padding);
|
|
|
|
// Helper function to validate size
|
|
Status ValidateVectorPositive(const std::string &transform_name, const std::vector<int32_t> &size);
|
|
|
|
// Helper function to validate transforms
|
|
Status ValidateVectorTransforms(const std::string &transform_name,
|
|
const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
|
|
|
// Helper function to compare float value
|
|
bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f);
|
|
|
|
// Transform operations for performing data transformation.
|
|
namespace transforms {
|
|
|
|
// Transform Op classes (in alphabetical order)
|
|
class ComposeOperation;
|
|
class DuplicateOperation;
|
|
class OneHotOperation;
|
|
class RandomApplyOperation;
|
|
class RandomChoiceOperation;
|
|
class TypeCastOperation;
|
|
#ifndef ENABLE_ANDROID
|
|
class UniqueOperation;
|
|
#endif
|
|
|
|
/// \brief Function to create a Compose TensorOperation.
|
|
/// \notes Compose a list of transforms into a single transform.
|
|
/// \param[in] transforms A vector of transformations to be applied.
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<ComposeOperation> Compose(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
|
|
|
/// \brief Function to create a Duplicate TensorOperation.
|
|
/// \notes Duplicate the input tensor to a new output tensor.
|
|
/// The input tensor is carried over to the output list.
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<DuplicateOperation> Duplicate();
|
|
|
|
/// \brief Function to create a OneHot TensorOperation.
|
|
/// \notes Convert the labels into OneHot format.
|
|
/// \param[in] num_classes number of classes.
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes);
|
|
|
|
/// \brief Function to create a RandomApply TensorOperation.
|
|
/// \notes Randomly perform a series of transforms with a given probability.
|
|
/// \param[in] transforms A vector of transformations to be applied.
|
|
/// \param[in] prob The probability to apply the transformation list (default=0.5)
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<RandomApplyOperation> RandomApply(const std::vector<std::shared_ptr<TensorOperation>> &transforms,
|
|
double prob = 0.5);
|
|
|
|
/// \brief Function to create a RandomChoice TensorOperation.
|
|
/// \notes Randomly selects one transform from a list of transforms to perform operation.
|
|
/// \param[in] transforms A vector of transformations to be chosen from to apply.
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<RandomChoiceOperation> RandomChoice(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
|
|
|
/// \brief Function to create a TypeCast TensorOperation.
|
|
/// \notes Tensor operation to cast to a given MindSpore data type.
|
|
/// \param[in] data_type mindspore.dtype to be cast to.
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<TypeCastOperation> TypeCast(std::string data_type);
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
/// \brief Function to create a Unique TensorOperation.
|
|
/// \notes Return an output tensor containing all the unique elements of the input tensor in
|
|
/// the same order that they occur in the input tensor.
|
|
/// \return Shared pointer to the current TensorOperation.
|
|
std::shared_ptr<UniqueOperation> Unique();
|
|
#endif
|
|
|
|
/* ####################################### Derived TensorOperation classes ################################# */
|
|
|
|
class ComposeOperation : public TensorOperation {
|
|
public:
|
|
explicit ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
|
|
|
~ComposeOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
|
};
|
|
|
|
class DuplicateOperation : public TensorOperation {
|
|
public:
|
|
DuplicateOperation() = default;
|
|
|
|
~DuplicateOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
};
|
|
|
|
class OneHotOperation : public TensorOperation {
|
|
public:
|
|
explicit OneHotOperation(int32_t num_classes_);
|
|
|
|
~OneHotOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
|
|
private:
|
|
float num_classes_;
|
|
};
|
|
|
|
class RandomApplyOperation : public TensorOperation {
|
|
public:
|
|
explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
|
|
|
|
~RandomApplyOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
|
double prob_;
|
|
};
|
|
|
|
class RandomChoiceOperation : public TensorOperation {
|
|
public:
|
|
explicit RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
|
|
|
~RandomChoiceOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
|
};
|
|
|
|
class TypeCastOperation : public TensorOperation {
|
|
public:
|
|
explicit TypeCastOperation(std::string data_type);
|
|
|
|
~TypeCastOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
|
|
private:
|
|
std::string data_type_;
|
|
};
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
class UniqueOperation : public TensorOperation {
|
|
public:
|
|
UniqueOperation() = default;
|
|
|
|
~UniqueOperation() = default;
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
Status ValidateParams() override;
|
|
};
|
|
#endif
|
|
} // namespace transforms
|
|
} // namespace dataset
|
|
} // namespace mindspore
|
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|