|
|
|
@ -48,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
auto *in_var =
|
|
|
|
|
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var);
|
|
|
|
|
|
|
|
|
|
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
|
|
|
|
|
|
|
|
|
|
// NOTE: The tensors' Place of input and output must be all on GPU or all on
|
|
|
|
|
// CPU.
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
|
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto t_out_p = out_var_handle->place_;
|
|
|
|
|
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var);
|
|
|
|
|
if (platform::is_gpu_place(in_tensor.place())) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
|
|
|
|
|
"Places of input and output must be all on GPU.");
|
|
|
|
|
} else {
|
|
|
|
|
t_out_p = platform::CPUPlace();
|
|
|
|
|
}
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
|
|
|
|
|
in_tensor.type());
|
|
|
|
|
}
|
|
|
|
|
InitOutputValue(*in_var_handle, out_var_handles);
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(in_tensor.place())) {
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
@ -145,6 +125,40 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BroadcastOpHandle::InitOutputValue(
|
|
|
|
|
const VarHandle &in_var_handle,
|
|
|
|
|
const std::vector<VarHandle *> &out_var_handles) const {
|
|
|
|
|
std::vector<const Scope *> var_scopes;
|
|
|
|
|
for (auto *s : local_scopes_) {
|
|
|
|
|
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
|
|
|
|
|
}
|
|
|
|
|
auto *in_var =
|
|
|
|
|
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
|
|
|
|
|
|
|
|
|
|
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
|
|
|
|
|
|
|
|
|
|
// NOTE: The tensors' Place of input and output must be all on GPU or all on
|
|
|
|
|
// CPU.
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
|
if (out_var_handle->IsTheSameVar(in_var_handle)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto t_out_p = out_var_handle->place_;
|
|
|
|
|
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var);
|
|
|
|
|
if (is_gpu_place(in_tensor.place())) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
|
|
|
|
|
"Places of input and output must be all on GPU.");
|
|
|
|
|
} else {
|
|
|
|
|
t_out_p = platform::CPUPlace();
|
|
|
|
|
}
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
|
|
|
|
|
in_tensor.type());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace framework
|
|
|
|
|