parent
b57d4ea2f3
commit
2ffe76981d
@ -1,6 +1,8 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-opt OBJECT
|
||||
pass.cc
|
||||
util/printer_pass.cc
|
||||
pass.cc
|
||||
pre/removal_nodes.cc
|
||||
pre/removal_pass.cc
|
||||
util/printer_pass.cc
|
||||
)
|
||||
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pre/removal_nodes.h"
|
||||
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
|
||||
|
||||
// Perform ShuffleOp removal check.
|
||||
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||
if (removal_pass_) {
|
||||
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
|
||||
} else {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RemovalPass;
|
||||
|
||||
/// \class RemovalNodes removal_nodes.h
|
||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||
/// It works in conjunction with the removal_pass.
|
||||
class RemovalNodes : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||
explicit RemovalNodes(RemovalPass *removal_pass);
|
||||
|
||||
/// \brief Perform ShuffleOp removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
bool is_caching_;
|
||||
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
|
@ -0,0 +1,45 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "dataset/engine/opt/pre/removal_nodes.h"
|
||||
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// constructor
|
||||
RemovalPass::RemovalPass() {}
|
||||
|
||||
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
// Create the removal node pass which can identify which nodes need to be removed.
|
||||
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
|
||||
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
||||
|
||||
// Then, execute the removal of any nodes that were set up for removal
|
||||
for (auto node : removal_nodes_) {
|
||||
node->Remove();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the list of operators to be removed
|
||||
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class RemovalPass removal_pass.h
|
||||
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
|
||||
/// nodes should be removed, and then removes them.
|
||||
class RemovalPass : public TreePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
RemovalPass();
|
||||
|
||||
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
/// \brief Adds an operator to the list of operators to be removed
|
||||
/// \param[in] dataset_op The operator to add to the removal list
|
||||
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
Loading…
Reference in new issue