add hccl send recv

pull/8870/head
baihuawei 4 years ago
parent 55cc959ac7
commit 7d09dff880

@ -18,6 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel)
MS_REG_GPU_KERNEL_REGULAR(StreamRecv, KernelAttr(), RecvGpuKernel)
} // namespace kernel
} // namespace mindspore

@ -18,6 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel)
MS_REG_GPU_KERNEL_REGULAR(StreamSend, KernelAttr(), SendGpuKernel)
} // namespace kernel
} // namespace mindspore

@ -32,6 +32,8 @@ static std::map<std::string, std::string> kMsOpNameToHcomHcclType = {
{mindspore::kAllReduceOpName, mindspore::kHcomOpTypeAllReduce},
{mindspore::kAllGatherOpName, mindspore::kHcomOpTypeAllGather},
{mindspore::kBroadcastOpName, mindspore::kHcomOpTypeBroadcast},
{mindspore::kHcomSendOpName, mindspore::kHcomOpTypeSend},
{mindspore::kReceiveOpName, mindspore::kHcomOpTypeReceive},
{mindspore::kReduceScatterOpName, mindspore::kHcomOpTypeReduceScatter}};
std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
auto iter = kMsOpNameToHcomHcclType.find(ms_op_type);
@ -80,7 +82,12 @@ HcclKernel::~HcclKernel() {
bool HcclKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
op_name_ = AnfAlgo::GetCNodeName(anf_node);
if (op_name_ == kReceive) {
if (!HcomUtil::GetHcomReceiveType(anf_node, &receive_type_)) {
MS_LOG(ERROR) << "GetHcomReceiveType fail!";
return false;
}
}
if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) {
MS_LOG(ERROR) << "GetKernelInputShape fail!";
return false;
@ -89,13 +96,27 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
MS_LOG(ERROR) << "GetKernelOutputShape fail!";
return false;
}
if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) {
if (op_name_ == kReceive) {
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_;
return false;
}
hccl_data_type_list_.emplace_back(iter->second);
} else if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) {
MS_LOG(ERROR) << "GetHcomDataType fail!";
return false;
}
if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) {
MS_LOG(ERROR) << "GetHcomCount fail!";
return false;
if (op_name_ == kReceive) {
if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_output_shape_list_, &hccl_count_)) {
MS_LOG(ERROR) << "GetHcomCount fail!";
return false;
}
} else {
if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) {
MS_LOG(ERROR) << "GetHcomCount fail!";
return false;
}
}
if (op_name_ == kAllReduce || op_name_ == kReduceScatter) {
if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) {
@ -146,17 +167,24 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return wor
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
if (inputs.empty() || outputs.empty()) {
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
if (hccl_type == kReceive) {
if (outputs.empty()) {
MS_LOG(EXCEPTION) << "Outputs is empty";
}
} else if (inputs.empty() || outputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs or outputs is empty";
}
stream_id_ = stream_id;
MS_EXCEPTION_IF_NULL(inputs.at(0));
auto input_data_addr = inputs.at(0)->addr;
void *input_data_addr = nullptr;
if (hccl_type != kReceive) {
MS_EXCEPTION_IF_NULL(inputs.at(0));
input_data_addr = inputs.at(0)->addr;
}
MS_EXCEPTION_IF_NULL(outputs.at(0));
auto output_data_addr = outputs.at(0)->addr;
std::vector<uint8_t> private_def;
HcclDataType data_type = hccl_data_type_list_[0];
std::vector<hccl::HcclTaskInfo> task_info;
bool ret = hccl::GenTask(anf_node_, data_type, &task_info);
if (!ret) {

@ -51,6 +51,7 @@ class HcclKernel : public AscendKernelMod {
uint64_t hccl_count_;
HcclReduceOp op_type_;
uint32_t root_id_;
int64_t receive_type_;
mutable std::vector<size_t> input_size_list_;
mutable std::vector<size_t> output_size_list_;
mutable std::vector<size_t> workspace_size_list_;

@ -33,6 +33,9 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) {
return kOpFormat_DEFAULT;
}
if (op_name == kReceive || op_name == kHcomSend) {
return kOpFormat_DEFAULT;
}
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
return format;
@ -52,7 +55,8 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
MS_EXCEPTION_IF_NULL(kernel_info_list);
MS_EXCEPTION_IF_NULL(kernel_node);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) {
if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter &&
op_name != kHcomSend && op_name != kReceive) {
MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]";
return;
}

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/hccl/hcom_receive.h"
#include <memory>
#include "utils/ms_context.h"
namespace mindspore {
namespace kernel {
bool HcomReceiveKernel::Launch(const std::vector<AddressPtr> & /*inputs*/,
const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
MS_LOG(INFO) << "HcomReceive launch";
return true;
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_
#include <memory>
#include <vector>
#include "backend/kernel_compiler/hccl/hccl_kernel.h"
namespace mindspore {
namespace kernel {
class HcomReceiveKernel : public HcclKernel {
public:
HcomReceiveKernel() = default;
~HcomReceiveKernel() override = default;
/* Inherit from kernelmod */
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
private:
};
MS_HCCL_REG_KERNEL(Receive, HcomReceiveKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/hccl/hcom_send.h"
#include <memory>
#include "utils/ms_context.h"
namespace mindspore {
namespace kernel {
bool HcomSendKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
MS_LOG(INFO) << "HcomSend launch";
return true;
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_
#include <memory>
#include <vector>
#include "backend/kernel_compiler/hccl/hccl_kernel.h"
namespace mindspore {
namespace kernel {
class HcomSendKernel : public HcclKernel {
public:
HcomSendKernel() = default;
~HcomSendKernel() override = default;
/* Inherit from kernelmod */
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
private:
};
MS_HCCL_REG_KERNEL(Send, HcomSendKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_

@ -102,8 +102,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
uint64_t block_size;
size_t input_size;
uint32_t type_size = 4;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
size_t size = AnfAlgo::GetInputTensorNum(anf_node);
if (AnfAlgo::GetCNodeName(anf_node) == kReceiveOpName) {
size = AnfAlgo::GetOutputTensorNum(anf_node);
}
for (size_t i = 0; i < size; ++i) {
if (!GetHcomTypeSize(data_type_list[i], &type_size)) {
return false;
}
@ -183,6 +186,20 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
return true;
}
bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, int64_t *receive_type) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(receive_type);
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("dtype") != nullptr) {
*receive_type = (int64_t)(GetValue<NumberPtr>(primitive->GetAttr("dtype"))->type_id());
} else {
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!";
return false;
}
return true;
}
void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);

@ -34,6 +34,8 @@ using std::vector;
constexpr auto kAllGather = "AllGather";
constexpr auto kAllReduce = "AllReduce";
constexpr auto kBroadcast = "Broadcast";
constexpr auto kHcomSend = "Send";
constexpr auto kReceive = "Receive";
constexpr auto kReduceScatter = "ReduceScatter";
/* Correspondence between data_type and hcom data type in Ascend */
@ -64,6 +66,7 @@ class HcomUtil {
static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type);
static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
static bool GetHcomReceiveType(const AnfNodePtr &anf_node, int64_t *receive_type);
};
} // namespace mindspore

@ -39,7 +39,7 @@ class RecvKernel : public RtKernel {
uint32_t event_id_;
};
MS_REG_RTKERNEL(recv, RecvKernel);
MS_REG_RTKERNEL(streamrecv, RecvKernel);
} // namespace kernel
} // namespace mindspore

@ -37,7 +37,7 @@ class SendKernel : public RtKernel {
uint32_t event_id_;
};
MS_REG_RTKERNEL(send, SendKernel);
MS_REG_RTKERNEL(streamsend, SendKernel);
} // namespace kernel
} // namespace mindspore

@ -29,7 +29,7 @@ MemReuseChecker &MemReuseChecker::GetInstance() {
void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) {
std::string node_name = AnfAlgo::GetCNodeName(c_node);
if (node_name == kSend || node_name == kRecv) {
if (node_name == kSendOpName || node_name == kRecvOpName) {
MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send";
// get op's info && check
MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node)

@ -29,8 +29,6 @@
#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h"
namespace mindspore {
namespace memreuse {
constexpr auto kSend = "Send";
constexpr auto kRecv = "Recv";
constexpr auto kSplitC = '/';
class MemReuseChecker {
public:

@ -1132,7 +1132,7 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
}
auto kernel_name = AnfAlgo::GetCNodeName(node);
if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
kernel_name == kReduceScatterOpName) {
kernel_name == kReduceScatterOpName || kernel_name == kHcomSendOpName || kernel_name == kReceiveOpName) {
return true;
}
return false;

@ -28,6 +28,8 @@
#include "mindspore/core/base/core_ops.h"
#include "transform/graph_ir/util.h"
static constexpr char kGeOpNameHcclSend[] = "HcomSend";
static constexpr char kGeOpNameHcclReceive[] = "HcomReceive";
static constexpr char kGeOpNameHcclAllRudece[] = "HcomAllReduce";
static constexpr char kGeOpNameHcclAllGather[] = "HcomAllGather";
static constexpr char kGeOpNameHcclBroadcast[] = "HcomBroadcast";
@ -63,6 +65,18 @@ struct IsString<std::string> {
static constexpr bool value = true;
};
template <class T>
struct IsVector {
// cppcheck-suppress unusedStructMember
static constexpr bool value = false;
};
template <>
struct IsVector<std::vector<int64_t>> {
// cppcheck-suppress unusedStructMember
static constexpr bool value = true;
};
namespace mindspore::hccl {
template <class T>
static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name,
@ -78,6 +92,8 @@ static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const st
auto attr = AnfAlgo::GetNodeAttr<T>(cnode, anf_attr_name);
if constexpr (IsString<T>::value) {
ret = ge::AttrUtils::SetStr(*ge_op, ge_attr_name, attr);
} else if constexpr (IsVector<T>::value) {
ret = ge::AttrUtils::SetListInt(*ge_op, ge_attr_name, attr);
} else {
ret = ge::AttrUtils::SetInt(*ge_op, ge_attr_name, attr);
}
@ -99,6 +115,10 @@ std::string GetGeNodeName(const CNodePtr &cnode) {
return kGeOpNameHcclBroadcast;
} else if (IsPrimitiveCNode(cnode, prim::kPrimReduceScatter)) {
return kGeOpNameHcclReduceScatter;
} else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) {
return kGeOpNameHcclSend;
} else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
return kGeOpNameHcclReceive;
}
MS_LOG(EXCEPTION) << "Unknown hccl node type " << cnode->DebugString();
@ -133,6 +153,10 @@ std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr
// set node attr
(void)ConvertAttr<int64_t>(cnode, op_desc, kAttrRankSize, ge::HCOM_ATTR_RANK_SIZE);
(void)ConvertAttr<std::string>(cnode, op_desc, kAttrGroup, ge::HCOM_ATTR_GROUP);
(void)ConvertAttr<int64_t>(cnode, op_desc, kAttrSrcRank, ge::HCOM_ATTR_SRC_RANK);
(void)ConvertAttr<int64_t>(cnode, op_desc, kAttrDestRank, ge::HCOM_ATTR_DEST_RANK);
(void)ConvertAttr<int64_t>(cnode, op_desc, kAttrSrTag, ge::HCOM_ATTR_SR_TAG);
(void)ConvertAttr<std::vector<int64_t>>(cnode, op_desc, kAttrShape, ge::HCOM_ATTR_SHAPE);
ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>(kStubDataStructureName);
MS_EXCEPTION_IF_NULL(ge_graph);

@ -57,6 +57,8 @@ constexpr auto kAllReduceOpName = "AllReduce";
constexpr auto kAllGatherOpName = "AllGather";
constexpr auto kHostAllGatherOpName = "HostAllGather";
constexpr auto kBroadcastOpName = "Broadcast";
constexpr auto kReceiveOpName = "Receive";
constexpr auto kHcomSendOpName = "Send";
constexpr auto kReduceScatterOpName = "ReduceScatter";
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
@ -142,8 +144,8 @@ constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad";
constexpr auto kStreamSwitchOpName = "StreamSwitch";
constexpr auto kStreamActiveOpName = "StreamActive";
constexpr auto kAssignAddOpName = "AssignAdd";
constexpr auto kSendOpName = "Send";
constexpr auto kRecvOpName = "Recv";
constexpr auto kSendOpName = "StreamSend";
constexpr auto kRecvOpName = "StreamRecv";
constexpr auto kReluV2OpName = "ReLUV2";
constexpr auto kReluGradV2OpName = "ReluGradV2";
constexpr auto kAddNOpName = "AddN";
@ -248,6 +250,8 @@ constexpr auto kBroadcastToOpName = "BroadcastTo";
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
constexpr auto kHcomOpTypeAllGather = "HcomAllGather";
constexpr auto kHcomOpTypeBroadcast = "HcomBroadcast";
constexpr auto kHcomOpTypeSend = "HcomSend";
constexpr auto kHcomOpTypeReceive = "HcomReceive";
constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter";
// attr key name
@ -292,6 +296,9 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active
constexpr auto kAttrFusion = "fusion";
constexpr auto kAttrGroup = "group";
constexpr auto kAttrOp = "op";
constexpr auto kAttrDestRank = "dest_rank";
constexpr auto kAttrSrcRank = "src_rank";
constexpr auto kAttrSrTag = "sr_tag";
constexpr auto kAttrRootRank = "root_rank";
constexpr auto kAttrIsTraining = "is_training";
constexpr auto kAttrFusionId = "fusion_id";

@ -193,6 +193,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");

Loading…
Cancel
Save