|
|
|
@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
|
|
|
|
|
std::vector<std::function<void()>> broadcast_calls;
|
|
|
|
|
|
|
|
|
|
int type = platform::ToNCCLDataType(in_tensor.type());
|
|
|
|
|
size_t numel = static_cast<size_t>(in_tensor.numel());
|
|
|
|
|
|
|
|
|
|
for (auto out_var_handle : out_var_handles) {
|
|
|
|
|
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
|
|
|
|
|
out_handle = out_var_handle;
|
|
|
|
|
} else {
|
|
|
|
|
send_recv_buffer =
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(
|
|
|
|
|
out_var_handle->place_);
|
|
|
|
|
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
|
|
|
|
|
.Resize(in_tensor.dims())
|
|
|
|
|
.mutable_data(out_var_handle->place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int type = platform::ToNCCLDataType(in_tensor.type());
|
|
|
|
|
size_t numel = static_cast<size_t>(in_tensor.numel());
|
|
|
|
|
broadcast_calls.emplace_back(
|
|
|
|
|
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
|
|
|
@ -102,23 +103,50 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
this->RunAndRecordEvent([&] {
|
|
|
|
|
{
|
|
|
|
|
platform::NCCLGroupGuard guard;
|
|
|
|
|
for (auto &call : broadcast_calls) {
|
|
|
|
|
call();
|
|
|
|
|
// FIXME(zcd): a temporary fix for some language model that has sparse
|
|
|
|
|
// parameter.
|
|
|
|
|
bool use_mutex = true;
|
|
|
|
|
if (in_var->IsType<paddle::framework::SelectedRows>()) {
|
|
|
|
|
use_mutex = false;
|
|
|
|
|
}
|
|
|
|
|
if (use_mutex) {
|
|
|
|
|
this->RunAndRecordEvent([&] {
|
|
|
|
|
{
|
|
|
|
|
platform::NCCLGroupGuard guard;
|
|
|
|
|
for (auto &call : broadcast_calls) {
|
|
|
|
|
call();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!out_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handles[0]->name_);
|
|
|
|
|
paddle::framework::TensorCopy(
|
|
|
|
|
in_tensor, in_var_handle->place_,
|
|
|
|
|
*(dev_ctxes_.at(in_var_handle->place_)),
|
|
|
|
|
&VariableVisitor::GetMutableTensor(out_var));
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
if (!out_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handles[0]->name_);
|
|
|
|
|
paddle::framework::TensorCopy(
|
|
|
|
|
in_tensor, in_var_handle->place_,
|
|
|
|
|
*(dev_ctxes_.at(in_var_handle->place_)),
|
|
|
|
|
&VariableVisitor::GetMutableTensor(out_var));
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
this->RunAndRecordEventNoMutex([&] {
|
|
|
|
|
{
|
|
|
|
|
platform::NCCLGroupGuard guard;
|
|
|
|
|
for (auto &call : broadcast_calls) {
|
|
|
|
|
call();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!out_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handles[0]->name_);
|
|
|
|
|
paddle::framework::TensorCopy(
|
|
|
|
|
in_tensor, in_var_handle->place_,
|
|
|
|
|
*(dev_ctxes_.at(in_var_handle->place_)),
|
|
|
|
|
&VariableVisitor::GetMutableTensor(out_var));
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("CUDA is not enabled.");
|
|
|
|
|
#endif
|
|
|
|
|