|
|
|
@ -22,9 +22,9 @@ namespace details {
|
|
|
|
|
|
|
|
|
|
void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
if (places_.size() == 1) return;
|
|
|
|
|
// the input and output may have dummy var.
|
|
|
|
|
VarHandle *in_var_handle;
|
|
|
|
|
|
|
|
|
|
// The input and output may have dummy vars.
|
|
|
|
|
VarHandle *in_var_handle;
|
|
|
|
|
{
|
|
|
|
|
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
|
|
|
|
@ -53,23 +53,39 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
|
|
|
|
|
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
|
|
|
|
|
|
|
|
|
|
// NOTE(zcd): the Place of input can be get from in_tensor and in_var_handle ,
|
|
|
|
|
// maybe they are different, because the Place that getting from in_tensor is
|
|
|
|
|
// determined at runtime, the other is determined at building SSA graph stage.
|
|
|
|
|
// If they are different, DataTransform should be applied. Currently, it has
|
|
|
|
|
// not been done yet.
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
|
if (*out_var_handle == *in_var_handle) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto &out_p = out_var_handle->place_;
|
|
|
|
|
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_p.which(), in_tensor.place().which(),
|
|
|
|
|
"Currently, Places of input and output must be all on CPU "
|
|
|
|
|
"or all on GPU.");
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
|
|
|
|
|
in_tensor.type());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(in_tensor.place())) {
|
|
|
|
|
for (auto *out : out_var_handles) {
|
|
|
|
|
if (*out == *in_var_handle) {
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
|
if (*out_var_handle == *in_var_handle) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &out_p = out->place_;
|
|
|
|
|
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var);
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_p.which(), in_tensor.place().which(),
|
|
|
|
|
"Places must be all on CPU or all on CUDA.");
|
|
|
|
|
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
|
|
|
|
|
in_tensor.type());
|
|
|
|
|
|
|
|
|
|
auto &out_p = out_var_handle->place_;
|
|
|
|
|
auto dev_ctx = dev_ctxes_.at(out_p);
|
|
|
|
|
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
|
|
|
|
|
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
|
|
|
|
|
paddle::framework::TensorCopy(
|
|
|
|
|
in_tensor, out_p, *dev_ctx,
|
|
|
|
@ -78,35 +94,21 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place()));
|
|
|
|
|
VarHandle *out_handle;
|
|
|
|
|
int root = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
|
|
|
|
|
VarHandle *out_handle = nullptr;
|
|
|
|
|
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
|
|
|
|
|
std::vector<std::function<void()>> broadcast_calls;
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < out_var_handles.size(); ++j) {
|
|
|
|
|
VarHandle *out_var_handle = out_var_handles[j];
|
|
|
|
|
for (auto out_var_handle : out_var_handles) {
|
|
|
|
|
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
|
|
|
|
|
if (*out_var_handle != *in_var_handle) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var);
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(),
|
|
|
|
|
in_tensor.place().which(),
|
|
|
|
|
"Places must be all on CPU or all on CUDA.");
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(
|
|
|
|
|
out_var_handle->place_, in_tensor.type());
|
|
|
|
|
}
|
|
|
|
|
int dst_id =
|
|
|
|
|
boost::get<platform::CUDAPlace>(out_var_handle->place_).device;
|
|
|
|
|
|
|
|
|
|
auto out_p = out_var_handle->place_;
|
|
|
|
|
int dev_id = boost::get<platform::CUDAPlace>(out_p).device;
|
|
|
|
|
|
|
|
|
|
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
|
|
|
|
|
auto stream = nccl_ctx.stream();
|
|
|
|
|
auto comm = nccl_ctx.comm_;
|
|
|
|
|
auto &nccl_ctx = nccl_ctxs_->at(dst_id);
|
|
|
|
|
|
|
|
|
|
void *send_recv_buffer = nullptr;
|
|
|
|
|
if (root == dev_id) {
|
|
|
|
|
if (root_id == dst_id) {
|
|
|
|
|
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
|
|
|
|
|
out_handle = out_var_handle;
|
|
|
|
|
} else {
|
|
|
|
@ -116,11 +118,13 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int type = platform::ToNCCLDataType(in_tensor.type());
|
|
|
|
|
broadcast_calls.emplace_back([=] {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
|
|
|
|
send_recv_buffer, in_tensor.numel(),
|
|
|
|
|
static_cast<ncclDataType_t>(type), root, comm, stream));
|
|
|
|
|
});
|
|
|
|
|
size_t numel = static_cast<size_t>(in_tensor.numel());
|
|
|
|
|
broadcast_calls.emplace_back(
|
|
|
|
|
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
|
|
|
|
send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
|
|
|
|
|
root_id, nccl_ctx.comm_, nccl_ctx.stream()));
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
this->RunAndRecordEvent([&] {
|
|
|
|
@ -130,6 +134,7 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
call();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// TODO(zcd): Maybe the unequal operator is not appropriate here.
|
|
|
|
|
if (*out_handle != *in_var_handle) {
|
|
|
|
|
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handles[0]->name_);
|
|
|
|
@ -140,7 +145,7 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("CUDA is not support.");
|
|
|
|
|
PADDLE_THROW("CUDA is not enabled.");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|