PyNative AllReduce Bucket

pull/12091/head
caifubi 4 years ago
parent d2d6c3cfb5
commit 171b468bb3

@ -64,6 +64,7 @@
#include "toolchain/adx_datadump_server.h"
#ifdef ENABLE_DUMP_IR
#include "debug/rdr/running_data_recorder.h"
#include "runtime/device/ascend/ascend_bucket.h"
#endif
#if ENABLE_CPU && ENABLE_D
#include "ps/util.h"
@ -258,6 +259,7 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
// construct graph, if successfully, graph_sum_ + 1
auto graph = ConstructKernelGraph(lst, outputs);
auto graph_id = graph->graph_id();
InitAllBucket(graph);
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}
@ -632,6 +634,13 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
if (tensor->NeedWaitDevice()) {
tensor->WaitDevice();
}
}
// Run op
auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph);
@ -1510,5 +1519,9 @@ void AscendSession::SyncStream() {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
}
std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
return std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
}
} // namespace session
} // namespace mindspore

@ -61,6 +61,7 @@ class AscendSession : public SessionBasic {
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, size_t> &cnode_refcount) override;
std::string GetCommWorldGroup() override { return kHcclWorldGroup; }
private:
// compile child graph when session have multiple child graphs
@ -123,6 +124,7 @@ class AscendSession : public SessionBasic {
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
InputTensorInfo *input_tensor_info);
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs

