!12234 [GraphKernel] Support pipeline optimization for parallel fusion.

From: @tronzhang
Reviewed-by: 
Signed-off-by:
pull/12234/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 54fc5e0d2b

2
akg

@ -1 +1 @@
Subproject commit 94cb709ecaf5d1d869883dfe80cee7497dd0692c Subproject commit 24ba04df564fb3d2578e1b4324c760783b34d551

@ -17,11 +17,12 @@ from .model import PrimLib
class ParalGain: class ParalGain:
def __init__(self, fusion_type, bottleneck, gain, block_assign): def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
self.fusion_type = fusion_type self.fusion_type = fusion_type
self.bottleneck = bottleneck self.bottleneck = bottleneck
self.gain = gain self.gain = gain
self.block_assign = block_assign self.block_assign = block_assign
self.type_info = type_info
class ScheduleAnalyzer: class ScheduleAnalyzer:
@ -30,6 +31,7 @@ class ScheduleAnalyzer:
MAX_SM = 80 # Volta MAX_SM = 80 # Volta
MAX_NUM_THREADS = 1024 MAX_NUM_THREADS = 1024
MAX_BLOCK = 256 MAX_BLOCK = 256
PIPELINE_OP_THREADHOLD = 5
def __init__(self, graph): def __init__(self, graph):
self.graph = graph self.graph = graph
@ -132,11 +134,141 @@ class ScheduleAnalyzer:
else: else:
self.default_analyze() self.default_analyze()
def suitable_to_pipeline(self):
"""judge whether is suitable to be pipeline optimized"""
# Reduce is not suitable
def _contain_reduce(ops):
for op in ops:
# Reduce may make the tiling bad.
if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE:
return True
return False
suitable = True
if _contain_reduce(self.ops):
suitable = False
return suitable
@staticmethod
def k_mean(data, class_n=2, exclude_id=()):
"""
Find k clusters in which element is close to each other.
Args:
data (list): Elements' information.
class_n (int): Number of clusters wanted to be analyzed, default is 2.
exclude_id (tuple[int]): The list of excluded element's index, default is ().
Returns:
classes (list[list[int]]): The list of clusters. Each cluster is a list of indices.
"""
def _cal_mean(classes):
class_datas = [[data[cid] for cid in cls] for cls in classes]
return [sum(cls) / len(cls) if cls else float('inf') for cls in class_datas]
def _cal_distance(a, b):
return abs(a - b)
def _check_different(old_classes, new_classes):
for o, n in zip(old_classes, new_classes):
if o != n:
return True
return False
if len(data) < class_n:
return None
classes = []
for i, _ in enumerate(data):
if i in exclude_id:
continue
if len(classes) >= class_n:
break
classes.append([i])
changed = True
while changed:
new_classes = [[] for cls in classes]
means = _cal_mean(classes)
for idx, d in enumerate(data):
if idx in exclude_id:
continue
min_idx = -1
min_dis = float('inf')
for i, m in enumerate(means):
cur_dis = _cal_distance(m, d)
min_idx = i if min_dis > cur_dis else min_idx
min_dis = cur_dis if min_dis > cur_dis else min_dis
new_classes[min_idx].append(idx)
changed = _check_different(classes, new_classes)
classes = new_classes
return classes
@staticmethod
def pipeline_fusion_analyze(blocks, op_sizes, exclude_id):
"""analyze whether the segments can be pipeline optimized"""
# op size first, block second.
def _simple_factor(block, op_size):
return block + 5 * op_size
def _take_second(elem):
return elem[1]
simple_indicators = [_simple_factor(b, s)
for b, s in zip(blocks, op_sizes)]
# 2 classes, one heavy, the other light
classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id)
if not classes:
return []
means = [sum([simple_indicators[idx] for idx in cls]) /
len(cls) if cls else float('inf') for cls in classes]
# The target two clusters should be a heavy one and a light one.
# The light one maybe suitable to run with pipeline optimized.
classes_infos = [[cls, m] for cls, m in zip(classes, means)]
classes_infos.sort(key=_take_second)
pipeline_target = None
for ci in classes_infos:
if ci:
pipeline_target = ci
break
pipeline_gids, pipeline_mean = pipeline_target
if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks),
ScheduleAnalyzer.PIPELINE_OP_THREADHOLD):
return []
pipeline_blocks = []
pipeline_weight = len(pipeline_gids)
# Try to make two paralleled at least.
if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2:
if len(pipeline_gids[:pipeline_weight // 2]) > 1:
pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2])
if len(pipeline_gids[pipeline_weight // 2:]) > 1:
pipeline_blocks.append(pipeline_gids[pipeline_weight // 2:])
elif pipeline_weight > 1:
pipeline_blocks.append(pipeline_gids)
return pipeline_blocks
@staticmethod
def fusion_consult(blocks, op_sizes, exclude_gid):
"""get a recommendation for parallel fusion"""
# Default is block fusion
fusion_type = "block_fusion"
type_info = None
activate_pipeline_optimization = False # Disable pipeline optimization for now.
if activate_pipeline_optimization:
pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze(
blocks, op_sizes, exclude_gid)
if pipeline_info:
fusion_type = "block_pipeline_fusion"
type_info = pipeline_info
return fusion_type, type_info
def block_parallel_estimate(graphs): def block_parallel_estimate(graphs):
"""estimate block parallel gain""" """estimate block parallel gain"""
sum_block, max_weight, sum_weight, blocks = 0, 0, 0, [] sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], []
for g in graphs: for gid, g in enumerate(graphs):
s = ScheduleAnalyzer(g) s = ScheduleAnalyzer(g)
s.analyze() s.analyze()
sum_block += s.block_num sum_block += s.block_num
@ -144,9 +276,14 @@ def block_parallel_estimate(graphs):
max_weight = s.block_weight max_weight = s.block_weight
sum_weight += s.block_weight sum_weight += s.block_weight
blocks.append(s.block_num) blocks.append(s.block_num)
op_sizes.append(len(s.ops))
if not s.suitable_to_pipeline():
exclude_gid.append(gid)
if sum_block > ScheduleAnalyzer.MAX_SM * 32: if sum_block > ScheduleAnalyzer.MAX_SM * 32:
return ParalGain("none", sum_weight, 0, []) return ParalGain("none", sum_weight, 0, [0 for _ in graphs], None)
return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks)
fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid))
return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info)
def parallel_estimate(graphs): def parallel_estimate(graphs):

