Fix broadcast

wangkuiyi-patch-3
chengduoZH 7 years ago
parent 28a86aebc3
commit 13de72388d

@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls;
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle;
} else {
send_recv_buffer =
VariableVisitor::GetMutableTensor(out_var).mutable_data(
out_var_handle->place_);
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
.Resize(in_tensor.dims())
.mutable_data(out_var_handle->place_);
}
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast(

Loading…
Cancel
Save