@ -16,6 +16,7 @@
#include "backend/session/gpu_session.h"
#include <string>
#include <utility>
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
@ -63,6 +64,7 @@
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/cuda_driver.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "runtime/device/gpu/gpu_bucket.h"
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
#include "utils/ms_context.h"
@ -394,6 +396,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
InitAllBucket(graph);
// Alloc memory in graph mode, including static memory and dynamic memory
if (!pynative_mode) {
AllocateMemory(graph.get());
@ -473,6 +477,12 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
if (tensor->NeedWaitDevice()) {
tensor->WaitDevice();
}
}
// run op
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -548,6 +558,10 @@ void GPUSession::SyncStream() {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
}
std::shared_ptr<device::Bucket> GPUSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
return std::make_shared<device::gpu::GPUBucket>(bucket_id, bucket_size);
}
} // namespace gpu
} // namespace session
} // namespace mindspore

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include <algorithm>
#include <string>
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/session_factory.h"
@ -44,6 +45,8 @@ class GPUSession : public SessionBasic {
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
private:
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;

@ -40,6 +40,7 @@
#include "debug/anf_ir_dump.h"
#include "debug/common.h"
#include "utils/trace_base.h"
#include "frontend/parallel/context.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/constants.h"
@ -556,10 +557,12 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs,
std::vector<TensorPtr> *runop_output_tensors) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(runop_output_tensors);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
@ -592,6 +595,7 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
runop_output_tensors->emplace_back(output_tensor);
}
}
}
@ -2196,6 +2200,11 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
GetRefCount(kernel_graph.get(), &cnode_refcount);
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
// Clear bucket resources every step
if (kernel_graph->is_bprop()) {
ClearAllBucket(graph_id);
}
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
for (const auto &kernel : kernel_graph->execution_order()) {
// Generate input tensors, tensor masks and input kernel with index
@ -2212,9 +2221,15 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
input_tensor_info.input_tensors_mask);
std::vector<tensor::TensorPtr> new_output_tensors;
// Handle inputs and outputs of current op
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs, &new_output_tensors);
// Save grad node to Bucket
if (kernel_graph->is_bprop()) {
AddGradAddrToBucket(graph_id, new_output_tensors);
}
}
MS_LOG(INFO) << "Finish!";
}
@ -2287,6 +2302,137 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
}
}
std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string group = GetCommWorldGroup();
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
// PyNative not support multi group allreduce
group += "sum1";
return parallel_context->GetAllReduceFusionSplitIndices(group);
}
uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
}
void SetGraphBpropAttr(const KernelGraphPtr &graph) {
auto &execution_orders = graph->execution_order();
if (std::any_of(execution_orders.begin(), execution_orders.end(),
[](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
graph->set_is_bprop(true);
MS_LOG(INFO) << "Match bprop graph";
} else {
graph->set_is_bprop(false);
}
}
std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) {
if (split_index.empty()) {
auto grads_count = GetBpropGraphGradsCount(graph);
if (grads_count == 0) {
MS_LOG(EXCEPTION) << "Bprop graph has no grad";
}
return {grads_count};
}
std::vector<uint32_t> bucket_size_list;
uint32_t old_index = 0;
for (auto &index : split_index) {
if (old_index == 0) {
bucket_size_list.emplace_back(index - old_index + 1);
} else {
bucket_size_list.emplace_back(index - old_index);
}
old_index = index;
}
return bucket_size_list;
}
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
SetGraphBpropAttr(graph);
if (!graph->is_bprop()) {
return;
}
std::vector<std::shared_ptr<device::Bucket>> bucket_list;
// Create bucket for every split allreduce ops
auto split_index = GetAllReduceSplitIndex();
auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
uint32_t bucket_id = 0;
for (auto bucket_size : bucket_size_list) {
MS_LOG(INFO) << "Create new bucket:" << bucket_id;
auto bucket = CreateBucket(bucket_id++, bucket_size);
bucket->Init();
bucket_list.emplace_back(bucket);
}
auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list);
if (!bucket_ret.second) {
MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id();
}
// set all free bucket index to 0
auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0);
if (!free_bucket_ret.second) {
MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id();
}
MS_LOG(INFO) << "Init Bucket finish";
}
void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
if (parallel_mode != parallel::DATA_PARALLEL) {
return;
}
auto iter = bucket_map_.find(graph_id);
if (iter == bucket_map_.end()) {
MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id;
}
auto &bucket_list = iter->second;
auto free_bucket_iter = free_bucket_id_map_.find(graph_id);
if (free_bucket_iter == free_bucket_id_map_.end()) {
MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id;
}
auto free_bucket_index = free_bucket_iter->second;
for (auto &tensor : grad_tensor) {
if (free_bucket_index >= bucket_list.size()) {
MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second
<< " total bucket num:" << bucket_list.size();
}
auto &free_bucket = bucket_list[free_bucket_index];
free_bucket->AddGradTensor(tensor);
if (free_bucket->full()) {
MS_LOG(INFO) << "bucket is full";
free_bucket->Launch();
free_bucket_index = ++free_bucket_iter->second;
MS_LOG(INFO) << "new free bucket:" << free_bucket_index;
}
}
}
void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
auto iter = bucket_map_.find(graph_id);
if (iter != bucket_map_.end()) {
auto bucket_list = iter->second;
for (auto &bucket : bucket_list) {
MS_LOG(INFO) << "Clear bucket:" << bucket->id();
bucket->Release();
}
}
auto free_iter = free_bucket_id_map_.find(graph_id);
if (free_iter != free_bucket_id_map_.end()) {
free_iter->second = 0;
}
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::PSContext::instance()->is_worker()) {

@ -32,6 +32,7 @@
#include "utils/contract.h"
#include "runtime/device/kernel_info.h"
#include "utils/ms_context.h"
#include "runtime/device/bucket.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "debug/debugger/debugger.h"
#endif
@ -224,12 +225,20 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
void InitAllBucket(const KernelGraphPtr &graph);
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
void ClearAllBucket(const GraphId &graph_id);
std::vector<uint32_t> GetAllReduceSplitIndex();
virtual std::string GetCommWorldGroup() { return std::string(); }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
void GetBatchElements(const AnfNodePtr &kernel_node) const;
void InitPsWorker(const KernelGraphPtr &kernel_graph);
#endif
std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
std::map<uint32_t, uint32_t> free_bucket_id_map_;
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;

@ -677,34 +677,9 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
return op_exec_info;
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list);
MS_EXCEPTION_IF_NULL(op_exec_info);
void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
auto prim = op_exec_info->py_primitive;
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim));
const auto &signature = prim->signatures();
auto sig_size = signature.size();
auto size = op_exec_info->op_inputs.size();
// ignore monad signature
for (auto sig : signature) {
if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
--sig_size;
}
}
if (sig_size > 0 && sig_size != size) {
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
<< "inputs size " << sig_size;
}
if (op_exec_info->op_name != prim::kPrimCast->name()) {
RunParameterAutoMixPrecisionCast(op_exec_info);
}
MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
abstract::AbstractBasePtr abs = nullptr;
const auto &obj = op_exec_info->op_inputs[i];
@ -733,11 +708,42 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if (input_node->abstract() != nullptr) {
abs = input_node->abstract();
}
inputs.emplace_back(input_node);
inputs->emplace_back(input_node);
}
}
(*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i));
}
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list);
MS_EXCEPTION_IF_NULL(op_exec_info);
auto prim = op_exec_info->py_primitive;
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim));
const auto &signature = prim->signatures();
auto sig_size = signature.size();
auto size = op_exec_info->op_inputs.size();
// ignore monad signature
for (auto sig : signature) {
if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
--sig_size;
}
}
if (sig_size > 0 && sig_size != size) {
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
<< "inputs size " << sig_size;
}
if (op_exec_info->op_name != prim::kPrimCast->name()) {
RunParameterAutoMixPrecisionCast(op_exec_info);
}
MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
GetArgsSpec(op_exec_info, op_masks, &inputs, args_spec_list);
CNodePtr cnode = nullptr;
if (need_construct_graph()) {

@ -208,6 +208,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
PynativeStatusCode *const status);
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,