@ -28,10 +28,8 @@ def estimate_ops(json_str: str):
for gd in graph_descs: for gd in graph_descs:
graphs.append(model.load_composite(gd).graph) graphs.append(model.load_composite(gd).graph)
estimation = model.parallel_estimate(graphs) estimation = model.parallel_estimate(graphs)
if estimation.fusion_type == "block_fusion" and estimation.gain > 0: res = (estimation.block_assign, estimation.gain,
res = (estimation.block_assign, estimation.gain) estimation.fusion_type, estimation.type_info)
else:
res = ([0 for g in graphs], 0)
return res return res
except jd.JSONDecodeError: except jd.JSONDecodeError:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())

@ -557,30 +557,6 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
return true; return true;
} }
void AkgKernelJsonGenerator::SetParallelValueToJson(const std::string &processor,
const std::map<size_t, size_t> &dim_infos,
nlohmann::json *sub_fusion_json) {
if (processor == kProcessorCuda) {
std::vector<size_t> cnums;
std::transform(dim_infos.cbegin(), dim_infos.cend(), std::back_insert_iterator(cnums),
[](const std::pair<size_t, size_t> &dim) { return dim.second; });
(*sub_fusion_json)[kJsonKeyCoreNum] = cnums;
} else {
MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now.";
}
}
void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json) {
nlohmann::json parallel_fusion_json;
parallel_fusion_json[kJsonKeyFusionType] = "block_fusion";
std::vector<std::vector<std::string>> sgraphs;
std::transform(sub_graphs_.cbegin(), sub_graphs_.cend(), std::back_insert_iterator(sgraphs),
[](const std::pair<int, std::vector<std::string>> &sg) { return sg.second; });
parallel_fusion_json[kJsonKeySubGraph] = sgraphs;
SetParallelValueToJson(processor, dim_infos_, &parallel_fusion_json);
(*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json;
}
void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map, std::map<AnfNodePtr, nlohmann::json> *node_json_map,
nlohmann::json *kernel_json) { nlohmann::json *kernel_json) {
@ -633,12 +609,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyOutputDesc] = (*kernel_json)[kJsonKeyOutputDesc] =
CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map);
auto processor = GetProcessorStr(anf_nodes[0]);
// Add parallel fusion information. // Add parallel fusion information.
if (!sub_graphs_.empty()) { GenParallelJson(anf_nodes, input_list, output_list, node_json_map, kernel_json);
AddParalleFusionJsonInfo(processor, kernel_json);
}
size_t hash_id = std::hash<std::string>()(kernel_json->dump()); size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = "Fused_"; kernel_name_ = "Fused_";
@ -660,7 +632,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyId] = GetOpCntInc(); (*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_; (*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG"; (*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = processor; (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]);
(*kernel_json)[kJsonKeyComposite] = true; (*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
@ -755,6 +727,70 @@ nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNod
return inputs_json; return inputs_json;
} }
void AkgKernelJsonGenerator::GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list,
const std::map<AnfNodePtr, nlohmann::json> &node_json_map,
nlohmann::json *kernel_json) {
std::map<size_t, std::pair<size_t, std::vector<std::string>>> sub_graphs_info;
std::string fusion_type;
std::vector<std::vector<int>> type_info;
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
for (size_t i = 0; i < output_index.size(); ++i) {
auto [tmp_output, tmp_output_index] = output_index[i];
bool found = std::any_of(input_list.cbegin(), input_list.cend(),
[&tmp_output](const AnfNodePtr &in) { return tmp_output == in; });
if (!found) {
auto tcnode = tmp_output->cast<CNodePtr>();
if (tcnode == nullptr) {
return;
}
// Get dim info.
if (AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) {
auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo);
if (info.size() != 2) {
MS_LOG(EXCEPTION) << "Parallel dim info is invalid!";
}
auto tensor_name =
GetTensorName(node_json_map.at(tmp_output), kJsonKeyOutputDesc, std::make_pair(0, tmp_output_index));
sub_graphs_info[info[0]].second.push_back(tensor_name);
sub_graphs_info[info[0]].first = info[1];
}
// Get fusion type.
if (AnfAlgo::HasNodeAttr(kAttrParallelFusionType, tcnode)) {
fusion_type = AnfAlgo::GetNodeAttr<std::string>(tcnode, kAttrParallelFusionType);
}
// Get fusion type info.
if (AnfAlgo::HasNodeAttr(kAttrParallelTypeInfo, tcnode)) {
type_info = AnfAlgo::GetNodeAttr<std::vector<std::vector<int>>>(tcnode, kAttrParallelTypeInfo);
}
}
}
if (!sub_graphs_info.empty()) {
auto processor = GetProcessorStr(anf_nodes[0]);
if (processor != kProcessorCuda) {
MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now.";
}
nlohmann::json parallel_fusion_json;
parallel_fusion_json[kJsonKeyFusionType] = fusion_type;
parallel_fusion_json[kJsonKeyTypeInfo] = type_info;
std::vector<std::vector<std::string>> sgraphs;
std::vector<size_t> cnums;
std::for_each(sub_graphs_info.cbegin(), sub_graphs_info.cend(),
[&sgraphs, &cnums](const std::pair<size_t, std::pair<size_t, std::vector<std::string>>> &sg_info) {
sgraphs.push_back(sg_info.second.second);
cnums.push_back(sg_info.second.first);
});
parallel_fusion_json[kJsonKeySubGraph] = sgraphs;
parallel_fusion_json[kJsonKeyCoreNum] = cnums;
(*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json;
}
}
nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list, const std::vector<AnfNodePtr> &output_list,
@ -785,17 +821,6 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNo
output_shape.push_back(1); output_shape.push_back(1);
} }
output_desc_json[kJsonKeyShape] = output_shape; output_desc_json[kJsonKeyShape] = output_shape;
if (auto tcnode = tmp_output.first->cast<CNodePtr>();
tcnode && AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) {
auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo);
if (info.size() != 2) {
MS_LOG(EXCEPTION) << "Parallel dim info is invalid!";
}
sub_graphs_[info[0]].push_back(output_desc_json[kJsonKeyTensorName]);
if (dim_infos_.find(info[0]) == dim_infos_.end()) {
dim_infos_[info[0]] = info[1];
}
}
} }
outputs_json.emplace_back(output_desc_json); outputs_json.emplace_back(output_desc_json);
} }

