Migrate repeat_pass.cc to IR optimizer and remove ExecTree optimizer

pull/11787/head
Nat Sutyanyong 4 years ago
parent c16b45ab23
commit 5a7dc0accc

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -258,10 +258,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return The number of required repeats for the operator
int32_t op_total_repeats() { return op_total_repeats_; }
/// \brief Getter function
/// \return The number of required epochs for the operator
int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; }
/// \brief Getter function
/// \return The number of repeats per epoch for the operator
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; }

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,10 +17,8 @@
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/log_adapter.h"

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -304,6 +304,10 @@ class CsvOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *const modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CsvOp"; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,20 +16,9 @@
#include "minddata/dataset/engine/execution_tree.h"
#include <iostream>
#include <string>
#include <utility>
#include <limits>
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#endif
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
#if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)
@ -255,97 +244,13 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
return Status::OK();
}
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status The status code returned
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
num_epochs_ = num_epochs;
partially_prepare_ = partial;
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PreAction());
// Post optimization compulsory transformation
RETURN_IF_NOT_OK(this->PostAction());
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
// Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK(this->PrepareDeprecated());
return Status::OK();
}
Status ExecutionTree::PreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
if (!partially_prepare_) {
#ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheErrorPass>());
#endif
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
}
MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops.";
// Apply pre action passes
for (auto &pass : pre_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Pre passes complete.";
return Status::OK();
}
Status ExecutionTree::PostAction() {
bool modified = false;
OptPass post_actions;
// Construct pre actions
MS_LOG(INFO) << "Running post pass loops.";
#ifndef ENABLE_ANDROID
// Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind.
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
// This is because Python API binding to TensorOperation is still in progress.
post_actions.push_back(std::make_unique<CacheErrorPass>());
post_actions.push_back(std::make_unique<RepeatPass>());
#endif
// Apply post action passes
for (auto &pass : post_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Post passes complete.";
return Status::OK();
}
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
//
// This driver is deprecated.
Status ExecutionTree::PrepareDeprecated() {
// Tree must be in pending prepare state before we can assign root to it
if (tree_state_ != kDeTStatePrepare) {
std::string err_msg =
"Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg);
}
Status ExecutionTree::Prepare() {
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
if (root_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -169,24 +169,6 @@ class ExecutionTree {
// @return the prepare flags
uint32_t PrepareFlags() const { return prepare_flags_; }
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status The status code returned
Status Prepare(int num_epochs = -1, bool partial = false);
// Compulsory transformation/action pre optimization.
// @return Status The status code returned
Status PreAction();
@ -200,7 +182,7 @@ class ExecutionTree {
// it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status The status code returned
Status PrepareDeprecated();
Status Prepare();
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
@ -239,10 +221,6 @@ class ExecutionTree {
// Getter for profiling manager, no ownership
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
private:
// A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print
@ -257,9 +235,7 @@ class ExecutionTree {
int32_t id_count_; // Counter for generating operator id's
uint32_t prepare_flags_; // Flags used during tree prepare
TreeState tree_state_; // Tracking the current tree state
int32_t num_epochs_; // Total number of epochs to run for this tree
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
#if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)
// This rank_id is for numa and device_queue, one process work with only one rank_id,
// for standalone scenario, this rank_id may come from env 'CUDA_VISIBLE_DEVICES',

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -102,9 +102,11 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
node_ops->push_back(project_op);
}
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
pad_map_));
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
#else
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, pad_map_));

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -84,9 +84,12 @@ void BucketBatchByLengthNode::Print(std::ostream &out) const {
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
node_ops->push_back(std::make_shared<BucketBatchByLengthOp>(
column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
if (bucket_boundaries_[0] == 0) {
bucket_boundaries_.erase(bucket_boundaries_.begin());
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -55,10 +55,11 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
// Function to build BuildSentenceVocabNode
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<BuildSentencePieceVocabOp> build_sentence_piece_vocab_op;
build_sentence_piece_vocab_op = std::make_shared<BuildSentencePieceVocabOp>(
vocab_, col_names_, vocab_size_, character_coverage_, model_type_, params_, connector_que_size_);
node_ops->push_back(build_sentence_piece_vocab_op);
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -54,6 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
build_vocab_op->set_total_repeats(GetTotalRepeats());
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(build_vocab_op);
return Status::OK();
}

@ -51,10 +51,24 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
"Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
lookup_op_->set_total_repeats(GetTotalRepeats());
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(lookup_op_);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status CacheLookupNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<CacheLookupNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status CacheLookupNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<CacheLookupNode>(), modified);
}
std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
// CacheLookupNode should already been copied, so we just return it here
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);

@ -64,6 +64,18 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
private:
std::shared_ptr<SamplerObj> sampler_;
std::shared_ptr<DatasetOp> lookup_op_;

@ -48,9 +48,23 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
merge_op->set_total_repeats(GetTotalRepeats());
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(merge_op);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status CacheMergeNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<CacheMergeNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status CacheMergeNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<CacheMergeNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -53,6 +53,18 @@ class CacheMergeNode : public DatasetNode {
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore

@ -53,9 +53,23 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
cache_op->SetSampler(sampler_->SamplerBuild());
cache_op->set_total_repeats(GetTotalRepeats());
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cache_op);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status CacheNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<CacheNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status CacheNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<CacheNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -55,6 +55,18 @@ class CacheNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
private:
std::shared_ptr<SamplerObj> sampler_;
};

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -119,12 +119,16 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
}
Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<ConcatOp> op;
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_));
op = std::make_shared<ConcatOp>(connector_que_size_);
} else {
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(),
children_flag_and_nums_, children_start_end_index_));
op = std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_,
children_start_end_index_);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -219,7 +219,9 @@ DatasetNode::DatasetNode()
dataset_size_(-1),
mappable_(kNotADataSource),
nary_op_(false),
descendant_of_cache_(false) {
descendant_of_cache_(false),
total_repeats_(-1),
num_epochs_(1) {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
@ -292,6 +293,24 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status of the function
virtual Status to_json(nlohmann::json *out_json);
/// \brief Setter function, set the number of total repeats for the operator
void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; }
/// \brief Setter function, set the number of epochs for the operator
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
/// \brief Getter function
/// \return The number of required repeats for the operator
int32_t GetTotalRepeats() const { return total_repeats_; }
/// \brief Getter function
/// \return The number of epochs for the operator
int32_t GetNumEpochs() const { return num_epochs_; }
/// \brief Getter function
/// \return The number of repeats per epoch for the operator
int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; }
protected:
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase
@ -301,6 +320,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
int32_t total_repeats_; // Number of times required to run this operator
int32_t num_epochs_; // Number of epochs
// Establish a parent-child relationship between this node and the input node.
// Used only in the constructor of the class and its derived classes.
void AddChild(std::shared_ptr<DatasetNode> child);

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,6 +44,8 @@ void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" +
// Function to build the EpochCtrlOp
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
new_op_->set_total_repeats(GetTotalRepeats());
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op_);
op_ = new_op_;
return Status::OK();

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,7 +44,10 @@ void FilterNode::Print(std::ostream &out) const {
}
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,7 +38,8 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
output_columns_(output_columns),
project_columns_(project_columns),
DatasetNode(std::move(cache)),
callbacks_(callbacks) {
callbacks_(callbacks),
under_a_cache_(false) {
this->AddChild(child);
}
@ -64,6 +65,17 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
[](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
// This is temporary code.
// Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
// to TensorOp, we need to check the randomness here.
// When TensorOperation captures the randomness behaviour, remove this code and the member "under_a_cache_"
// and the temporary code in CacheValidation pre pass in IR optimizer.
if (under_a_cache_) {
auto itr = std::find_if(tensor_ops.begin(), tensor_ops.end(), [](const auto &it) { return !it->Deterministic(); });
if (itr != tensor_ops.end()) {
RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
}
}
// This parameter will be removed with next rebase
std::vector<std::string> col_orders;
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
@ -74,9 +86,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
if (!project_columns_.empty()) {
auto project_op = std::make_shared<ProjectOp>(project_columns_);
project_op->set_total_repeats(GetTotalRepeats());
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(project_op);
}
map_op->set_total_repeats(GetTotalRepeats());
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(map_op);
return Status::OK();
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -79,6 +79,9 @@ class MapNode : public DatasetNode {
/// \brief setter to set all tensor operations
void setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations);
/// \brief indicate this Map will be cached
void Cached() { under_a_cache_ = true; }
/// \brief Getter functions
/// \brief Getter of tensor operations
/// \return Vector of operations the Map node will process
@ -95,12 +98,11 @@ class MapNode : public DatasetNode {
private:
std::vector<std::shared_ptr<TensorOperation>> operations_;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
std::vector<std::string> project_columns_;
std::vector<std::shared_ptr<DSCallback>> callbacks_;
bool under_a_cache_;
};
} // namespace dataset

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,7 +53,10 @@ Status ProjectNode::ValidateParams() {
}
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<ProjectOp>(columns_));
auto op = std::make_shared<ProjectOp>(columns_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,7 +58,10 @@ Status RenameNode::ValidateParams() {
}
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,6 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + st
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
new_op->set_total_repeats(GetTotalRepeats());
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op);
op_ = new_op;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save