@ -1,5 +1,7 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"bucket.cc"
)
if(ENABLE_GPU)

@ -0,0 +1,173 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ascend_bucket.h"
#include <vector>
#include <memory>
#include "runtime/mem.h"
#include "external/hccl/hccl.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "backend/kernel_compiler/hccl/hcom_util.h"
#include "backend/kernel_compiler/hccl/hccl_context.h"
#include "runtime/device/memory_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/ascend/ascend_event.h"
#include "utils/profile.h"
#define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \
{ \
rtError_t ret = (expression); \
if (ret != RT_ERROR_NONE) { \
MS_LOG(EXCEPTION) << message << ", error code: " << ret; \
} \
}
namespace mindspore::device::ascend {
void AscendBucket::AllocateAllReduceAddr() {
// Check bucket is full
if (grad_tensor_list_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size()
<< " is not equal to bucket size:" << bucket_size_;
}
auto total_size = 0;
std::vector<size_t> align_size_list;
std::vector<size_t> origin_size_list;
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor_type_list_.emplace_back(tensor->data_type());
DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
auto origin_size = device_address->GetSize();
auto align_size = MemoryManager::GetCommonAlignSize(origin_size);
origin_size_list.emplace_back(origin_size);
align_size_list.emplace_back(align_size);
total_size += align_size;
memcpy_input_addrs_.emplace_back(std::make_shared<kernel::Address>(
static_cast<uint8_t *>(device_address->GetMutablePtr()), device_address->GetSize()));
}
total_size_ = total_size;
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(runtime_instance);
// AllReduce input output addr need to clear zero
ar_input_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size);
ar_output_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size);
// generate memecpy output addr
uint8_t *memcpy_output = ar_input_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, origin_size_list[i]));
memcpy_output += align_size_list[i];
}
// store output tensor addr
uint8_t *tensor_output = ar_output_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list[i];
}
}
void AscendBucket::FreeDeviceMem(void *dev_ptr) { AscendMemoryPool::GetInstance().FreeTensorMem(dev_ptr); }
void AscendBucket::FreeAllDeviceMem() {
if (ar_input_addr_ != nullptr) {
uint8_t *origin_dev_addr = ar_input_addr_ - kMemAlignSize;
FreeDeviceMem(origin_dev_addr);
ar_input_addr_ = nullptr;
}
if (ar_output_addr_ != nullptr) {
uint8_t *origin_dev_addr = ar_output_addr_ - kMemAlignSize;
FreeDeviceMem(origin_dev_addr);
ar_output_addr_ = nullptr;
}
}
void AscendBucket::CopyTensorToContiguousMemory() {
// Clean input addr
CHECK_ASCEND_RT_WITH_EXCEPTION(rtMemsetAsync(ar_input_addr_, total_size_, 0, total_size_, compute_stream_),
"Call rtMemsetAsync failed");
for (size_t i = 0; i < bucket_size_; ++i) {
MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]);
MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]);
MS_LOG(DEBUG) << "MemcpyAsync dst size:" << memcpy_output_addrs_[i]->size
<< " src size:" << memcpy_input_addrs_[i]->size;
if (memcpy_output_addrs_[i]->size < memcpy_input_addrs_[i]->size) {
MS_LOG(EXCEPTION) << "rtMemcpyAsync dst size < src size";
}
CHECK_ASCEND_RT_WITH_EXCEPTION(
rtMemcpyAsync(memcpy_output_addrs_[i]->addr, memcpy_output_addrs_[i]->size, memcpy_input_addrs_[i]->addr,
memcpy_input_addrs_[i]->size, RT_MEMCPY_DEVICE_TO_DEVICE, compute_stream_),
"Call rtMemcpyAsync failed");
}
}
void AscendBucket::LaunchAllReduce() {
if (tensor_type_list_.empty()) {
MS_LOG(EXCEPTION) << "No tesnor type found";
}
// AllReduce inputs data type should be same
auto type = tensor_type_list_[0];
if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(),
[&type](TypeId tensor_type) { return type != tensor_type; })) {
MS_LOG(EXCEPTION) << "allreduce input have different dtype";
}
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(EXCEPTION) << "unknown data type:" << type;
}
uint32_t type_size;
if (!HcomUtil::GetHcomTypeSize(iter->second, &type_size)) {
MS_LOG(EXCEPTION) << "get hcom type size fialed";
}
if (type_size == 0 || total_size_ % type_size != 0) {
MS_LOG(EXCEPTION) << "Total_size[" << total_size_ << "],Type_size[" << type_size << "] != 0, fail!";
}
auto hccl_count = total_size_ / type_size;
HcclReduceOp op_type = HcclReduceOp::HCCL_REDUCE_SUM;
auto hccl_result = HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, iter->second, op_type,
kernel::HcclContext::GetInstance().hccl_comm(), stream_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "HcclAllReduce faled, ret:" << hccl_result;
}
}
void AscendBucket::Init() {
pre_event_ = std::make_shared<AscendEvent>();
post_event_ = std::make_shared<AscendEvent>();
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(kernel_runtime);
compute_stream_ = kernel_runtime->compute_stream();
stream_ = kernel_runtime->communication_stream();
MS_EXCEPTION_IF_NULL(pre_event_);
MS_EXCEPTION_IF_NULL(post_event_);
pre_event_->set_wait_stream(stream_);
pre_event_->set_record_stream(compute_stream_);
post_event_->set_wait_stream(compute_stream_);
post_event_->set_record_stream(stream_);
}
} // namespace mindspore::device::ascend

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
#include "runtime/device/bucket.h"
namespace mindspore::device::ascend {
class AscendBucket : public Bucket {
public:
AscendBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size) {}
~AscendBucket() override = default;
void Init() override;
private:
void AllocateAllReduceAddr() override;
void FreeAllDeviceMem() override;
void FreeDeviceMem(void *dev_ptr) override;
void CopyTensorToContiguousMemory() override;
void LaunchAllReduce() override;
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ascend_event.h"
#include "runtime/event.h"
#include "runtime/stream.h"
#include "utils/log_adapter.h"
namespace mindspore::device::ascend {
AscendEvent::AscendEvent() {
auto ret = rtEventCreate(&event_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtEventCreate failed, ret:" << ret;
event_ = nullptr;
}
}
AscendEvent::~AscendEvent() {
auto ret = rtEventDestroy(event_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtEventDestory failed, ret:" << ret;
}
}
void AscendEvent::RecordEvent() {
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(record_stream_);
auto ret = rtEventRecord(event_, record_stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtEventRecord failed, ret:" << ret;
}
need_wait_ = true;
}
void AscendEvent::WaitEvent() {
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(wait_stream_);
auto ret = rtStreamWaitEvent(wait_stream_, event_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtStreamWaitEvent failed, ret:" << ret;
}
need_wait_ = false;
}
bool AscendEvent::NeedWait() { return need_wait_; }
} // namespace mindspore::device::ascend

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ASCEND_EVENT_H
#define MINDSPORE_ASCEND_EVENT_H
#include "runtime/base.h"
#include "ir/device_event.h"
namespace mindspore::device::ascend {
class AscendEvent : public DeviceEvent {
public:
AscendEvent();
~AscendEvent() override;
void WaitEvent() override;
void RecordEvent() override;
bool NeedWait() override;
void set_wait_stream(rtStream_t wait_stream) override { wait_stream_ = wait_stream; }
void set_record_stream(rtStream_t record_stream) override { record_stream_ = record_stream; }
private:
rtEvent_t event_{nullptr};
rtStream_t wait_stream_{nullptr};
rtStream_t record_stream_{nullptr};
bool need_wait_{false};
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_ASCEND_EVENT_H

@ -718,6 +718,10 @@ bool AscendKernelRuntime::SyncStream() {
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}
if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}
FreeAndClearBufferPtrs();
return true;
}
@ -786,6 +790,10 @@ bool AscendKernelRuntime::InitDevice() {
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
ret = rtStreamCreate(&communication_stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
return true;
}
@ -799,6 +807,14 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
stream_ = nullptr;
}
if (communication_stream_ != nullptr) {
auto ret = rtStreamDestroy(communication_stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
}
communication_stream_ = nullptr;
}
auto ret = rtDeviceReset(device_id);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
@ -919,4 +935,5 @@ uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
return ascend_mem_manager->GetDeviceMemSize();
}
} // namespace mindspore::device::ascend