@ -54,6 +54,7 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion";
constexpr auto kJsonKeyFusionType = "fusion_type"; constexpr auto kJsonKeyFusionType = "fusion_type";
constexpr auto kJsonKeySubGraph = "sub_graph"; constexpr auto kJsonKeySubGraph = "sub_graph";
constexpr auto kJsonKeyCoreNum = "core_num"; constexpr auto kJsonKeyCoreNum = "core_num";
constexpr auto kJsonKeyTypeInfo = "type_info";
constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
constexpr auto kJsonKeyStitchOp = "stitch_op"; constexpr auto kJsonKeyStitchOp = "stitch_op";
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
@ -89,8 +90,6 @@ class AkgKernelJsonGenerator {
input_tensor_idx_.clear(); input_tensor_idx_.clear();
address_node_map_.clear(); address_node_map_.clear();
output_tensor_idx_ = 0; output_tensor_idx_ = 0;
sub_graphs_.clear();
dim_infos_.clear();
} }
void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; } void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; }
std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; } std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; }
@ -127,9 +126,10 @@ class AkgKernelJsonGenerator {
std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index); std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index);
void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json); void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json);
OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node); OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node);
void SetParallelValueToJson(const std::string &processor, const std::map<size_t, size_t> &dim_infos, void CollectParallelDimInfo(const AnfNodePtr &anf_node);
nlohmann::json *sub_fusion_json); void GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
void AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json); const std::vector<AnfNodePtr> &output_list,
const std::map<AnfNodePtr, nlohmann::json> &node_json_map, nlohmann::json *kernel_json);
DumpOption dump_option_; DumpOption dump_option_;
static int op_cnt_; static int op_cnt_;
@ -142,8 +142,6 @@ class AkgKernelJsonGenerator {
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::map<std::string, AnfNodePtr> address_node_map_; std::map<std::string, AnfNodePtr> address_node_map_;
std::map<size_t, std::vector<std::string>> sub_graphs_;
std::map<size_t, size_t> dim_infos_;
bool is_basic_op_{false}; bool is_basic_op_{false};
}; };
} // namespace kernel } // namespace kernel

