open graph kernel expander opt for gpu

pull/6160/head
r1chardf1d0 5 years ago
parent 12f3665167
commit 88de0cffa9

@ -648,9 +648,6 @@ bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, Anf
std::unordered_set<PrimitivePtr> GetExpandOps() { std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = { std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimSquare,
prim::kPrimGelu,
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
}; };
return expand_ops; return expand_ops;
} }

@ -116,6 +116,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>()); pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());

Loading…
Cancel
Save