!11665 [GraphKernel] Add parallel fusion support to master.
From: @tronzhang Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doupull/11665/MERGE
commit
ca675c0521
@ -1 +1 @@
|
||||
Subproject commit 20ecddee01cd07d0945240672597d7a36499e537
|
||||
Subproject commit c63b2e6f7e7704f18b217e42c8c5c0b95e04b9fb
|
@ -0,0 +1,153 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Cost model for parallel fusion"""
|
||||
from .model import PrimLib
|
||||
|
||||
|
||||
class ParalGain:
|
||||
def __init__(self, fusion_type, bottleneck, gain, block_assign):
|
||||
self.fusion_type = fusion_type
|
||||
self.bottleneck = bottleneck
|
||||
self.gain = gain
|
||||
self.block_assign = block_assign
|
||||
|
||||
|
||||
class ScheduleAnalyzer:
|
||||
"""schedule analyzer"""
|
||||
WRAP_SIZE = 32
|
||||
MAX_SM = 80 # Volta
|
||||
MAX_NUM_THREADS = 1024
|
||||
MAX_BLOCK = 256
|
||||
|
||||
def __init__(self, graph):
|
||||
self.graph = graph
|
||||
self.block_num = 0
|
||||
self.block_weight = 0
|
||||
_, outputs = graph.deduce_parameters()
|
||||
self.ops = graph.ops
|
||||
self.dom_op = [out.op for out in outputs]
|
||||
|
||||
def prod(self, shape):
|
||||
res = shape[0]
|
||||
for i in range(1, len(shape)):
|
||||
res = res * shape[i]
|
||||
return res
|
||||
|
||||
def _cal_weight(self, ops):
|
||||
weight = 0
|
||||
for op in ops:
|
||||
weight += self.prod(op.output.shape) * \
|
||||
PrimLib.dtype_bytes(op.output.dtype)
|
||||
return weight
|
||||
|
||||
def injective_analyze(self):
|
||||
"""analyze injective case"""
|
||||
const_size = max([self.prod(op.output.shape) for op in self.dom_op])
|
||||
const_size = (const_size + self.MAX_NUM_THREADS -
|
||||
1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS
|
||||
|
||||
total_weight = self._cal_weight(self.ops)
|
||||
total_block = (const_size + self.MAX_NUM_THREADS -
|
||||
1) // self.MAX_NUM_THREADS
|
||||
need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS
|
||||
if need_block_split:
|
||||
self.block_num = self.MAX_BLOCK
|
||||
waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK
|
||||
self.block_weight = total_weight // total_block * waves
|
||||
else:
|
||||
self.block_num = total_block
|
||||
self.block_weight = total_weight // self.block_num
|
||||
|
||||
def reduce_analyze(self):
|
||||
"""analyze reduce case"""
|
||||
thread_x, thread_y = 32, 32
|
||||
reduce_op = None
|
||||
for op in self.ops:
|
||||
if PrimLib.iter_type(op) == PrimLib.REDUCE:
|
||||
if reduce_op:
|
||||
raise RuntimeError(
|
||||
"Not support multiply reduce op in a graph now.")
|
||||
reduce_op = op
|
||||
if not reduce_op:
|
||||
raise RuntimeError("Wrong analyze for reduce!")
|
||||
shape = reduce_op.inputs[0].shape
|
||||
reduce_axis = reduce_op.attrs['reduce_axis']
|
||||
total_space = self.prod(shape)
|
||||
red_space = shape[reduce_axis[0]]
|
||||
for i in range(1, len(reduce_axis)):
|
||||
red_space *= shape[reduce_axis[i]]
|
||||
dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype)
|
||||
|
||||
weight = self._cal_weight(self.ops) # reduce + injective
|
||||
block_x = (total_space // red_space + thread_y - 1) // thread_y
|
||||
block_w = (weight + block_x - 1) // block_x
|
||||
waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK
|
||||
self.block_num = min(self.MAX_BLOCK, block_x)
|
||||
all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write
|
||||
self.block_weight = (block_w + all_reduce *
|
||||
dtype_size * thread_x * thread_y) * waves
|
||||
|
||||
def default_analyze(self):
|
||||
"""analyze default case"""
|
||||
def _cal_default_space(op):
|
||||
space = self.prod(op.output.shape)
|
||||
for t in op.inputs:
|
||||
size = self.prod(t.shape)
|
||||
if size > space:
|
||||
space = size
|
||||
return space
|
||||
space = max([_cal_default_space(op) for op in self.dom_op])
|
||||
|
||||
# each sm least 4 wrap
|
||||
block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4)
|
||||
self.block_num = min(self.MAX_BLOCK, block)
|
||||
self.block_weight = self._cal_weight(self.ops) // self.block_num
|
||||
|
||||
def analyze(self):
|
||||
"""analyze ops"""
|
||||
def _ops_type(ops, dom_op):
|
||||
have_reduce = any(
|
||||
[PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops])
|
||||
if have_reduce:
|
||||
return True
|
||||
return PrimLib.iter_type(dom_op[0])
|
||||
|
||||
dom_type = _ops_type(self.ops, self.dom_op)
|
||||
if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
self.injective_analyze()
|
||||
elif dom_type == PrimLib.REDUCE:
|
||||
self.reduce_analyze()
|
||||
else:
|
||||
self.default_analyze()
|
||||
|
||||
|
||||
def block_parallel_estimate(graphs):
|
||||
"""estimate block parallel gain"""
|
||||
sum_block, max_weight, sum_weight, blocks = 0, 0, 0, []
|
||||
for g in graphs:
|
||||
s = ScheduleAnalyzer(g)
|
||||
s.analyze()
|
||||
sum_block += s.block_num
|
||||
if s.block_weight > max_weight:
|
||||
max_weight = s.block_weight
|
||||
sum_weight += s.block_weight
|
||||
blocks.append(s.block_num)
|
||||
if sum_block > ScheduleAnalyzer.MAX_SM * 32:
|
||||
return ParalGain("none", sum_weight, 0, [])
|
||||
return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks)
|
||||
|
||||
|
||||
def parallel_estimate(graphs):
|
||||
return block_parallel_estimate(graphs)
|
@ -0,0 +1,49 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""estimate parallel case"""
|
||||
import json
|
||||
import json.decoder as jd
|
||||
import traceback
|
||||
from mindspore import log as logger
|
||||
from . import model
|
||||
|
||||
def estimate_ops(json_str: str):
|
||||
"""Call costmodel to estimate ops."""
|
||||
try:
|
||||
json_obj = json.loads(json_str)
|
||||
graph_descs = json_obj["graph_desc"]
|
||||
graphs = []
|
||||
for gd in graph_descs:
|
||||
graphs.append(model.load_composite(gd).graph)
|
||||
estimation = model.parallel_estimate(graphs)
|
||||
if estimation.fusion_type == "block_fusion" and estimation.gain > 0:
|
||||
res = (estimation.block_assign, estimation.gain)
|
||||
else:
|
||||
res = ([0 for g in graphs], 0)
|
||||
return res
|
||||
except jd.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def estimate_calulation_amount(json_str: str):
|
||||
"""Call costmodel to estimate calculation amount of op."""
|
||||
try:
|
||||
graph_desc = json.loads(json_str)
|
||||
comp = model.load_composite(graph_desc)
|
||||
estimation = model.parallel_estimate([comp.graph])
|
||||
return estimation.bottleneck
|
||||
except jd.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
@ -0,0 +1,155 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/depend_formater.h"
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
const auto &users = mng->node_users()[node];
|
||||
std::vector<std::pair<AnfNodePtr, int>> sons;
|
||||
for (const auto &[user, index] : users) {
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
|
||||
sons.emplace_back(user, index);
|
||||
continue;
|
||||
}
|
||||
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
|
||||
sons.emplace_back(fake_first_grad_son, grad_index);
|
||||
}
|
||||
|
||||
AnfNodePtrList latter_to_delete;
|
||||
for (const auto &[son, index] : sons) {
|
||||
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_delete.push_back(son);
|
||||
}
|
||||
|
||||
if (latter_to_delete.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
|
||||
if (latter_to_delete.size() == sons.size()) {
|
||||
// Left one Depend node relation and delete others!
|
||||
++delete_begin;
|
||||
}
|
||||
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
|
||||
auto depend_anfnode = *delete_begin;
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtr patron_node;
|
||||
|
||||
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(return_cnode);
|
||||
auto output_node = return_cnode->input(kFirstDataInputIndex);
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
patron_node = output_cnode->input(kFirstDataInputIndex);
|
||||
} else {
|
||||
patron_node = output_node;
|
||||
}
|
||||
|
||||
return patron_node;
|
||||
}
|
||||
|
||||
void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtr modified_node = stable_node;
|
||||
for (const auto &free_node : free_nodes) {
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(modified_node->abstract());
|
||||
main_graph->AddNode(depend_cnode);
|
||||
modified_node = depend_cnode;
|
||||
}
|
||||
|
||||
if (!free_nodes.empty()) {
|
||||
mng->Replace(stable_node, modified_node);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DependFormater::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
// 1. Try to remove redundant depend.
|
||||
bool changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) {
|
||||
if (RemoveRedundantDepend(node, mng)) {
|
||||
changed = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Should re-toposort for changed graph.
|
||||
if (changed) {
|
||||
nodes = TopoSort(func_graph->get_return());
|
||||
}
|
||||
|
||||
// 2. Move depend to tail of graph.
|
||||
AnfNodePtrList old_depends;
|
||||
AnfNodePtrList free_nodes;
|
||||
|
||||
// Find depend and its free nodes.
|
||||
for (const auto &node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
old_depends.push_back(node);
|
||||
free_nodes.push_back(node->cast<CNodePtr>()->input(kDependAttachNodeIndex));
|
||||
}
|
||||
|
||||
if (old_depends.empty()) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Delete old depend.
|
||||
for (const auto &depend_anfnode : old_depends) {
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
|
||||
// Add new depend node in tail.
|
||||
AnfNodePtr patron_node = FindPatronNode(func_graph, mng);
|
||||
AddDepends(patron_node, free_nodes, func_graph, mng);
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,37 @@
|
||||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DependFormater : public Pass {
|
||||
public:
|
||||
DependFormater() : Pass("depend_formater") {}
|
||||
~DependFormater() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
using DependFormaterPtr = std::shared_ptr<DependFormater>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
@ -0,0 +1,89 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/parallel_cost_model.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::string CommonDimInfo::ToString() {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Dim(" << dim_info_ << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
|
||||
nlohmann::json json_desc;
|
||||
AnfNodePtrList nodes = {node};
|
||||
DumpOption dump_option;
|
||||
if (!AnfToJsonDesc(nodes, dump_option, &json_desc)) {
|
||||
MS_LOG(EXCEPTION) << "Collect json desc failed.";
|
||||
}
|
||||
|
||||
auto json_desc_str = json_desc.dump();
|
||||
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str);
|
||||
if (py::isinstance<py::none>(ret)) {
|
||||
MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||
<< json_desc_str;
|
||||
}
|
||||
return py::cast<int>(ret);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) {
|
||||
nlohmann::json json_desc;
|
||||
std::vector<AnfNodePtrList> graphs;
|
||||
std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs),
|
||||
[](const AnfNodePtr &node) -> AnfNodePtrList { return {node}; });
|
||||
DumpOption dump_option;
|
||||
if (!AnfToJsonDesc(graphs, dump_option, &json_desc)) {
|
||||
MS_LOG(EXCEPTION) << "Collect json desc failed.";
|
||||
}
|
||||
|
||||
auto json_desc_str = json_desc.dump();
|
||||
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelEstimateOps, json_desc_str);
|
||||
if (py::isinstance<py::none>(ret)) {
|
||||
MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||
<< json_desc_str;
|
||||
}
|
||||
|
||||
py::tuple ret_tuple = py::cast<py::tuple>(ret);
|
||||
if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!";
|
||||
}
|
||||
|
||||
std::vector<DimInfoPtr> dim_infos;
|
||||
py::list dim_list = py::cast<py::list>(ret_tuple[0]);
|
||||
for (size_t i = 0; i < dim_list.size(); ++i) {
|
||||
dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i])));
|
||||
}
|
||||
int benefit = py::cast<int>(ret_tuple[1]);
|
||||
|
||||
return std::make_tuple(dim_infos, benefit);
|
||||
}
|
||||
|
||||
ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) {
|
||||
if (target != kGPUDevice) {
|
||||
MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now.";
|
||||
}
|
||||
return cost_model_;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/parallel_cost_model.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DimInfo {
|
||||
public:
|
||||
DimInfo() = default;
|
||||
~DimInfo() {}
|
||||
virtual std::string ToString() = 0;
|
||||
};
|
||||
|
||||
class CommonDimInfo : public DimInfo {
|
||||
public:
|
||||
explicit CommonDimInfo(size_t dim) : dim_info_(dim) {}
|
||||
~CommonDimInfo() {}
|
||||
void set_dim_info(size_t d) { dim_info_ = d; }
|
||||
size_t dim_info() const { return dim_info_; }
|
||||
std::string ToString() override;
|
||||
|
||||
private:
|
||||
size_t dim_info_;
|
||||
};
|
||||
|
||||
using DimInfoPtr = std::shared_ptr<DimInfo>;
|
||||
using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>;
|
||||
|
||||
class ParallelCostModel {
|
||||
public:
|
||||
ParallelCostModel() {}
|
||||
~ParallelCostModel() {}
|
||||
int GetNodeCalAmount(const AnfNodePtr &node);
|
||||
std::tuple<std::vector<DimInfoPtr>, int> CalFuseInfo(const AnfNodePtrList &nodes);
|
||||
};
|
||||
|
||||
using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>;
|
||||
|
||||
class ParellelCostModelWarehouse {
|
||||
public:
|
||||
static ParellelCostModelWarehouse &Instance() {
|
||||
static ParellelCostModelWarehouse instance;
|
||||
return instance;
|
||||
}
|
||||
ParallelCostModelPtr GetParallelCostModel(const std::string &target);
|
||||
|
||||
private:
|
||||
ParellelCostModelWarehouse() { cost_model_ = std::make_shared<ParallelCostModel>(); }
|
||||
ParallelCostModelPtr cost_model_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,122 @@
|
||||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/parallel_cost_model.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ParallelInfo {
|
||||
public:
|
||||
ParallelInfo() = default;
|
||||
ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims) : nodes_(nodes), dims_(dims) {}
|
||||
ParallelInfo(const ParallelInfo &obj) {
|
||||
nodes_ = obj.nodes_;
|
||||
dims_ = obj.dims_;
|
||||
}
|
||||
~ParallelInfo() = default;
|
||||
|
||||
size_t GetSize() const {
|
||||
if (nodes_.size() != dims_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Internal error in parallel info!";
|
||||
}
|
||||
return nodes_.size();
|
||||
}
|
||||
const AnfNodePtrList &nodes() const { return nodes_; }
|
||||
const std::vector<DimInfoPtr> &dims() const { return dims_; }
|
||||
|
||||
private:
|
||||
AnfNodePtrList nodes_;
|
||||
std::vector<DimInfoPtr> dims_;
|
||||
};
|
||||
|
||||
class ParallelConfig {
|
||||
public:
|
||||
ParallelConfig() = default;
|
||||
explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {}
|
||||
explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; }
|
||||
~ParallelConfig() = default;
|
||||
size_t max_num_for_fuse() { return max_num_for_fuse_; }
|
||||
|
||||
private:
|
||||
size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result.
|
||||
};
|
||||
|
||||
struct NodeRelation {
|
||||
public:
|
||||
NodeRelation() {}
|
||||
~NodeRelation() = default;
|
||||
OrderedSet<AnfNodePtr> pres;
|
||||
OrderedSet<AnfNodePtr> nexts;
|
||||
};
|
||||
|
||||
class ParallelOpFusion : public Pass {
|
||||
public:
|
||||
ParallelOpFusion(const std::string &target, const ParallelConfig &config)
|
||||
: Pass("parallel_fusion"), target_(target), config_(config) {}
|
||||
~ParallelOpFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<int> &offsets,
|
||||
const std::vector<bool> &used,
|
||||
const AnfNodePtrList &nodes,
|
||||
const std::set<int> &excludes);
|
||||
|
||||
std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates(
|
||||
size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
|
||||
std::map<AnfNodePtr, int> *sorted_indices);
|
||||
|
||||
std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs);
|
||||
|
||||
void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
|
||||
std::vector<ParallelInfo> *parallel_infos);
|
||||
|
||||
std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups);
|
||||
|
||||
void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info);
|
||||
|
||||
bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
|
||||
OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes);
|
||||
std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels);
|
||||
|
||||
std::string target_;
|
||||
ParallelConfig config_;
|
||||
ParallelCostModelPtr cost_model_ptr_;
|
||||
std::set<AnfNodePtr> virtual_noout_nodes_;
|
||||
std::set<AnfNodePtr> ignore_noin_nodes_;
|
||||
};
|
||||
using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
|
@ -0,0 +1,54 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================================================================
|
||||
"""test graph parallel case"""
|
||||
import model
|
||||
|
||||
def injective_graph(shape):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('injective') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
a3 = gb.emit('Abs', a2)
|
||||
gb.emit('Abs', a3)
|
||||
return gb.get()[0]
|
||||
|
||||
def reduce_graph(shape, reduce_axis):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('reduce') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
a3 = gb.emit('Abs', a2)
|
||||
gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis})
|
||||
return gb.get()[0]
|
||||
|
||||
def control_graph(shape):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('control') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
gb.emit('ControlDepend', a2)
|
||||
return gb.get()[0]
|
||||
|
||||
def block_fusion(graphs):
|
||||
gain = model.parallel_estimate(graphs)
|
||||
print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain))
|
||||
return gain.fusion_type == "block_fusion" and gain.gain > 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert block_fusion([injective_graph([40, 1024]), injective_graph([40, 1024])])
|
||||
assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])])
|
||||
assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])])
|
||||
assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])])
|
||||
assert block_fusion([control_graph([20, 128]), injective_graph([40, 1024])])
|
Loading…
Reference in new issue