!2247 Synchronize Ascend software suite 17 Jun 2020

Merge pull request !2247 from yanghaoran/master
pull/2247/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit efa61b061e

@ -1 +1 @@
Subproject commit 45ca7863ac6410c8e2f83168481ddc6b43bcea33
Subproject commit 1350673d51b3f8535bc217a7780e6a0b52ff9a41

@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG
}
}
void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr);
auto cnode_ptr_list = graph_ptr->execution_order();
vector<uint32_t> fusion_hcom_index;
vector<CNodePtr> orders;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (IsHcom(cur_cnode)) {
fusion_hcom_index.emplace_back(i);
}
}
if (fusion_hcom_index.size() < 2) {
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
return;
}
uint32_t first_index = fusion_hcom_index[0];
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];
uint32_t cur_event_id = total_event_num_;
uint32_t pre_hcom_stream_id = UINT32_MAX;
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
for (size_t i = first_index; i <= last_index; i++) {
auto cur_cnode = cnode_ptr_list[i];
auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i);
if (it == fusion_hcom_index.end()) {
orders.emplace_back(cur_cnode);
continue;
}
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
if (cur_hcom_stream_id == pre_hcom_stream_id) {
orders.emplace_back(cur_cnode);
continue;
}
if (i == first_index) {
// first fusion hcom
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send);
} else if (i == last_index) {
// last fusion hcom
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
orders.emplace_back(cur_cnode);
cur_event_id++;
} else {
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
cur_event_id++;
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send);
}
pre_hcom_stream_id = cur_hcom_stream_id;
}
std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
graph_ptr->set_execution_order(orders);
total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]";
MS_LOG(INFO) << "end";
}
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr);
@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor
graph_ptr->set_execution_order(cnodes);
total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]";
// Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2)
InsertSendRecvForDiffHcom(graph_ptr);
MS_LOG(INFO) << "end";
}

@ -95,6 +95,7 @@ class AscendStreamAssign {
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);

@ -23,16 +23,26 @@ matmul_op_info = TBERegOp("MatMul") \
.compute_cost(10) \
.kernel_name("matmul") \
.partial_flag(True) \
.attr("transpose_a", "required", "bool", "all") \
.attr("transpose_b", "required", "bool", "all") \
.attr("transpose_x1", "required", "bool", "all") \
.attr("transpose_x2", "required", "bool", "all") \
.attr("offset_x", "optional", "int", "all") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "optional", "all") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default,
DataType.I32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
DataType.F32_FracNZ) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I8_Default,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I8_Default,
DataType.F32_Default) \
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, DataType.I8_Default,
DataType.I32_NHWC) \
.get_op_info()

Loading…
Cancel
Save