|
|
|
@ -28,6 +28,17 @@ namespace dataset {
|
|
|
|
|
|
|
|
|
|
class TensorOp;
|
|
|
|
|
|
|
|
|
|
// Char arrays storing name of corresponding classes (in alphabetical order)
|
|
|
|
|
constexpr char kComposeOperation[] = "Compose";
|
|
|
|
|
constexpr char kDuplicateOperation[] = "Duplicate";
|
|
|
|
|
constexpr char kOneHotOperation[] = "OneHot";
|
|
|
|
|
constexpr char kPreBuiltOperation[] = "PreBuilt";
|
|
|
|
|
constexpr char kRandomApplyOperation[] = "RandomApply";
|
|
|
|
|
constexpr char kRandomChoiceOperation[] = "RandomChoice";
|
|
|
|
|
constexpr char kRandomSelectSubpolicyOperation[] = "RandomSelectSubpolicy";
|
|
|
|
|
constexpr char kTypeCastOperation[] = "TypeCast";
|
|
|
|
|
constexpr char kUniqueOperation[] = "Unique";
|
|
|
|
|
|
|
|
|
|
// Abstract class to represent a dataset in the data pipeline.
|
|
|
|
|
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
|
|
|
|
public:
|
|
|
|
@ -46,6 +57,8 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
|
|
|
|
|
|
|
|
|
virtual Status ValidateParams() = 0;
|
|
|
|
|
|
|
|
|
|
virtual std::string Name() const = 0;
|
|
|
|
|
|
|
|
|
|
/// \brief Check whether the operation is deterministic.
|
|
|
|
|
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
|
|
|
|
|
bool IsRandomOp() const { return random_op_; }
|
|
|
|
@ -146,6 +159,8 @@ class ComposeOperation : public TensorOperation {
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kComposeOperation; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
|
|
|
|
};
|
|
|
|
@ -159,6 +174,8 @@ class DuplicateOperation : public TensorOperation {
|
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kDuplicateOperation; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OneHotOperation : public TensorOperation {
|
|
|
|
@ -171,6 +188,8 @@ class OneHotOperation : public TensorOperation {
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kOneHotOperation; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
float num_classes_;
|
|
|
|
|
};
|
|
|
|
@ -185,6 +204,8 @@ class PreBuiltOperation : public TensorOperation {
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kPreBuiltOperation; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<TensorOp> op_;
|
|
|
|
|
};
|
|
|
|
@ -199,6 +220,8 @@ class RandomApplyOperation : public TensorOperation {
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kRandomApplyOperation; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
|
|
|
|
double prob_;
|
|
|
|
@ -214,6 +237,8 @@ class RandomChoiceOperation : public TensorOperation {
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kRandomChoiceOperation; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
|
|
|
|
};
|
|
|
|
@ -227,6 +252,8 @@ class TypeCastOperation : public TensorOperation {
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kTypeCastOperation; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::string data_type_;
|
|
|
|
|
};
|
|
|
|
@ -241,6 +268,8 @@ class UniqueOperation : public TensorOperation {
|
|
|
|
|
std::shared_ptr<TensorOp> Build() override;
|
|
|
|
|
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
std::string Name() const override { return kUniqueOperation; }
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
} // namespace transforms
|
|
|
|
|