!2772 add a pre pass for node removals
Merge pull request !2772 from Jamie/removalpasspull/2772/MERGE
commit
efe07bd169
@ -1,6 +1,8 @@
|
|||||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
add_library(engine-opt OBJECT
|
add_library(engine-opt OBJECT
|
||||||
pass.cc
|
pass.cc
|
||||||
util/printer_pass.cc
|
pre/removal_nodes.cc
|
||||||
|
pre/removal_pass.cc
|
||||||
|
util/printer_pass.cc
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/opt/pre/removal_nodes.h"
|
||||||
|
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||||
|
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
|
||||||
|
|
||||||
|
// Perform ShuffleOp removal check.
|
||||||
|
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||||
|
if (is_caching_) {
|
||||||
|
MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||||
|
if (removal_pass_) {
|
||||||
|
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
} else {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,51 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class RemovalPass;
|
||||||
|
|
||||||
|
/// \class RemovalNodes removal_nodes.h
|
||||||
|
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||||
|
/// It works in conjunction with the removal_pass.
|
||||||
|
class RemovalNodes : public NodePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||||
|
explicit RemovalNodes(RemovalPass *removal_pass);
|
||||||
|
|
||||||
|
/// \brief Perform ShuffleOp removal check
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_caching_;
|
||||||
|
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
|
@ -0,0 +1,45 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "dataset/engine/opt/pre/removal_nodes.h"
|
||||||
|
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// constructor
|
||||||
|
RemovalPass::RemovalPass() {}
|
||||||
|
|
||||||
|
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||||
|
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||||
|
// Create the removal node pass which can identify which nodes need to be removed.
|
||||||
|
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
|
||||||
|
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
||||||
|
|
||||||
|
// Then, execute the removal of any nodes that were set up for removal
|
||||||
|
for (auto node : removal_nodes_) {
|
||||||
|
node->Remove();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds an operator to the list of operators to be removed
|
||||||
|
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,53 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class DatasetOp;
|
||||||
|
|
||||||
|
/// \class RemovalPass removal_pass.h
|
||||||
|
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
|
||||||
|
/// nodes should be removed, and then removes them.
|
||||||
|
class RemovalPass : public TreePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
RemovalPass();
|
||||||
|
|
||||||
|
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||||
|
/// \param[inout] tree The tree to operate on.
|
||||||
|
/// \param[inout] Indicate of the tree was modified.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Adds an operator to the list of operators to be removed
|
||||||
|
/// \param[in] dataset_op The operator to add to the removal list
|
||||||
|
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
Loading…
Reference in new issue