@ -60,8 +60,9 @@ KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_ge
return cached_kernel_pack; return cached_kernel_pack;
} }
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
auto kernel_json = json_generator.kernel_json_str(); auto kernel_json = json_generator.kernel_json_str();
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json); auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json);
(void)alarm(0); (void)alarm(0);
if (!res) { if (!res) {
@ -70,7 +71,6 @@ KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_ge
} }
auto new_kernel_pack = InsertCache(kernel_name, processor); auto new_kernel_pack = InsertCache(kernel_name, processor);
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
if (new_kernel_pack == nullptr) { if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "]."; << anf_node->fullname_with_scope() << "].";

@ -47,7 +47,7 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
return py::cast<int>(ret); return py::cast<int>(ret);
} }
std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) {
nlohmann::json json_desc; nlohmann::json json_desc;
std::vector<AnfNodePtrList> graphs; std::vector<AnfNodePtrList> graphs;
std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs),
@ -65,7 +65,7 @@ std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const An
} }
py::tuple ret_tuple = py::cast<py::tuple>(ret); py::tuple ret_tuple = py::cast<py::tuple>(ret);
if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 2) { if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 4) {
MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!"; MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!";
} }
@ -75,8 +75,41 @@ std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const An
dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i]))); dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i])));
} }
int benefit = py::cast<int>(ret_tuple[1]); int benefit = py::cast<int>(ret_tuple[1]);
auto fusion_info = ProcessFusionInfo(ret_tuple[2], ret_tuple[3]);
return std::make_tuple(dim_infos, benefit); return std::make_tuple(dim_infos, benefit, fusion_info);
}
FusionInfoPtr ParallelCostModel::ProcessFusionInfo(py::object fusion_type, py::object type_info) {
if (!py::isinstance<py::str>(fusion_type)) {
MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!";
}
std::string fusion_type_name = py::cast<std::string>(fusion_type);
FusionInfoPtr fusion_info;
if (fusion_type_name == "block_fusion") {
fusion_info = std::make_shared<BlockFusionInfo>();
} else if (fusion_type_name == "block_pipeline_fusion") {
if (!py::isinstance<py::list>(type_info)) {
MS_LOG(EXCEPTION) << "Fusion type info for block pipe fusion type is invalid!";
}
std::vector<std::vector<int>> pipeline_ids;
py::list pipeline_ids_list = py::cast<py::list>(type_info);
for (size_t i = 0; i < pipeline_ids_list.size(); ++i) {
std::vector<int> part_ids;
py::list inner_ids_list = py::cast<py::list>(pipeline_ids_list[i]);
for (size_t j = 0; j < inner_ids_list.size(); ++j) {
part_ids.push_back(py::cast<int>(inner_ids_list[j]));
}
pipeline_ids.push_back(part_ids);
}
fusion_info = std::make_shared<BlockPipelineFusionInfo>(pipeline_ids);
} else {
MS_LOG(EXCEPTION) << "Unsupported parallel fusion type: " << fusion_type_name;
}
return fusion_info;
} }
ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) { ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) {

@ -29,6 +29,7 @@
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/graph_kernel/parallel_cost_model.h" #include "backend/optimizer/graph_kernel/parallel_cost_model.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
@ -55,12 +56,50 @@ class CommonDimInfo : public DimInfo {
using DimInfoPtr = std::shared_ptr<DimInfo>; using DimInfoPtr = std::shared_ptr<DimInfo>;
using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>; using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>;
class FusionInfo {
public:
FusionInfo() = default;
explicit FusionInfo(const std::string &type) : fusion_type_(type) {}
~FusionInfo() = default;
std::string FusionType() { return fusion_type_; }
virtual bool ExistTypeInfo() { return false; }
private:
std::string fusion_type_{"none"};
};
class BlockFusionInfo : public FusionInfo {
public:
BlockFusionInfo() : FusionInfo("block_fusion") {}
~BlockFusionInfo() = default;
bool ExistTypeInfo() { return false; }
};
class BlockPipelineFusionInfo : public FusionInfo {
public:
explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids)
: FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {}
~BlockPipelineFusionInfo() = default;
bool ExistTypeInfo() { return true; }
std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; }
private:
std::vector<std::vector<int>> pipeline_ids_;
};
using FusionInfoPtr = std::shared_ptr<FusionInfo>;
using BlockFusionInfoPtr = std::shared_ptr<BlockFusionInfo>;
using BlockPipelineFusionInfoPtr = std::shared_ptr<BlockPipelineFusionInfo>;
class ParallelCostModel { class ParallelCostModel {
public: public:
ParallelCostModel() {} ParallelCostModel() {}
~ParallelCostModel() {} ~ParallelCostModel() {}
int GetNodeCalAmount(const AnfNodePtr &node); int GetNodeCalAmount(const AnfNodePtr &node);
std::tuple<std::vector<DimInfoPtr>, int> CalFuseInfo(const AnfNodePtrList &nodes); std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes);
private:
FusionInfoPtr ProcessFusionInfo(py::object fusion_type, py::object type_info);
}; };
using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>;