@ -57,6 +57,8 @@ class AscendKernelRuntime : public KernelRuntime {
void *context() const override { return rt_context_; }
void PreInit() override;
uint64_t GetAvailableMemMaxSize() const override;
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,

@ -162,6 +162,14 @@ void AscendMemoryManager::MallocSomasDynamicMem(const session::KernelGraph *grap
somas_reuse_util_ptr_->ConvertToProfilingNode(graph->graph_id());
}
}
// communication memory: [512align_size + data + 512align_size]
// return the pointer to the start of data address.
uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
auto align_size = GetCommunicationAlignSize(size);
uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
return base_ptr + kMemAlignSize;
}
} // namespace ascend
} // namespace device
} // namespace mindspore

@ -33,6 +33,7 @@ class AscendMemoryManager : public MemoryManager {
void *MallocMemFromMemPool(size_t size) override;
uint64_t GetDeviceMemSize();
void MallocSomasDynamicMem(const session::KernelGraph *graph);
uint8_t *MallocCommunicationMemFromMemPool(size_t size) override;
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;

@ -0,0 +1,106 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/bucket.h"
#include <memory>
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/profile.h"
namespace mindspore::device {
void Bucket::AddGradTensor(const tensor::TensorPtr &tensor) {
if (grad_tensor_list_.size() >= bucket_size_) {
MS_LOG(EXCEPTION) << "bucket is full";
}
grad_tensor_list_.emplace_back(tensor);
if (grad_tensor_list_.size() > bucket_size_) {
MS_LOG(EXCEPTION) << "too many tensor add to the bucket, bucket_size_:" << bucket_size_
<< " total tensor size:" << grad_tensor_list_.size();
}
MS_LOG(INFO) << "current bucket tensors size:" << grad_tensor_list_.size();
// bucket is full, start to launch allreduce
if (grad_tensor_list_.size() == bucket_size_) {
full_ = true;
}
}
void Bucket::Launch() {
auto start = GetTime();
if (grad_tensor_list_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "Bucket is not full, grad_tensor_list_ size:" << grad_tensor_list_.size()
<< " bucket_size_:" << bucket_size_;
}
MS_LOG(INFO) << "Bucket is full, start to launch AllReduce";
MS_EXCEPTION_IF_NULL(pre_event_);
MS_EXCEPTION_IF_NULL(post_event_);
AllocateAllReduceAddr();
CopyTensorToContiguousMemory();
pre_event_->RecordEvent();
pre_event_->WaitEvent();
LaunchAllReduce();
post_event_->RecordEvent();
UpdateTensorAddr();
// pass event to the tensor
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor->SetDeviceEvent(post_event_);
}
MS_LOG(INFO) << "Bucket launch cost:" << (GetTime() - start) * 1e6 << " us";
}
// TODO(caifubi): float16 grad cast to float32 grad
void Bucket::UpdateTensorAddr() {
if (grad_tensor_list_.size() != bucket_size_ || new_tensor_output_addrs_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "grad_tensor_list size:" << grad_tensor_list_.size()
<< " tensor output addr size:" << new_tensor_output_addrs_.size()
<< " bucket size:" << bucket_size_;
}
for (size_t i = 0; i < bucket_size_; ++i) {
auto &tensor = grad_tensor_list_[i];
MS_EXCEPTION_IF_NULL(tensor);
auto device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
// release old addr and manage addr by this Bucket.
MS_EXCEPTION_IF_NULL(device_address);
auto origin_dev_ptr = device_address->GetMutablePtr();
// FreeDeviceMem(origin_dev_ptr);
tensor_old_addr_list_.emplace_back(origin_dev_ptr);
device_address->from_mem_pool_ = false;
device_address->set_ptr(new_tensor_output_addrs_[i]);
}
}
void Bucket::LazyDeleteOldAddr() {
MS_LOG(INFO) << "Lazy delete old grad address";
for (auto old_addr : tensor_old_addr_list_) {
FreeDeviceMem(old_addr);
}
tensor_old_addr_list_.clear();
}
void Bucket::Release() {
MS_LOG(INFO) << "Clear bucket:" << id_;
grad_tensor_list_.clear();
new_tensor_output_addrs_.clear();
memcpy_input_addrs_.clear();
memcpy_output_addrs_.clear();
tensor_type_list_.clear();
LazyDeleteOldAddr();
FreeAllDeviceMem();
full_ = false;
}
} // namespace mindspore::device

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
#include <vector>
#include <utility>
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/device_event.h"
#include "runtime/device/device_address.h"
#include "backend/session/kernel_graph.h"
namespace mindspore::device {
class Bucket {
public:
Bucket(uint32_t id, uint32_t bucket_size)
: id_(id),
bucket_size_(bucket_size),
full_(false),
stream_(nullptr),
compute_stream_(nullptr),
pre_event_(nullptr),
post_event_(nullptr),
total_size_(0),
ar_input_addr_(nullptr),
ar_output_addr_(nullptr) {}
virtual ~Bucket() = default;
uint32_t id() const { return id_; }
bool full() const { return full_; }
void Launch();
void Release();
void AddGradTensor(const tensor::TensorPtr &tensor);
virtual void Init() = 0;
protected:
uint32_t id_;
uint32_t bucket_size_;
bool full_;
void *stream_;
void *compute_stream_;
std::shared_ptr<DeviceEvent> pre_event_;
std::shared_ptr<DeviceEvent> post_event_;
size_t total_size_;
uint8_t *ar_input_addr_;
uint8_t *ar_output_addr_;
std::string group_;
std::vector<tensor::TensorPtr> grad_tensor_list_;
std::vector<uint8_t *> new_tensor_output_addrs_;
std::vector<kernel::AddressPtr> memcpy_input_addrs_;
std::vector<kernel::AddressPtr> memcpy_output_addrs_;
std::vector<TypeId> tensor_type_list_;
std::vector<void *> tensor_old_addr_list_;
virtual void AllocateAllReduceAddr() = 0;
void UpdateTensorAddr();
virtual void LaunchAllReduce() = 0;
virtual void FreeAllDeviceMem() = 0;
virtual void FreeDeviceMem(void *dev_ptr) = 0;
virtual void CopyTensorToContiguousMemory() = 0;
void LazyDeleteOldAddr();
};
} // namespace mindspore::device
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_

