!8672 add removal pass for dataset getters

From: @ziruiwu
Reviewed-by: 
Signed-off-by:
pull/8672/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bf447ff51a

@ -41,6 +41,9 @@ class CallbackManager {
/// \param [in] callbacks list of callbacks to perform
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
/// \brief set callbacks to empty
void ClearCallbacks() { callbacks_.clear(); }
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
/// \return Status

@ -393,6 +393,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }
/// \brief Remove all callbacks from DatasetOp
void ClearCallbacks() { callback_manager_.ClearCallbacks(); }
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function

@ -16,6 +16,7 @@
#include "minddata/dataset/engine/execution_tree.h"
#include <iostream>
#include <string>
#include <utility>
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
@ -35,7 +36,7 @@
namespace mindspore {
namespace dataset {
// Constructor
ExecutionTree::ExecutionTree() : id_count_(0) {
ExecutionTree::ExecutionTree() : id_count_(0), pre_pass_override_(nullptr) {
tg_ = std::make_unique<TaskGroup>();
tree_state_ = kDeTStateInit;
prepare_flags_ = kDePrepNone;
@ -234,7 +235,6 @@ Status ExecutionTree::PrepareTreePreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass loops.";
#ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheErrorPass>());
#endif
@ -243,6 +243,17 @@ Status ExecutionTree::PrepareTreePreAction() {
#ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheTransformPass>());
#endif
// this offers a way to override the preset optimization pass with customized ones
// this is used when certain nodes are removed for tree getters
if (pre_pass_override_) {
MS_LOG(INFO) << "Default pre optimization passes is being overridden,"
<< " number of passes before the override:" << pre_actions.size() << ".";
pre_actions = pre_pass_override_(std::move(pre_actions));
}
MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops.";
// Apply pre action passes
for (auto &pass : pre_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
@ -256,7 +267,7 @@ Status ExecutionTree::PrepareTreePostAction() {
tree_state_ = kDeTStatePrepare;
bool modified = false;
std::vector<std::unique_ptr<Pass>> post_actions;
OptPass post_actions;
// Construct pre actions
MS_LOG(INFO) << "Running post pass loops.";
#ifndef ENABLE_ANDROID
@ -274,7 +285,7 @@ Status ExecutionTree::PrepareTreePostAction() {
Status ExecutionTree::Optimize() {
// Vector of optimizations, currently only 1, add more as necessary
std::vector<std::unique_ptr<NodePass>> optimizations;
OptPass optimizations;
#ifndef ENABLE_ANDROID
optimizations.push_back(std::make_unique<TensorOpFusionPass>());
#endif

@ -24,13 +24,13 @@
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/util/status.h"
#include "mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h"
namespace mindspore {
namespace dataset {
// Forward declares
class TaskGroup;
class DatasetOp;
class Pass;
using OptPass = std::vector<std::unique_ptr<Pass>>;
class ExecutionTree {
public:
// Prepare flags used during tree prepare phase
@ -253,6 +253,10 @@ class ExecutionTree {
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
// set the function ptr that overrides the pre-pass which allows caller to adjust the existing pre_pass and
// introduce new passes. E.g. caller can override the num_epoch in EpochInjectionPass
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
private:
// A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print
@ -270,6 +274,7 @@ class ExecutionTree {
int32_t num_epochs_; // Total number of epochs to run for this tree
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool optimize_; // Flag to enable optional optimizations
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
};
} // namespace dataset
} // namespace mindspore

@ -1,13 +1,14 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
optional/tensor_op_fusion_pass.cc
pass.cc
post/repeat_pass.cc
pre/cache_error_pass.cc
pre/cache_transform_pass.cc
pre/epoch_injection_pass.cc
pre/getter_pass.cc
pre/input_validation_pass.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc
util/printer_pass.cc
)

@ -0,0 +1,87 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) {
nodes_to_clear_callback_.push_back(node);
} else if (type_ == kDatasetSize) {
nodes_to_remove_.push_back(node);
}
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
return Status::OK();
}
#ifdef ENABLE_PYTHON
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
return Status::OK();
}
#endif
Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
RETURN_IF_NOT_OK(pass_.Run(tree, modified));
// nested private class variables can be directly accessed by its outer class
for (auto node : pass_.nodes_to_remove_) {
RETURN_IF_NOT_OK(node->Remove());
}
// clear the callback for selected ops (map when its GetOutputType/Shape)
for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_
#include <memory>
#include <list>
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class GetterPass
/// \brief This is a tree pass that will remove nodes or clears the callback in MapOp
class GetterPass : public TreePass {
public:
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
/// \brief Constructor
explicit GetterPass(GetterType tp) : pass_(tp) {}
/// \brief Destructor
~GetterPass() = default;
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
private:
/// \class GetterNodes, this is a nested class which is owned via composition by the outter class to identify nodes
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
class GetterNodes : public NodePass {
public:
/// \brief Constructor
explicit GetterNodes(GetterType tp) : type_(tp) {}
~GetterNodes() = default;
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
// whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
#ifdef ENABLE_PYTHON
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
#endif
GetterType type_;
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
};
// outter class needs only to own the inner class object since it automatically has access to its private variables
GetterNodes pass_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_

@ -95,7 +95,7 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
}
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
// Build the DatasetOp ExecutionTree from the optmized IR tree
// Build the DatasetOp ExecutionTree from the optimized IR tree
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node.");

@ -1,57 +1,98 @@
include(GoogleTest)
SET(DE_UT_SRCS
common/common.cc
common/cvop_common.cc
common/bboxop_common.cc
auto_contrast_op_test.cc
album_op_test.cc
arena_test.cc
auto_contrast_op_test.cc
batch_op_test.cc
bit_functions_test.cc
storage_container_test.cc
treap_test.cc
interrupt_test.cc
image_folder_op_test.cc
buddy_test.cc
bounding_box_augment_op_test.cc
arena_test.cc
btree_test.cc
buddy_test.cc
build_vocab_test.cc
c_api_cache_test.cc
c_api_dataset_album_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_save.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_samplers_test.cc
c_api_text_sentence_piece_vocab_test.cc
c_api_text_vocab_test.cc
c_api_transforms_test.cc
c_api_vision_test.cc
callback_test.cc
celeba_op_test.cc
center_crop_op_test.cc
channel_swap_test.cc
cifar_op_test.cc
circular_pool_test.cc
client_config_test.cc
clue_op_test.cc
coco_op_test.cc
common/bboxop_common.cc
common/common.cc
common/cvop_common.cc
concat_op_test.cc
concatenate_op_test.cc
connector_test.cc
cutmix_batch_op_test.cc
csv_op_test.cc
cut_out_op_test.cc
cutmix_batch_op_test.cc
cyclic_array_test.cc
data_helper_test.cc
datatype_test.cc
decode_op_test.cc
distributed_sampler_test.cc
epoch_ctrl_op_test.cc
equalize_op_test.cc
execution_tree_test.cc
fill_op_test.cc
global_context_test.cc
gnn_graph_test.cc
image_folder_op_test.cc
image_process_test.cc
interrupt_test.cc
jieba_tokenizer_op_test.cc
main_test.cc
map_op_test.cc
mask_test.cc
memory_pool_test.cc
mind_record_op_test.cc
mixup_batch_op_test.cc
memory_pool_test.cc
mnist_op_test.cc
normalize_op_test.cc
one_hot_op_test.cc
optimization_pass_test.cc
pad_end_op_test.cc
pad_op_test.cc
path_test.cc
perf_data_test.cc
project_op_test.cc
queue_test.cc
random_affine_op_test.cc
random_color_adjust_op_test.cc
random_color_op_test.cc
random_crop_op_test.cc
random_crop_with_bbox_op_test.cc
random_crop_decode_resize_op_test.cc
random_crop_and_resize_op_test.cc
random_crop_and_resize_with_bbox_op_test.cc
random_color_adjust_op_test.cc
random_crop_decode_resize_op_test.cc
random_crop_op_test.cc
random_crop_with_bbox_op_test.cc
random_horizontal_flip_op_test.cc
random_horizontal_flip_with_bbox_test.cc
random_resize_op_test.cc
random_resize_op_test.cc
random_resize_with_bbox_op_test.cc
random_rotation_op_test.cc
random_solarize_op_test.cc
@ -65,74 +106,34 @@ SET(DE_UT_SRCS
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
schema_test.cc
skip_op_test.cc
sentence_piece_vocab_op_test.cc
shuffle_op_test.cc
skip_op_test.cc
slice_op_test.cc
sliding_window_op_test.cc
solarize_op_test.cc
stand_alone_samplers_test.cc
status_test.cc
storage_container_test.cc
subset_random_sampler_test.cc
swap_red_blue_test.cc
take_op_test.cc
task_manager_test.cc
tensor_op_fusion_pass_test.cc
tensor_row_test.cc
tensor_string_test.cc
tensor_test.cc
tensorshape_test.cc
text_file_op_test.cc
tfReader_op_test.cc
to_float16_op_test.cc
tokenizer_op_test.cc
treap_test.cc
tree_adapter_test.cc
trucate_pair_test.cc
type_cast_op_test.cc
zip_op_test.cc
random_resize_op_test.cc
subset_random_sampler_test.cc
weighted_random_sampler_test.cc
mnist_op_test.cc
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
clue_op_test.cc
csv_op_test.cc
text_file_op_test.cc
concat_op_test.cc
jieba_tokenizer_op_test.cc
tokenizer_op_test.cc
gnn_graph_test.cc
coco_op_test.cc
fill_op_test.cc
mask_test.cc
trucate_pair_test.cc
concatenate_op_test.cc
cyclic_array_test.cc
perf_data_test.cc
build_vocab_test.cc
c_api_samplers_test.cc
c_api_transforms_test.cc
c_api_vision_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_album_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_save.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc
c_api_text_sentence_piece_vocab_test.cc
c_api_text_vocab_test.cc
c_api_cache_test.cc
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
sentence_piece_vocab_op_test.cc
solarize_op_test.cc
swap_red_blue_test.cc
distributed_sampler_test.cc
data_helper_test.cc
image_process_test.cc
slice_op_test.cc
zip_op_test.cc
)
if (ENABLE_PYTHON)

@ -0,0 +1,137 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::MsLogLevel::INFO;
class MindDataTestOptimizationPass : public UT::DatasetOpTesting {
public:
MindDataTestOptimizationPass() = default;
void SetUp() override { GlobalInit(); }
// this recursive function helps build a ExecutionTree from a IR node, it is copied from TreeAdapter
Status DFSBuild(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op, ExecutionTree *tree) {
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty() && tree != nullptr && op != nullptr, "Fail To Build Tree.");
(*op) = ops.front();
RETURN_IF_NOT_OK(tree->AssociateNode(*op));
for (size_t i = 1; i < ops.size(); i++) {
RETURN_IF_NOT_OK(tree->AssociateNode(ops[i]));
RETURN_IF_NOT_OK(ops[i - 1]->AddChild(ops[i]));
}
for (std::shared_ptr<DatasetNode> child_ir : ir->Children()) {
std::shared_ptr<DatasetOp> child_op;
RETURN_IF_NOT_OK(DFSBuild(child_ir, &child_op, tree));
RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops
}
return Status::OK();
}
// this function will build an execution_tree from a root ir node. nullptr will be returned if error occurs
std::unique_ptr<ExecutionTree> BuildTree(std::shared_ptr<DatasetNode> ir) {
std::unique_ptr<ExecutionTree> tree = std::make_unique<ExecutionTree>();
std::shared_ptr<DatasetOp> root;
if (DFSBuild(ir, &root, tree.get()).IsError()) return nullptr;
if (tree->AssignRoot(root).IsError()) return nullptr;
return tree;
}
};
TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) {
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestOutputShapeAndTypePass.";
// config leaf_op, use random_data to avoid I/O
std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
std::shared_ptr<Dataset> ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2);
std::unique_ptr<ExecutionTree> exe_tree = BuildTree(ds->IRNode());
ASSERT_NE(exe_tree, nullptr);
// test the optimization pass
// OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty
std::function<OptPass(OptPass)> pass = [](OptPass pre) {
// return a new pass, this will override all the existing pre-pass es
pre.clear();
pre.push_back(std::make_unique<GetterPass>(GetterPass::kOutputShapeAndType));
return pre;
};
exe_tree->SetPrePassOverride(pass);
ASSERT_OK(exe_tree->PrepareTreePreAction());
std::stringstream ss;
// print the tree in std::string as a way to verify that nodes are indeed removed
exe_tree->Print(ss);
std::string ss_str = ss.str();
// ss_str would look like this
// +- ( 0) <BatchOp>: [workers: 4] [batch size: 2]
// +- ( 2) <ProjectOp>: [workers: 0 (inlined)]
// +- ( 4) <RandomDataOp>: [workers: 4] [total rows: 44]
//
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
EXPECT_EQ(ss_str.find("RepeatOp"), ss_str.npos);
EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
}
TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestDatasetSizePass.";
// config leaf_op, use random_data to avoid I/O
std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
std::shared_ptr<Dataset> ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2);
std::unique_ptr<ExecutionTree> exe_tree = BuildTree(ds->IRNode());
ASSERT_NE(exe_tree, nullptr);
// test the optimization pass
// OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty
std::function<OptPass(OptPass)> pass = [](OptPass pre) {
// return a new pass, this will override all the existing pre-pass es
pre.clear(); // remove all existing pre pass
pre.push_back(std::make_unique<GetterPass>(GetterPass::kDatasetSize));
return pre;
};
exe_tree->SetPrePassOverride(pass);
ASSERT_OK(exe_tree->PrepareTreePreAction());
std::stringstream ss;
// print the tree in std::string as a way to verify that nodes are indeed removed
exe_tree->Print(ss);
std::string ss_str = ss.str();
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos);
EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos);
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
}
Loading…
Cancel
Save