!8672 add removal pass for dataset getters
From: @ziruiwu Reviewed-by: Signed-off-by:pull/8672/MERGE
commit
bf447ff51a
@ -1,13 +1,14 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-opt OBJECT
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/cache_error_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
pre/getter_pass.cc
|
||||
pre/input_validation_pass.cc
|
||||
pre/removal_pass.cc
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
util/printer_pass.cc
|
||||
)
|
||||
|
@ -0,0 +1,87 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) {
|
||||
nodes_to_clear_callback_.push_back(node);
|
||||
} else if (type_ == kDatasetSize) {
|
||||
nodes_to_remove_.push_back(node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
RETURN_IF_NOT_OK(pass_.Run(tree, modified));
|
||||
|
||||
// nested private class variables can be directly accessed by its outer class
|
||||
for (auto node : pass_.nodes_to_remove_) {
|
||||
RETURN_IF_NOT_OK(node->Remove());
|
||||
}
|
||||
|
||||
// clear the callback for selected ops (map when its GetOutputType/Shape)
|
||||
for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <list>
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class GetterPass
|
||||
/// \brief This is a tree pass that will remove nodes or clears the callback in MapOp
|
||||
class GetterPass : public TreePass {
|
||||
public:
|
||||
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
|
||||
/// \brief Constructor
|
||||
explicit GetterPass(GetterType tp) : pass_(tp) {}
|
||||
/// \brief Destructor
|
||||
~GetterPass() = default;
|
||||
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
private:
|
||||
/// \class GetterNodes, this is a nested class which is owned via composition by the outter class to identify nodes
|
||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||
class GetterNodes : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit GetterNodes(GetterType tp) : type_(tp) {}
|
||||
|
||||
~GetterNodes() = default;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
|
||||
// whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp
|
||||
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
GetterType type_;
|
||||
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
|
||||
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
|
||||
};
|
||||
// outter class needs only to own the inner class object since it automatically has access to its private variables
|
||||
GetterNodes pass_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_
|
@ -0,0 +1,137 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestOptimizationPass : public UT::DatasetOpTesting {
|
||||
public:
|
||||
MindDataTestOptimizationPass() = default;
|
||||
void SetUp() override { GlobalInit(); }
|
||||
|
||||
// this recursive function helps build a ExecutionTree from a IR node, it is copied from TreeAdapter
|
||||
Status DFSBuild(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op, ExecutionTree *tree) {
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty() && tree != nullptr && op != nullptr, "Fail To Build Tree.");
|
||||
(*op) = ops.front();
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(*op));
|
||||
for (size_t i = 1; i < ops.size(); i++) {
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(ops[i]));
|
||||
RETURN_IF_NOT_OK(ops[i - 1]->AddChild(ops[i]));
|
||||
}
|
||||
for (std::shared_ptr<DatasetNode> child_ir : ir->Children()) {
|
||||
std::shared_ptr<DatasetOp> child_op;
|
||||
RETURN_IF_NOT_OK(DFSBuild(child_ir, &child_op, tree));
|
||||
RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// this function will build an execution_tree from a root ir node. nullptr will be returned if error occurs
|
||||
std::unique_ptr<ExecutionTree> BuildTree(std::shared_ptr<DatasetNode> ir) {
|
||||
std::unique_ptr<ExecutionTree> tree = std::make_unique<ExecutionTree>();
|
||||
std::shared_ptr<DatasetOp> root;
|
||||
if (DFSBuild(ir, &root, tree.get()).IsError()) return nullptr;
|
||||
if (tree->AssignRoot(root).IsError()) return nullptr;
|
||||
return tree;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestOutputShapeAndTypePass.";
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
|
||||
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2);
|
||||
|
||||
std::unique_ptr<ExecutionTree> exe_tree = BuildTree(ds->IRNode());
|
||||
|
||||
ASSERT_NE(exe_tree, nullptr);
|
||||
|
||||
// test the optimization pass
|
||||
// OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty
|
||||
std::function<OptPass(OptPass)> pass = [](OptPass pre) {
|
||||
// return a new pass, this will override all the existing pre-pass es
|
||||
pre.clear();
|
||||
pre.push_back(std::make_unique<GetterPass>(GetterPass::kOutputShapeAndType));
|
||||
return pre;
|
||||
};
|
||||
|
||||
exe_tree->SetPrePassOverride(pass);
|
||||
ASSERT_OK(exe_tree->PrepareTreePreAction());
|
||||
std::stringstream ss;
|
||||
|
||||
// print the tree in std::string as a way to verify that nodes are indeed removed
|
||||
exe_tree->Print(ss);
|
||||
std::string ss_str = ss.str();
|
||||
|
||||
// ss_str would look like this
|
||||
// +- ( 0) <BatchOp>: [workers: 4] [batch size: 2]
|
||||
// +- ( 2) <ProjectOp>: [workers: 0 (inlined)]
|
||||
// +- ( 4) <RandomDataOp>: [workers: 4] [total rows: 44]
|
||||
//
|
||||
|
||||
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
|
||||
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
|
||||
EXPECT_EQ(ss_str.find("RepeatOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestDatasetSizePass.";
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
|
||||
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2);
|
||||
|
||||
std::unique_ptr<ExecutionTree> exe_tree = BuildTree(ds->IRNode());
|
||||
|
||||
ASSERT_NE(exe_tree, nullptr);
|
||||
|
||||
// test the optimization pass
|
||||
// OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty
|
||||
std::function<OptPass(OptPass)> pass = [](OptPass pre) {
|
||||
// return a new pass, this will override all the existing pre-pass es
|
||||
pre.clear(); // remove all existing pre pass
|
||||
pre.push_back(std::make_unique<GetterPass>(GetterPass::kDatasetSize));
|
||||
return pre;
|
||||
};
|
||||
|
||||
exe_tree->SetPrePassOverride(pass);
|
||||
ASSERT_OK(exe_tree->PrepareTreePreAction());
|
||||
std::stringstream ss;
|
||||
// print the tree in std::string as a way to verify that nodes are indeed removed
|
||||
exe_tree->Print(ss);
|
||||
std::string ss_str = ss.str();
|
||||
|
||||
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
|
||||
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos);
|
||||
EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
|
||||
}
|
Loading…
Reference in new issue