boilerplate code for future IR optimizer

add 2 test cases to IRNode Deepcopy()

address review cmts

fix ut

samplerObj copy

ci

fix ci

fix ci round III

address further review cmts

add a missing macro

fix merge conflict

fix complie err

fix lite compile err

fix compile err

fix lite compile round III

address an issue

fix minor comments
pull/8802/head
Nat Sutyanyong 5 years ago committed by Zirui Wu
parent dabb82ec7a
commit 5e1bb0b697

@ -568,8 +568,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size,
character_coverage, model_type, params);
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
@ -600,8 +600,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first) {
auto vocab = std::make_shared<Vocab>();
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens,
special_first);
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();

@ -190,13 +190,12 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
: sp_minddataset_(std::move(sampler)) {}
#endif
bool PreBuiltSamplerObj::ValidateParams() { return true; }
@ -207,6 +206,13 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
#endif
return std::make_shared<PreBuiltSamplerObj>(sp_);
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object

@ -30,8 +30,6 @@
namespace mindspore {
namespace dataset {
TensorOperation::TensorOperation() {}
/* ####################################### Validator Functions ############################################ */
Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
@ -231,7 +229,7 @@ std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
// RandomApplyOperation
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
: transforms_(transforms), prob_(prob) {}
: TensorOperation(true), transforms_(transforms), prob_(prob) {}
Status RandomApplyOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
@ -248,7 +246,7 @@ std::shared_ptr<TensorOp> RandomApplyOperation::Build() {
// RandomChoiceOperation
RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
: transforms_(transforms) {}
: TensorOperation(true), transforms_(transforms) {}
Status RandomChoiceOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));

@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> &degrees
scale_range_(scale_range),
shear_ranges_(shear_ranges),
interpolation_(interpolation),
fill_value_(fill_value) {}
fill_value_(fill_value) {
random_op_ = true;
}
Status RandomAffineOperation::ValidateParams() {
// Degrees
@ -867,7 +869,7 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
}
// RandomColorOperation.
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; }
Status RandomColorOperation::ValidateParams() {
// Do some input validation.
@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() {
// RandomColorAdjustOperation.
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
std::vector<float> saturation, std::vector<float> hue)
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {
random_op_ = true;
}
Status RandomColorAdjustOperation::ValidateParams() {
// brightness
@ -1012,11 +1016,14 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
// RandomCropOperation
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
std::vector<uint8_t> fill_value, BorderType padding_mode)
: size_(size),
: TensorOperation(true),
size_(size),
padding_(padding),
pad_if_needed_(pad_if_needed),
fill_value_(fill_value),
padding_mode_(padding_mode) {}
padding_mode_(padding_mode) {
random_op_ = true;
}
Status RandomCropOperation::ValidateParams() {
// size
@ -1083,7 +1090,12 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio,
InterpolationMode interpolation, int32_t max_attempts)
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
: TensorOperation(true),
size_(size),
scale_(scale),
ratio_(ratio),
interpolation_(interpolation),
max_attempts_(max_attempts) {}
Status RandomCropDecodeResizeOperation::ValidateParams() {
// size
@ -1176,7 +1188,8 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding,
bool pad_if_needed, std::vector<uint8_t> fill_value,
BorderType padding_mode)
: size_(size),
: TensorOperation(true),
size_(size),
padding_(padding),
pad_if_needed_(pad_if_needed),
fill_value_(fill_value),
@ -1245,7 +1258,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() {
}
// RandomHorizontalFlipOperation
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability)
: TensorOperation(true), probability_(probability) {}
Status RandomHorizontalFlipOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_));
@ -1260,7 +1274,7 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
// RandomHorizontalFlipWithBBoxOperation
RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability)
: probability_(probability) {}
: TensorOperation(true), probability_(probability) {}
Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_));
@ -1275,7 +1289,8 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipWithBBoxOperation::Build() {
}
// RandomPosterizeOperation
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range)
: TensorOperation(true), bit_range_(bit_range) {}
Status RandomPosterizeOperation::ValidateParams() {
if (bit_range_.size() != 2) {
@ -1309,7 +1324,7 @@ std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
}
// RandomResizeOperation
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : size_(size) {}
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : TensorOperation(true), size_(size) {}
Status RandomResizeOperation::ValidateParams() {
// size
@ -1343,7 +1358,8 @@ std::shared_ptr<TensorOp> RandomResizeOperation::Build() {
}
// RandomResizeWithBBoxOperation
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) : size_(size) {}
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size)
: TensorOperation(true), size_(size) {}
Status RandomResizeWithBBoxOperation::ValidateParams() {
// size
@ -1380,7 +1396,12 @@ std::shared_ptr<TensorOp> RandomResizeWithBBoxOperation::Build() {
RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio, InterpolationMode interpolation,
int32_t max_attempts)
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
: TensorOperation(true),
size_(size),
scale_(scale),
ratio_(ratio),
interpolation_(interpolation),
max_attempts_(max_attempts) {}
Status RandomResizedCropOperation::ValidateParams() {
// size
@ -1536,7 +1557,8 @@ std::shared_ptr<TensorOp> RandomResizedCropWithBBoxOperation::Build() {
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
bool expand, std::vector<float> center,
std::vector<uint8_t> fill_value)
: degrees_(degrees),
: TensorOperation(true),
degrees_(degrees),
interpolation_mode_(interpolation_mode),
expand_(expand),
center_(center),
@ -1603,7 +1625,7 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
// RandomSelectSubpolicyOperation.
RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation(
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy)
: policy_(policy) {}
: TensorOperation(true), policy_(policy) {}
Status RandomSelectSubpolicyOperation::ValidateParams() {
if (policy_.empty()) {
@ -1650,7 +1672,8 @@ std::shared_ptr<TensorOp> RandomSelectSubpolicyOperation::Build() {
}
// Function to create RandomSharpness.
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees)
: TensorOperation(true), degrees_(degrees) {}
Status RandomSharpnessOperation::ValidateParams() {
if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) {
@ -1674,7 +1697,8 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
}
// RandomSolarizeOperation.
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold)
: TensorOperation(true), threshold_(threshold) {}
Status RandomSolarizeOperation::ValidateParams() {
if (threshold_.size() != 2) {
@ -1705,7 +1729,8 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
}
// RandomVerticalFlipOperation
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability)
: TensorOperation(true), probability_(probability) {}
Status RandomVerticalFlipOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_));
@ -1720,7 +1745,7 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
// RandomVerticalFlipWithBBoxOperation
RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability)
: probability_(probability) {}
: TensorOperation(true), probability_(probability) {}
Status RandomVerticalFlipWithBBoxOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_));