@ -26,6 +26,7 @@
namespace mindspore {
namespace device {
class Bucket;
namespace cpu {
class CPUSimpleMemPlan;
class CPUMemoryManager;
@ -100,6 +101,7 @@ class DeviceAddress : public mindspore::DeviceSync {
friend class mindspore::device::ascend::AscendKernelRuntime;
friend class mindspore::device::ascend::AscendMemoryManager;
friend class mindspore::device::ascend::DataDumper;
friend class mindspore::device::Bucket;
};
using DeviceAddressPtr = std::shared_ptr<DeviceAddress>;

@ -0,0 +1,177 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gpu/gpu_bucket.h"
#include <cuda_runtime_api.h>
#include <nccl.h>
#include <vector>
#include <memory>
#include "abstract/utils.h"
#include "runtime/device/gpu/gpu_event.h"
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "runtime/device/gpu/gpu_device_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
#include "runtime/device/gpu/gpu_common.h"
namespace {
const size_t kCommunicationMemAlignSize = 16;
size_t AlignMemorySize(size_t size) {
if (size == 0) {
return kCommunicationMemAlignSize;
}
return ((size + kCommunicationMemAlignSize - 1) / kCommunicationMemAlignSize) * kCommunicationMemAlignSize;
}
} // namespace
namespace mindspore::device::gpu {
GPUBucket::GPUBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size), collective_handle_(nullptr) {
group_ = kNcclWorldGroup;
}
void GPUBucket::AllocateAllReduceAddr() {
MS_LOG(INFO) << "start";
if (grad_tensor_list_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size()
<< " is not equal to bucket size:" << bucket_size_;
}
auto total_size = 0;
std::vector<size_t> size_list;
std::vector<size_t> align_size_list;
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor_type_list_.emplace_back(tensor->data_type());
DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
MS_EXCEPTION_IF_NULL(device_address);
auto origin_size = device_address->GetSize();
auto align_size = AlignMemorySize(origin_size);
size_list.emplace_back(origin_size);
align_size_list.emplace_back(align_size);
total_size += align_size;
memcpy_input_addrs_.emplace_back(
std::make_shared<kernel::Address>(static_cast<uint8_t *>(device_address->GetMutablePtr()), origin_size));
}
total_size_ = total_size;
ar_input_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size));
ar_output_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size));
uint8_t *memcpy_output = ar_input_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, size_list[i]));
memcpy_output += align_size_list[i];
}
uint8_t *tensor_output = ar_output_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list[i];
}
MS_LOG(INFO) << "end";
}
void GPUBucket::FreeDeviceMem(void *dev_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(dev_ptr); }
void GPUBucket::FreeAllDeviceMem() {
MS_LOG(INFO) << "start";
if (ar_input_addr_ != nullptr) {
FreeDeviceMem(ar_input_addr_);
ar_input_addr_ = nullptr;
}
if (ar_output_addr_ != nullptr) {
FreeDeviceMem(ar_output_addr_);
ar_output_addr_ = nullptr;
}
MS_LOG(INFO) << "end";
}
void GPUBucket::CopyTensorToContiguousMemory() {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(compute_stream_);
// Clean allreduce input
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemsetAsync(ar_input_addr_, 0, total_size_, static_cast<cudaStream_t>(compute_stream_)),
"Call cudaMemsetAsync failed");
for (size_t i = 0; i < bucket_size_; ++i) {
MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]);
MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]);
if (!GPUDeviceManager::GetInstance().CopyDeviceMemToDeviceAsync(memcpy_output_addrs_[i]->addr,
memcpy_input_addrs_[i]->addr,
memcpy_output_addrs_[i]->size, compute_stream_)) {
MS_LOG(EXCEPTION) << "Copy memory failed";
}
}
MS_LOG(INFO) << "end";
}
void GPUBucket::LaunchAllReduce() {
MS_LOG(INFO) << "start";
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
auto all_reduce_funcptr =
reinterpret_cast<kernel::AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
MS_EXCEPTION_IF_NULL(stream_);
if (tensor_type_list_.empty()) {
MS_LOG(EXCEPTION) << "No tesnor type found";
}
auto type = tensor_type_list_[0];
if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(),
[&type](TypeId tensor_type) { return type != tensor_type; })) {
MS_LOG(EXCEPTION) << "AllReduce input have different dtype";
}
auto type_size = abstract::TypeIdSize(type);
if (type_size == 0) {
MS_LOG(EXCEPTION) << "Invalid type:" << type;
}
// typeid to nccl_data_type
auto nccl_data_type_iter = kernel::kNcclDtypeMap.find(TypeIdLabel(type));
if (nccl_data_type_iter == kernel::kNcclDtypeMap.end()) {
MS_LOG(EXCEPTION) << "Invalid type:" << type;
}
auto nccl_result =
(*all_reduce_funcptr)(ar_input_addr_, ar_output_addr_, total_size_ / type_size, nccl_data_type_iter->second,
ncclRedOp_t::ncclSum, static_cast<cudaStream_t>(stream_), group_);
if (nccl_result != ncclSuccess) {
MS_LOG(EXCEPTION) << "AllReduce failed, ret:" << nccl_result;
}
MS_LOG(INFO) << "end";
}
void GPUBucket::Init() {
pre_event_ = std::make_shared<GpuEvent>();
post_event_ = std::make_shared<GpuEvent>();
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(kernel_runtime);
stream_ = kernel_runtime->communication_stream();
compute_stream_ = kernel_runtime->compute_stream();
MS_EXCEPTION_IF_NULL(pre_event_);
MS_EXCEPTION_IF_NULL(post_event_);
pre_event_->set_record_stream(compute_stream_);
pre_event_->set_wait_stream(stream_);
post_event_->set_record_stream(stream_);
post_event_->set_wait_stream(compute_stream_);
}
} // namespace mindspore::device::gpu

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
#include "runtime/device/bucket.h"
namespace mindspore::device::gpu {
class GPUBucket : public Bucket {
public:
GPUBucket(uint32_t id, uint32_t bucket_size);
~GPUBucket() override = default;
void Init() override;
private:
void AllocateAllReduceAddr() override;
void FreeAllDeviceMem() override;
void FreeDeviceMem(void *dev_ptr) override;
void CopyTensorToContiguousMemory() override;
void LaunchAllReduce() override;
const void *collective_handle_;
};
} // namespace mindspore::device::gpu
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gpu/gpu_event.h"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore::device::gpu {
GpuEvent::GpuEvent() {
auto ret = cudaEventCreate(&event_);
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cudaEventCreate failed, ret:" << ret;
event_ = nullptr;
}
}
GpuEvent::~GpuEvent() { CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaEventDestroy(event_), "cudaEventDestory failed"); }
void GpuEvent::WaitEvent() {
MS_EXCEPTION_IF_NULL(wait_stream_);
MS_EXCEPTION_IF_NULL(event_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamWaitEvent(wait_stream_, event_, 0), "cudaStreamWaitEvent failed");
need_wait_ = false;
}
void GpuEvent::RecordEvent() {
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(record_stream_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(event_, record_stream_), "cudaEventRecord failed");
need_wait_ = true;
}
bool GpuEvent::NeedWait() { return need_wait_; }
} // namespace mindspore::device::gpu

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
#include <cuda_runtime_api.h>
#include "ir/device_event.h"
namespace mindspore::device::gpu {
class GpuEvent : public DeviceEvent {
public:
GpuEvent();
~GpuEvent() override;
void WaitEvent() override;
void RecordEvent() override;
bool NeedWait() override;
void set_wait_stream(void *wait_stream) override { wait_stream_ = static_cast<cudaStream_t>(wait_stream); }
void set_record_stream(void *record_stream) override { record_stream_ = static_cast<cudaStream_t>(record_stream); }
private:
cudaEvent_t event_{nullptr};
cudaStream_t wait_stream_{nullptr};
cudaStream_t record_stream_{nullptr};
bool need_wait_{false};
};
} // namespace mindspore::device::gpu
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_

@ -229,6 +229,11 @@ bool GPUKernelRuntime::InitDevice() {
MS_LOG(ERROR) << "No default CUDA stream found.";
return false;
}
GPUDeviceManager::GetInstance().CreateStream(&communication_stream_);
if (communication_stream_ == nullptr) {
MS_LOG(ERROR) << "Invalid communication stream";
return false;
}
return true;
}
@ -1251,6 +1256,7 @@ session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &n
return addr_iter->second[i];
}
} // namespace gpu
} // namespace device
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save