!12176 [lite]match side effect in lite

From: @xu_anyue
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
pull/12176/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 2d04b34656

@ -225,7 +225,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc

@ -59,41 +59,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
}
}
void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
bool hasDepend = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();
inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr inputNode = cnode->input(i);
if (!inputNode->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto dependNode = utils::cast<CNodePtr>(inputNode);
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) ||
IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) {
hasDepend = true;
bool maskOut = (dependNode->inputs().size() == 3);
for (size_t j = 1; j < dependNode->inputs().size(); ++j) {
AnfNodePtr dependInputNode = dependNode->input(j);
if (dependInputNode->isa<CNode>()) {
inputs.emplace_back(dependInputNode);
if (maskOut) {
break;
}
}
}
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (hasDepend) {
cnode->set_inputs(inputs);
}
}
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
@ -286,23 +251,11 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
break;
}
}
#ifdef SUPPORT_TRAIN
RemoveIfMakeTuple(cnode);
RemoveIfDepend(cnode);
#endif
if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
#endif
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
continue;
}
#ifndef SUPPORT_TRAIN
RemoveIfMakeTuple(cnode);
#endif
auto primT = primitive_c->primitiveT();
auto node = std::make_unique<schema::CNodeT>();
if (node == nullptr) {

@ -41,7 +41,6 @@ class AnfExporter {
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node);
static void RemoveIfMakeTuple(const CNodePtr &cnode);
static void RemoveIfDepend(const CNodePtr &cnode);
protected:
int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode);

@ -59,7 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/update_conv2d_param_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/identity_remove_pass.cc
../optimizer/graph/redundant_op_remove_pass.cc
../optimizer/graph/infershape_pass.cc
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc

@ -34,7 +34,7 @@
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
@ -144,7 +144,7 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt
int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
const_fold_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
if (!config->trainModel) {
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
inne_context_ptr->Init();

@ -13,37 +13,41 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include <memory>
#include "mindspore/lite/include/errorcode.h"
#include "src/ops/primitive_c.h"
namespace mindspore::opt {
int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
namespace {
constexpr size_t InputDoubleNum = 2;
constexpr size_t InputTripleNum = 3;
constexpr auto kNameLoad = "Load";
constexpr auto kNameUpdateState = "UpdateState";
} // namespace
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
return lite::RET_NO_CHANGE;
}
auto type = opt::GetCNodeType(anf_node);
if (type != schema::PrimitiveType_Identity) {
MS_LOG(DEBUG) << "anf node is not a identity node.";
return lite::RET_NO_CHANGE;
}
auto identity_cnode = anf_node->cast<CNodePtr>();
if (identity_cnode->inputs().size() != lite::kDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
} else {
bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1));
if (!replace_succ) {
MS_LOG(ERROR) << "replace identity failed.";
return lite::RET_ERROR;
auto cnode = anf_node->cast<CNodePtr>();
if (type == schema::PrimitiveType_Identity) {
if (cnode->size() != InputDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
}
}
bool replace_succ = manager->Replace(anf_node, cnode->input(1));
if (!replace_succ) {
MS_LOG(ERROR) << "replace redundant op failed.";
return lite::RET_ERROR;
}
return RET_OK;
}
int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
return lite::RET_NO_CHANGE;
@ -53,7 +57,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
return lite::RET_NO_CHANGE;
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode->inputs().size() != 3) {
if (cnode->inputs().size() != InputTripleNum) {
MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
return RET_ERROR;
}
@ -81,7 +85,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
return lite::RET_OK;
}
bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
@ -93,10 +97,22 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
}
auto type = opt::GetCNodeType(node);
if (type == schema::PrimitiveType_Identity) {
status = ReplaceIdentity(node, manager);
} else if (type == schema::PrimitiveType_TupleGetItem) {
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameLoad))) {
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameUpdateState))) {
status = ReplaceOp(node, manager);
}
if (type == schema::PrimitiveType_Depend ||
type == schema::PrimitiveType_ControlDepend) { // ControlDepend delete next version.
status = ReplaceOp(node, manager);
}
if (type == schema::PrimitiveType_TupleGetItem) {
status = ReplaceTupleGetItem(node, manager);
} else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
}
if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
#ifndef MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_
#include <string>
#include <set>
#include "backend/optimizer/common/pass.h"
@ -24,11 +24,11 @@
using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class RemoveIdentityOpPass : public Pass {
class RemoveRedundantOpPass : public Pass {
public:
RemoveIdentityOpPass() : Pass("remove_identity_pass") {}
~RemoveIdentityOpPass() override = default;
int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
~RemoveRedundantOpPass() override = default;
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
bool Run(const FuncGraphPtr &graph) override;
@ -36,4 +36,4 @@ class RemoveIdentityOpPass : public Pass {
std::set<AnfNodePtr> remove_cnode_;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
#endif // MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_
Loading…
Cancel
Save