@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
build_sentence_piece_vocab_node.cc
build_vocab_node.cc
concat_node.cc
epoch_ctrl_node.cc
filter_node.cc
map_node.cc
project_node.cc
rename_node.cc
repeat_node.cc
root_node.cc
shuffle_node.cc
skip_node.cc
sync_wait_node.cc

@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, boo
batch_size_func_(batch_size_func),
batch_map_func_(batch_map_func),
pad_map_(pad_map) {
this->children.push_back(child);
this->AddChild(child);
}
#endif
// constructor #2, called by C++ API
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder)
: batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BatchNode::Copy() {
#ifdef ENABLE_PYTHON
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_,
col_order_, batch_size_func_, batch_map_func_, pad_map_);
#else
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_);
#endif
return node;
}
void BatchNode::Print(std::ostream &out) const {
out << Name() + "(batch_size:" + std::to_string(batch_size_) +
" drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")";
}
Status BatchNode::ValidateParams() {

@ -44,6 +44,18 @@ class BatchNode : public DatasetNode {
/// \brief Destructor
~BatchNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBatchNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_);
return node;
}
void BucketBatchByLengthNode::Print(std::ostream &out) const {
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)";
}
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {

@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \brief Destructor
~BucketBatchByLengthNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBucketBatchByLengthNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

@ -22,6 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> chil
character_coverage_(character_coverage),
model_type_(model_type),
params_(params) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BuildSentenceVocabNode::Copy() {
auto node = std::make_shared<BuildSentenceVocabNode>(nullptr, vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_);
return node;
}
void BuildSentenceVocabNode::Print(std::ostream &out) const {
out << Name() + "<vocab>," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) +
",...)";
}
// Function to build BuildSentenceVocabNode
@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
}
// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \brief Destructor
~BuildSentenceVocabNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBuildSentencePieceVocabNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::shared_ptr<SentencePieceVocab> vocab_;
std::vector<std::string> col_names_;

@ -22,7 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_p
top_k_(top_k),
special_tokens_(special_tokens),
special_first_(special_first) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BuildVocabNode::Copy() {
auto node =
std::make_shared<BuildVocabNode>(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_);
return node;
}
void BuildVocabNode::Print(std::ostream &out) const {
out << Name() + "(<vocab>," + "columns:" + PrintColumns(columns_) + ",...)";
}
// Function to build BuildVocabNode
@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status BuildVocabNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
}
// Visitor accepting method for NodePass
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode {
/// \brief Destructor
~BuildVocabNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBuildVocabNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_;

@ -22,7 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
: sampler_(sampler),
children_flag_and_nums_(children_flag_and_nums),
children_start_end_index_(children_start_end_index) {
this->children = datasets;
for (auto const &child : datasets) AddChild(child);
}
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
// create an empty vector to copy a concat
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>());
return node;
}
void ConcatNode::Print(std::ostream &out) const { out << Name(); }
Status ConcatNode::ValidateParams() {
if (children.size() < 2) {
if (children_.size() < 2) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(children.begin(), children.end(), nullptr) != children.end()) {
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
@ -73,5 +81,16 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
return node_ops;
}
// Visitor accepting method for NodePass
Status ConcatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ConcatNode>(), modified);
}
// Visitor accepting method for NodePass
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode {
/// \brief Destructor
~ConcatNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kConcatNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode {
std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
};
} // namespace dataset

