Don't insert UpdateState for HyperMap func graph call.

Move auto monad eliminator out from CSE.
Eliminate auto monad nodes for output node.
pull/13050/head
Zhang Qinghua 4 years ago
parent 8e307818d0
commit e853df4ecd

@ -2,7 +2,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,6 +35,7 @@
#include "ir/signature.h"
#include "debug/trace.h"
#include "utils/ms_context.h"
#include "utils/utils.h"
namespace mindspore {
// namespace to support composite operators definition
@ -184,7 +185,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
});
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
auto call_node = func_graph->NewCNodeInOrder(inputs2);
call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true));
inputs.push_back(call_node);
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -222,7 +225,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
});
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
auto call_node = func_graph->NewCNodeInOrder(inputs2);
call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true));
inputs.push_back(call_node);
}
return func_graph->NewCNodeInOrder(inputs);
}
@ -253,7 +258,9 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGrap
j++;
}
inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
auto call_node = func_graph->NewCNodeInOrder(inputs2);
call_node->AddAttr(kAttrIgnoreSideEffect, MakeValue(true));
inputs.push_back(call_node);
}
return func_graph->NewCNodeInOrder(inputs);
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_
#include "ir/anf.h"
#include "ir/manager.h"
#include "frontend/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class AutoMonadEliminator {
public:
AutoMonadEliminator() = default;
virtual ~AutoMonadEliminator() = default;
bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
auto manager = optimizer->resource()->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
// Never report change.
(void)ReplaceAutoMonadNode(manager);
(void)EliminateAutoMonadNode(manager);
return false;
}
private:
bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const;
bool EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AUTO_MONAD_ELIMINATOR_H_

File diff suppressed because it is too large Load Diff

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,7 +42,6 @@ class CSE {
private:
bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const;
bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const;
bool DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const;
};

@ -33,6 +33,7 @@
#include "frontend/optimizer/clean.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/graph_transform.h"
#include "frontend/optimizer/auto_monad_eliminate.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/cache_embedding/cache_embedding.h"
@ -183,6 +184,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
{"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
{"a_3", a_3}});

@ -27,6 +27,7 @@
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/multitype_funcgraph.h"
#include "utils/flags.h"
#include "utils/utils.h"
#include "utils/ordered_map.h"
#include "base/core_ops.h"
#include "abstract/abstract_value.h"
@ -1295,6 +1296,14 @@ class AutoMonadConverter {
}
AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
// Not attach UpdateState if set kAttrIgnoreSideEffect.
auto attr_ignore_side_effect = attach->cast<CNodePtr>()->GetAttr(kAttrIgnoreSideEffect);
auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
GetValue<bool>(attr_ignore_side_effect);
if (ignore_side_effect) {
return state;
}
auto update_state = NewValueNode(prim::kPrimUpdateState);
auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach});
update_state_cnode->set_abstract(state->abstract());

@ -408,6 +408,7 @@ constexpr auto kAttrParallelTypeInfo = "parallel_type_info";
constexpr auto kAttrCompositeType = "composite_type";
constexpr auto kAttrStitch = "stitch";
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect";
constexpr auto kAttrSwitchLayer = "switch_layer";
constexpr auto kAttrReturn = "return";

Loading…
Cancel
Save