|
|
|
@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
|
|
|
|
|
std::vector<std::function<void()>> broadcast_calls;
|
|
|
|
|
|
|
|
|
|
int type = platform::ToNCCLDataType(in_tensor.type());
|
|
|
|
|
size_t numel = static_cast<size_t>(in_tensor.numel());
|
|
|
|
|
|
|
|
|
|
for (auto out_var_handle : out_var_handles) {
|
|
|
|
|
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
|
|
|
|
|
out_handle = out_var_handle;
|
|
|
|
|
} else {
|
|
|
|
|
send_recv_buffer =
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(
|
|
|
|
|
out_var_handle->place_);
|
|
|
|
|
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
|
|
|
|
|
.Resize(in_tensor.dims())
|
|
|
|
|
.mutable_data(out_var_handle->place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int type = platform::ToNCCLDataType(in_tensor.type());
|
|
|
|
|
size_t numel = static_cast<size_t>(in_tensor.numel());
|
|
|
|
|
broadcast_calls.emplace_back(
|
|
|
|
|
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
|
|
|
|