@ -233,14 +233,92 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this();
}
DatasetNode::DatasetNode() {
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
build_status = Status::OK(); // remove me after changing return val of Build()
}
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
std::shared_ptr<DatasetNode> new_node = this->Copy();
for (const auto &child : children_) {
new_node->AddChild(child->DeepCopy());
}
return new_node;
}
std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
std::string me;
if (columns.empty()) {
me = "<nil>";
} else {
me = "[";
auto i = 0;
for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
me += *it;
if (i < columns.size() - 1) {
me += ", ";
} else {
me += "]";
}
}
}
return me;
}
void DatasetNode::PrintTree(std::ostream &out) const {
int level = 0;
PrintNode(out, &level);
}
void DatasetNode::PrintNode(std::ostream &out, int *level) const {
const std::string prefix = "+-";
const std::string indent = " ";
out << prefix;
Print(out);
for (const auto &c : this->Children()) {
out << '\n';
++(*level);
for (auto i = 0; i < *level; i++) {
out << indent;
}
c->PrintNode(out, level);
--(*level);
}
}
// Add a node as a child, node's parent needs to be nullptr
// this function will allow child to be a nullptr, in which case it will simply skip
void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
if (child != nullptr && child->parent_ == nullptr) {
children_.push_back(child);
child->parent_ = this;
} else if (child != nullptr) {
MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr.";
}
}
// Remove this node from its parent. Add the child of this node to its parent.
// for now, this remove is limited to node with a single child or no child
Status DatasetNode::Remove() {
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent.");
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child.");
if (children_.empty()) { // I am a leaf node, remove me from my parent's children list
parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()),
parent_->children_.end()); // removal using "erase remove idiom"
} else { // replace my position in my parent's children list with my single child
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent
*itr = std::move(children_[0]); // replace me in my parent's children list with my single child
children_.clear(); // release my single child from my children list
}
parent_ = nullptr;
return Status::OK();
}
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified);
}
Status DatasetNode::GetShardId(int32_t *shard_id) {
if (!Children().empty()) {
// Get shard id from the child node
return Children()[0]->GetShardId(shard_id);
} else {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
}
}
// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SourceNode>(), modified);
}
// Visitor accepting method for NodePass
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for EpochCtrlNode
EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) {
// The root node's parent must set to null pointer.
this->AddChild(child);
}
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_);
return node;
}
void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; }
// Function to build the EpochCtrlOp
std::vector<std::shared_ptr<DatasetOp>> EpochCtrlNode::Build() {
// A dummy vector
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<EpochCtrlOp>(num_epochs_));
return node_ops;
}
// Function to validate the parameters for EpochCtrlNode
Status EpochCtrlNode::ValidateParams() {
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_.size() != 1 || children_[0] == nullptr) {
std::string err_msg = "Internal error: epoch control node should have one child node";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class EpochCtrlNode : public DatasetNode {
public:
/// \brief Constructor
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
/// \brief Destructor
~EpochCtrlNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kEpochCtrlNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
int32_t num_epochs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_

@ -21,7 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -31,7 +31,16 @@ namespace dataset {
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns)
: predicate_(predicate), input_columns_(input_columns) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> FilterNode::Copy() {
auto node = std::make_shared<FilterNode>(nullptr, predicate_, input_columns_);
return node;
}
void FilterNode::Print(std::ostream &out) const {
out << Name() + "(<predicate>," + "input_cols:" + PrintColumns(input_columns_) + ")";
}
std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status FilterNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<FilterNode>(), modified);
}
// Visitor accepting method for NodePass
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -35,6 +35,18 @@ class FilterNode : public DatasetNode {
/// \brief Destructor
~FilterNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kFilterNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -43,6 +55,18 @@ class FilterNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::shared_ptr<TensorOp> predicate_;
std::vector<std::string> input_columns_;

@ -22,6 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
project_columns_(project_columns),
DatasetNode(std::move(cache)),
callbacks_(callbacks) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> MapNode::Copy() {
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_,
callbacks_);
return node;
}
void MapNode::Print(std::ostream &out) const {
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
",<project_cols>" + ",...)";
}
std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
@ -93,5 +105,16 @@ Status MapNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status MapNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<MapNode>(), modified);
}
// Visitor accepting method for NodePass
Status MapNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<MapNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -37,6 +37,18 @@ class MapNode : public DatasetNode {
/// \brief Destructor
~MapNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMapNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -45,6 +57,23 @@ class MapNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Getter of tensor operations
/// \return Vector of operations the Map node will process
const auto &TensorOperations() const { return operations_; }
auto &TensorOperations() { return operations_; }
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::vector<std::shared_ptr<TensorOperation>> operations_;
std::vector<std::string> input_columns_;

@ -29,9 +29,16 @@ namespace dataset {
// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
: columns_(columns) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> ProjectNode::Copy() {
auto node = std::make_shared<ProjectNode>(nullptr, this->columns_);
return node;
}
void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; }
Status ProjectNode::ValidateParams() {
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";

@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode {
/// \brief Destructor
~ProjectNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kProjectNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save