@ -553,7 +553,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea
std::tie(other_candidates, std::ignore) = std::tie(other_candidates, std::ignore) =
GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>());
int benefit; int benefit;
std::tie(std::ignore, benefit) = cost_model_ptr_->CalFuseInfo(other_candidates); std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
if (benefit > 0) { if (benefit > 0) {
begin = mid + 1; begin = mid + 1;
} else { } else {
@ -567,12 +567,12 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea
AnfNodePtrList other_candidates; AnfNodePtrList other_candidates;
std::tie(other_candidates, std::ignore) = std::tie(other_candidates, std::ignore) =
GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>());
auto [dim_infos, benefit] = cost_model_ptr_->CalFuseInfo(other_candidates); auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
if (benefit <= 0) { if (benefit <= 0) {
MS_LOG(EXCEPTION) << "Internal error in candidate search!"; MS_LOG(EXCEPTION) << "Internal error in candidate search!";
} }
max_benefit = benefit; max_benefit = benefit;
best_parallel_info = ParallelInfo(other_candidates, dim_infos); best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
i += begin - 1; i += begin - 1;
} }
@ -676,10 +676,13 @@ std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
} }
void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info) { void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info) {
AnfNodePtr attach_node;
// Dim info should be attach to each segment's output.
for (size_t i = 0; i < parallel_info.GetSize(); ++i) { for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
const auto &fuse_nodes = parallel_info.nodes(); const auto &fuse_nodes = parallel_info.nodes();
std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()}; std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) { if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
attach_node = fuse_nodes[i];
SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]); SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
} else { } else {
auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0)); auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
@ -689,11 +692,16 @@ void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &pa
for (size_t j = 1; j < inputs.size(); ++j) { for (size_t j = 1; j < inputs.size(); ++j) {
SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]); SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
} }
attach_node = inputs[1];
} else { } else {
attach_node = out_node;
SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node); SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
} }
} }
} }
// Fusion info is ok to attach to one of the segments.
SetFusionInfoAttrToNode(attach_node, parallel_info);
} }
void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) { void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) {
@ -741,6 +749,17 @@ void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_pt
} }
} }
void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info) {
auto fusion_type = parallel_info.fusion_info()->FusionType();
AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
if (parallel_info.fusion_info()->ExistTypeInfo()) {
if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
}
}
}
bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos, bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
const std::shared_ptr<session::KernelGraph> &kernel_graph) { const std::shared_ptr<session::KernelGraph> &kernel_graph) {
bool changed = false; bool changed = false;
@ -755,6 +774,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
AnfNodePtr sg_node; AnfNodePtr sg_node;
std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
PostProcessForNewSubGraphCNode(sg_node, kernel_graph); PostProcessForNewSubGraphCNode(sg_node, kernel_graph);
AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
DumpParallelFusionDetail(fuse_nodes, sg_node); DumpParallelFusionDetail(fuse_nodes, sg_node);
} }

