From 7d09dff88094e1c17fd085b78e95b294415a92cd Mon Sep 17 00:00:00 2001 From: baihuawei Date: Tue, 1 Dec 2020 21:00:20 +0800 Subject: [PATCH] add hccl send recv --- .../gpu/control/recv_gpu_kernel.cc | 2 +- .../gpu/control/send_gpu_kernel.cc | 2 +- .../kernel_compiler/hccl/hccl_kernel.cc | 46 +++++++++++++++---- .../kernel_compiler/hccl/hccl_kernel.h | 1 + .../hccl/hccl_kernel_metadata.cc | 6 ++- .../kernel_compiler/hccl/hcom_receive.cc | 29 ++++++++++++ .../kernel_compiler/hccl/hcom_receive.h | 42 +++++++++++++++++ .../backend/kernel_compiler/hccl/hcom_send.cc | 29 ++++++++++++ .../backend/kernel_compiler/hccl/hcom_send.h | 41 +++++++++++++++++ .../backend/kernel_compiler/hccl/hcom_util.cc | 21 ++++++++- .../backend/kernel_compiler/hccl/hcom_util.h | 3 ++ .../ccsrc/backend/kernel_compiler/rts/recv.h | 2 +- .../ccsrc/backend/kernel_compiler/rts/send.h | 2 +- .../optimizer/mem_reuse/mem_reuse_checker.cc | 2 +- .../optimizer/mem_reuse/mem_reuse_checker.h | 2 - .../backend/session/anf_runtime_algorithm.cc | 2 +- .../ccsrc/runtime/hccl_adapter/converter.cc | 24 ++++++++++ mindspore/ccsrc/utils/utils.h | 11 ++++- mindspore/core/base/core_ops.h | 1 + 19 files changed, 246 insertions(+), 22 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc index a89d4e9baf..88ccf5496c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc @@ -18,6 +18,6 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel) +MS_REG_GPU_KERNEL_REGULAR(StreamRecv, KernelAttr(), RecvGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc index 946038bb18..0d073d287e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc @@ -18,6 +18,6 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel) +MS_REG_GPU_KERNEL_REGULAR(StreamSend, KernelAttr(), SendGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 9ad10704f0..698549e96b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -32,6 +32,8 @@ static std::map kMsOpNameToHcomHcclType = { {mindspore::kAllReduceOpName, mindspore::kHcomOpTypeAllReduce}, {mindspore::kAllGatherOpName, mindspore::kHcomOpTypeAllGather}, {mindspore::kBroadcastOpName, mindspore::kHcomOpTypeBroadcast}, + {mindspore::kHcomSendOpName, mindspore::kHcomOpTypeSend}, + {mindspore::kReceiveOpName, mindspore::kHcomOpTypeReceive}, {mindspore::kReduceScatterOpName, mindspore::kHcomOpTypeReduceScatter}}; std::string MsOpNameToHcomOpType(const std::string &ms_op_type) { auto iter = kMsOpNameToHcomHcclType.find(ms_op_type); @@ -80,7 +82,12 @@ HcclKernel::~HcclKernel() { bool HcclKernel::Init(const AnfNodePtr &anf_node) { MS_EXCEPTION_IF_NULL(anf_node); op_name_ = AnfAlgo::GetCNodeName(anf_node); - + if (op_name_ == kReceive) { + if (!HcomUtil::GetHcomReceiveType(anf_node, &receive_type_)) { + MS_LOG(ERROR) << "GetHcomReceiveType fail!"; + return false; + } + } if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { MS_LOG(ERROR) << "GetKernelInputShape fail!"; return false; @@ -89,13 +96,27 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { MS_LOG(ERROR) << "GetKernelOutputShape fail!"; return false; } - if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { + if (op_name_ == kReceive) { + auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); + if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { + MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_; + return false; + } + hccl_data_type_list_.emplace_back(iter->second); + } else if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { MS_LOG(ERROR) << "GetHcomDataType fail!"; return false; } - if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { - MS_LOG(ERROR) << "GetHcomCount fail!"; - return false; + if (op_name_ == kReceive) { + if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_output_shape_list_, &hccl_count_)) { + MS_LOG(ERROR) << "GetHcomCount fail!"; + return false; + } + } else { + if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { + MS_LOG(ERROR) << "GetHcomCount fail!"; + return false; + } } if (op_name_ == kAllReduce || op_name_ == kReduceScatter) { if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) { @@ -146,17 +167,24 @@ const std::vector &HcclKernel::GetWorkspaceSizeList() const { return wor std::vector HcclKernel::GenTask(const std::vector &inputs, const std::vector &, const std::vector &outputs, uint32_t stream_id) { - if (inputs.empty() || outputs.empty()) { + std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); + if (hccl_type == kReceive) { + if (outputs.empty()) { + MS_LOG(EXCEPTION) << "Outputs is empty"; + } + } else if (inputs.empty() || outputs.empty()) { MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; } stream_id_ = stream_id; - MS_EXCEPTION_IF_NULL(inputs.at(0)); - auto input_data_addr = inputs.at(0)->addr; + void *input_data_addr = nullptr; + if (hccl_type != kReceive) { + MS_EXCEPTION_IF_NULL(inputs.at(0)); + input_data_addr = inputs.at(0)->addr; + } MS_EXCEPTION_IF_NULL(outputs.at(0)); auto output_data_addr = outputs.at(0)->addr; std::vector private_def; HcclDataType data_type = hccl_data_type_list_[0]; - std::vector task_info; bool ret = hccl::GenTask(anf_node_, data_type, &task_info); if (!ret) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h index 2ba888ecee..9f2780a51c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -51,6 +51,7 @@ class HcclKernel : public AscendKernelMod { uint64_t hccl_count_; HcclReduceOp op_type_; uint32_t root_id_; + int64_t receive_type_; mutable std::vector input_size_list_; mutable std::vector output_size_list_; mutable std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc index 4b8654ed84..f1bde11cec 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -33,6 +33,9 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { return kOpFormat_DEFAULT; } + if (op_name == kReceive || op_name == kHcomSend) { + return kOpFormat_DEFAULT; + } auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); if (op_name != kReduceScatter && op_name != kAllGatherOpName) { return format; @@ -52,7 +55,8 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector +#include "utils/ms_context.h" +namespace mindspore { +namespace kernel { +bool HcomReceiveKernel::Launch(const std::vector & /*inputs*/, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "HcomReceive launch"; + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.h new file mode 100644 index 0000000000..e514e60e7c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_receive.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_ + +#include +#include +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomReceiveKernel : public HcclKernel { + public: + HcomReceiveKernel() = default; + ~HcomReceiveKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; + +MS_HCCL_REG_KERNEL(Receive, HcomReceiveKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.cc new file mode 100644 index 0000000000..57f31dcdec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_send.h" +#include +#include "utils/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomSendKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "HcomSend launch"; + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.h new file mode 100644 index 0000000000..f10aa7f763 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_send.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_ + +#include +#include +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomSendKernel : public HcclKernel { + public: + HcomSendKernel() = default; + ~HcomSendKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; + +MS_HCCL_REG_KERNEL(Send, HcomSendKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc index 8d1836f3b8..4477d4e4b7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -102,8 +102,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vectorGetAttr("dtype") != nullptr) { + *receive_type = (int64_t)(GetValue(primitive->GetAttr("dtype"))->type_id()); + } else { + MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!"; + return false; + } + return true; +} + void HcomUtil::GetHcomGroup(NotNull anf_node, NotNull group) { auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h index 3e1843561a..a2d43c9d2f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h @@ -34,6 +34,8 @@ using std::vector; constexpr auto kAllGather = "AllGather"; constexpr auto kAllReduce = "AllReduce"; constexpr auto kBroadcast = "Broadcast"; +constexpr auto kHcomSend = "Send"; +constexpr auto kReceive = "Receive"; constexpr auto kReduceScatter = "ReduceScatter"; /* Correspondence between data_type and hcom data type in Ascend */ @@ -64,6 +66,7 @@ class HcomUtil { static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type); static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); static void GetHcomGroup(NotNull anf_node, NotNull group); + static bool GetHcomReceiveType(const AnfNodePtr &anf_node, int64_t *receive_type); }; } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h index 13dd91d55e..1ffe6ad83f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h @@ -39,7 +39,7 @@ class RecvKernel : public RtKernel { uint32_t event_id_; }; -MS_REG_RTKERNEL(recv, RecvKernel); +MS_REG_RTKERNEL(streamrecv, RecvKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.h b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h index 6550a3b11a..ff110cc6d3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/send.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h @@ -37,7 +37,7 @@ class SendKernel : public RtKernel { uint32_t event_id_; }; -MS_REG_RTKERNEL(send, SendKernel); +MS_REG_RTKERNEL(streamsend, SendKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc index 81dc3f8ba0..5d5b46f998 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc @@ -29,7 +29,7 @@ MemReuseChecker &MemReuseChecker::GetInstance() { void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) { std::string node_name = AnfAlgo::GetCNodeName(c_node); - if (node_name == kSend || node_name == kRecv) { + if (node_name == kSendOpName || node_name == kRecvOpName) { MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send"; // get op's info && check MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node) diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h index c5a5a128a1..86aac4ccef 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h @@ -29,8 +29,6 @@ #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" namespace mindspore { namespace memreuse { -constexpr auto kSend = "Send"; -constexpr auto kRecv = "Recv"; constexpr auto kSplitC = '/'; class MemReuseChecker { public: diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 6ca596a6a3..c2237b36b8 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1132,7 +1132,7 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { } auto kernel_name = AnfAlgo::GetCNodeName(node); if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || - kernel_name == kReduceScatterOpName) { + kernel_name == kReduceScatterOpName || kernel_name == kHcomSendOpName || kernel_name == kReceiveOpName) { return true; } return false; diff --git a/mindspore/ccsrc/runtime/hccl_adapter/converter.cc b/mindspore/ccsrc/runtime/hccl_adapter/converter.cc index 432093a641..d01fdd0930 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/converter.cc +++ b/mindspore/ccsrc/runtime/hccl_adapter/converter.cc @@ -28,6 +28,8 @@ #include "mindspore/core/base/core_ops.h" #include "transform/graph_ir/util.h" +static constexpr char kGeOpNameHcclSend[] = "HcomSend"; +static constexpr char kGeOpNameHcclReceive[] = "HcomReceive"; static constexpr char kGeOpNameHcclAllRudece[] = "HcomAllReduce"; static constexpr char kGeOpNameHcclAllGather[] = "HcomAllGather"; static constexpr char kGeOpNameHcclBroadcast[] = "HcomBroadcast"; @@ -63,6 +65,18 @@ struct IsString { static constexpr bool value = true; }; +template +struct IsVector { + // cppcheck-suppress unusedStructMember + static constexpr bool value = false; +}; + +template <> +struct IsVector> { + // cppcheck-suppress unusedStructMember + static constexpr bool value = true; +}; + namespace mindspore::hccl { template static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name, @@ -78,6 +92,8 @@ static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const st auto attr = AnfAlgo::GetNodeAttr(cnode, anf_attr_name); if constexpr (IsString::value) { ret = ge::AttrUtils::SetStr(*ge_op, ge_attr_name, attr); + } else if constexpr (IsVector::value) { + ret = ge::AttrUtils::SetListInt(*ge_op, ge_attr_name, attr); } else { ret = ge::AttrUtils::SetInt(*ge_op, ge_attr_name, attr); } @@ -99,6 +115,10 @@ std::string GetGeNodeName(const CNodePtr &cnode) { return kGeOpNameHcclBroadcast; } else if (IsPrimitiveCNode(cnode, prim::kPrimReduceScatter)) { return kGeOpNameHcclReduceScatter; + } else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) { + return kGeOpNameHcclSend; + } else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) { + return kGeOpNameHcclReceive; } MS_LOG(EXCEPTION) << "Unknown hccl node type " << cnode->DebugString(); @@ -133,6 +153,10 @@ std::tuple GenerateStubGeNode(const AnfNodePtr // set node attr (void)ConvertAttr(cnode, op_desc, kAttrRankSize, ge::HCOM_ATTR_RANK_SIZE); (void)ConvertAttr(cnode, op_desc, kAttrGroup, ge::HCOM_ATTR_GROUP); + (void)ConvertAttr(cnode, op_desc, kAttrSrcRank, ge::HCOM_ATTR_SRC_RANK); + (void)ConvertAttr(cnode, op_desc, kAttrDestRank, ge::HCOM_ATTR_DEST_RANK); + (void)ConvertAttr(cnode, op_desc, kAttrSrTag, ge::HCOM_ATTR_SR_TAG); + (void)ConvertAttr>(cnode, op_desc, kAttrShape, ge::HCOM_ATTR_SHAPE); ge::ComputeGraphPtr ge_graph = std::make_shared(kStubDataStructureName); MS_EXCEPTION_IF_NULL(ge_graph); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 89bdc00065..54c20b1700 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -57,6 +57,8 @@ constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllGatherOpName = "AllGather"; constexpr auto kHostAllGatherOpName = "HostAllGather"; constexpr auto kBroadcastOpName = "Broadcast"; +constexpr auto kReceiveOpName = "Receive"; +constexpr auto kHcomSendOpName = "Send"; constexpr auto kReduceScatterOpName = "ReduceScatter"; constexpr auto kHostReduceScatterOpName = "HostReduceScatter"; constexpr auto kMemCpyAsyncOpName = "memcpy_async"; @@ -142,8 +144,8 @@ constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; constexpr auto kStreamSwitchOpName = "StreamSwitch"; constexpr auto kStreamActiveOpName = "StreamActive"; constexpr auto kAssignAddOpName = "AssignAdd"; -constexpr auto kSendOpName = "Send"; -constexpr auto kRecvOpName = "Recv"; +constexpr auto kSendOpName = "StreamSend"; +constexpr auto kRecvOpName = "StreamRecv"; constexpr auto kReluV2OpName = "ReLUV2"; constexpr auto kReluGradV2OpName = "ReluGradV2"; constexpr auto kAddNOpName = "AddN"; @@ -248,6 +250,8 @@ constexpr auto kBroadcastToOpName = "BroadcastTo"; constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; constexpr auto kHcomOpTypeAllGather = "HcomAllGather"; constexpr auto kHcomOpTypeBroadcast = "HcomBroadcast"; +constexpr auto kHcomOpTypeSend = "HcomSend"; +constexpr auto kHcomOpTypeReceive = "HcomReceive"; constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter"; // attr key name @@ -292,6 +296,9 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active constexpr auto kAttrFusion = "fusion"; constexpr auto kAttrGroup = "group"; constexpr auto kAttrOp = "op"; +constexpr auto kAttrDestRank = "dest_rank"; +constexpr auto kAttrSrcRank = "src_rank"; +constexpr auto kAttrSrTag = "sr_tag"; constexpr auto kAttrRootRank = "root_rank"; constexpr auto kAttrIsTraining = "is_training"; constexpr auto kAttrFusionId = "fusion_id"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index a743e18ede..d3f2b0c048 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -193,6 +193,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared("SGD"); inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +inline const PrimitivePtr kPrimSend = std::make_shared("Send"); inline const PrimitivePtr kPrimReceive = std::make_shared("Receive"); inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); inline const PrimitivePtr kPrimAllSwap = std::make_shared("AllSwap");