|
|
@ -14,20 +14,24 @@
|
|
|
|
* limitations under the License.
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
#include <map>
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
#include <tuple>
|
|
|
|
#include <tuple>
|
|
|
|
#include <unordered_set>
|
|
|
|
#include <unordered_set>
|
|
|
|
#include "pipeline/jit/parse/python_adapter.h"
|
|
|
|
#include <utility>
|
|
|
|
#include "pipeline/jit/action.h"
|
|
|
|
|
|
|
|
#include "backend/kernel_compiler/common_utils.h"
|
|
|
|
#include "backend/kernel_compiler/common_utils.h"
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
|
|
|
#include "vm/segment_runner.h"
|
|
|
|
|
|
|
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
|
|
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
|
|
|
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
|
|
|
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
|
|
|
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
|
|
|
#include "pipeline/jit/parse/python_adapter.h"
|
|
|
|
|
|
|
|
#include "pipeline/jit/action.h"
|
|
|
|
|
|
|
|
#include "vm/segment_runner.h"
|
|
|
|
#if ENABLE_GPU
|
|
|
|
#if ENABLE_GPU
|
|
|
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
|
|
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
@ -526,12 +530,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
|
|
|
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
|
|
|
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) {
|
|
|
|
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
|
|
|
if (fuse_nodes.empty()) {
|
|
|
|
const std::string &postfix) {
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto mng = kernel_graph->manager();
|
|
|
|
auto mng = kernel_graph->manager();
|
|
|
|
if (mng == nullptr) {
|
|
|
|
if (mng == nullptr) {
|
|
|
|
mng = Manage(kernel_graph, true);
|
|
|
|
mng = Manage(kernel_graph, true);
|
|
|
@ -565,6 +566,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
fuse_op_name += postfix;
|
|
|
|
fuse_op_name += postfix;
|
|
|
|
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
|
|
|
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(fuse_new_node, src_outputs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
|
|
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
|
|
@ -737,7 +740,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
|
|
|
|
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
|
|
|
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
|
|
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
|
|
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
|
|
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
|
|
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
|
|
|
prim::kPrimTranspose, prim::kPrimCast};
|
|
|
|
prim::kPrimTranspose, prim::kPrimCast, prim::kPrimRealDiv};
|
|
|
|
#elif ENABLE_GPU
|
|
|
|
#elif ENABLE_GPU
|
|
|
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
|
|
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
|
|
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
|
|
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
|
|
@ -786,5 +789,123 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
|
|
|
device::gpu::SetKernelInfo(cnode, kernel_type);
|
|
|
|
device::gpu::SetKernelInfo(cnode, kernel_type);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
|
|
|
|
|
|
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
|
|
|
|
|
|
|
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
|
|
|
|
|
|
|
auto cnode = (*iter)->cast<CNodePtr>();
|
|
|
|
|
|
|
|
if (cnode == nullptr) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto prior_node = cnode->input(kControlDependPriorIndex);
|
|
|
|
|
|
|
|
auto depend_node = cnode->input(kControlDependBehindIndex);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prior_node);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(depend_node);
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> prior_nodes = {prior_node};
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> depend_nodes = {depend_node};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int depend_mode = 0;
|
|
|
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
|
|
|
|
|
|
|
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector<AnfNodePtr> {
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> out_nodes;
|
|
|
|
|
|
|
|
auto user_set = param->func_graph()->manager()->node_users()[param];
|
|
|
|
|
|
|
|
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
|
|
|
|
|
|
|
|
if (iter->first != cnode) {
|
|
|
|
|
|
|
|
out_nodes.push_back(iter->first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return out_nodes;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
|
|
|
|
|
|
|
prior_nodes = GetOutputNodes(prior_node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (depend_node->isa<Parameter>()) {
|
|
|
|
|
|
|
|
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> real_prior_nodes;
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> prior_visited;
|
|
|
|
|
|
|
|
for (const auto &tmp : prior_nodes) {
|
|
|
|
|
|
|
|
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
prior_visited.clear();
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> real_depend_nodes;
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> depend_visited;
|
|
|
|
|
|
|
|
for (const auto &tmp : depend_nodes) {
|
|
|
|
|
|
|
|
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
depend_visited.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &prior : real_prior_nodes) {
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto &depend : real_depend_nodes) {
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
depend_prior->insert({depend, std::make_pair(prior, cnode)});
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
real_prior_nodes.clear();
|
|
|
|
|
|
|
|
real_depend_nodes.clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
|
|
|
|
|
|
|
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
|
|
|
|
|
|
|
|
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
|
|
|
|
|
|
|
if (iter->second.second == control_depend_node) {
|
|
|
|
|
|
|
|
iter = depend_prior->erase(iter);
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
++iter;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
|
|
|
|
|
|
|
|
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
|
|
|
|
|
|
|
|
for (auto item : new_depend_prior) {
|
|
|
|
|
|
|
|
depend_prior->insert(item);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
|
|
|
|
|
|
|
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
|
|
|
|
|
|
|
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
|
|
|
|
|
|
|
|
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Need real outputs of makeTuple";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
|
|
|
|
|
|
|
if (iter->first == outputs[out_idx]) {
|
|
|
|
|
|
|
|
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
|
|
|
|
|
|
|
|
iter = depend_prior->erase(iter);
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (iter->second.first == outputs[out_idx]) {
|
|
|
|
|
|
|
|
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
|
|
|
|
|
|
|
|
iter = depend_prior->erase(iter);
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
++iter;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto item : new_fuse_cnode_dep_pri) {
|
|
|
|
|
|
|
|
depend_prior->insert(item);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|