@ -37,10 +37,12 @@ namespace opt {
class ParallelInfo { class ParallelInfo {
public: public:
ParallelInfo() = default; ParallelInfo() = default;
ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims) : nodes_(nodes), dims_(dims) {} ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info)
: nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {}
ParallelInfo(const ParallelInfo &obj) { ParallelInfo(const ParallelInfo &obj) {
nodes_ = obj.nodes_; nodes_ = obj.nodes_;
dims_ = obj.dims_; dims_ = obj.dims_;
fusion_info_ = obj.fusion_info_;
} }
~ParallelInfo() = default; ~ParallelInfo() = default;
@ -52,10 +54,12 @@ class ParallelInfo {
} }
const AnfNodePtrList &nodes() const { return nodes_; } const AnfNodePtrList &nodes() const { return nodes_; }
const std::vector<DimInfoPtr> &dims() const { return dims_; } const std::vector<DimInfoPtr> &dims() const { return dims_; }
const FusionInfoPtr &fusion_info() const { return fusion_info_; }
private: private:
AnfNodePtrList nodes_; AnfNodePtrList nodes_;
std::vector<DimInfoPtr> dims_; std::vector<DimInfoPtr> dims_;
FusionInfoPtr fusion_info_;
}; };
class ParallelConfig { class ParallelConfig {
@ -102,6 +106,8 @@ class ParallelOpFusion : public Pass {
std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups); std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups);
void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info);
void SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info); void SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info);
bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos, bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,

@ -397,6 +397,9 @@ constexpr auto kAttrIsGrad = "is_grad";
constexpr auto kAttrRecompute = "recompute"; constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
constexpr auto kAttrParallelFusionType = "parallel_fusion_type";
constexpr auto kAttrParallelTypeInfo = "parallel_type_info";
constexpr auto kAttrCompositeType = "composite_type";
constexpr auto kAttrStitch = "stitch"; constexpr auto kAttrStitch = "stitch";
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
constexpr auto kAttrSwitchLayer = "switch_layer"; constexpr auto kAttrSwitchLayer = "switch_layer";

Loading…
Cancel
Save