Not do cse for the nodes set recomputed before recompute pass

pull/11583/head
yujianfeng 4 years ago
parent 6d2ed6cafc
commit 266e960acb

@ -24,6 +24,7 @@
#include "abstract/abstract_function.h"
#include "utils/flags.h"
#include "utils/utils.h"
namespace mindspore {
/* namespace to support opt */
@ -32,6 +33,20 @@ using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractFunctionPtr;
bool WithRecomputedScope(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto full_name_with_scope = node->fullname_with_scope();
return full_name_with_scope.find(kAttrRecompute) == 0;
}
bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) {
return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) ||
(WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute));
}
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
MS_EXCEPTION_IF_NULL(node);
auto node_abs = node->abstract();
@ -83,7 +98,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknow node type";
MS_LOG(ERROR) << "Unknown node type";
}
hashes[node] = h;
@ -142,6 +157,10 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec
} else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>();
// Not do cse for the node set recompute before the recompute pass.
if (IsSetRecomputed(c_main, c_node)) {
return false;
}
// When appsame is true, check if has side effect, do not merge.
if (check_side_effect && HasSideEffect(main)) {
return false;

@ -25,12 +25,12 @@
#include <algorithm>
#include "ir/func_graph.h"
#include "mindspore/core/base/core_ops.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
constexpr auto kAttrRecompute = "recompute";
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@ -339,6 +339,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
auto recomputed_node = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(recomputed_node);
recomputed_node->AddAttr("duplicated", MakeValue(true));
recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope());
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
@ -415,6 +416,12 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) {
DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs,
&origin_to_recomputed_nodes);
}
// Set need cse attr for doing cse after recompute.
for (const auto &node : orders) {
if (WithRecomputedScope(node)) {
node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
}
}
}
} // namespace opt
} // namespace mindspore

@ -302,6 +302,11 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
return map;
}
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
return map;
}
static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {};
void InitOpt(const ResourcePtr &res) {
@ -323,6 +328,8 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_grad_epilogue"] =
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
@ -367,6 +374,7 @@ bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res,
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
@ -525,7 +533,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_cache_embedding", AddCacheEmbeddingPass},
{"add_control_depend", AddControlDependPass},
{"add_recomputation", AddRecomputationPass}};
{"add_recomputation", AddRecomputationPass},
{"cse_after_recomputation", OptAfterRecomputeGroup}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},

@ -380,6 +380,8 @@ constexpr auto kAttrPadMode = "pad_mode";
constexpr auto kAttrPad = "pad";
constexpr auto kAttrPadding = "padding";
constexpr auto kAttrIsGrad = "is_grad";
constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

Loading…
Cancel
Save