|
|
|
@ -53,42 +53,39 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
|
|
|
|
|
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
|
|
|
|
|
|
|
|
|
|
// NOTE(zcd): the Place of input can get from in_tensor and in_var_handle ,
|
|
|
|
|
// maybe they are different, because the Place that getting from in_tensor is
|
|
|
|
|
// determined at runtime, the other is determined at building SSA graph stage.
|
|
|
|
|
// If they are different, DataTransform should be applied. Currently, it has
|
|
|
|
|
// not been done yet.
|
|
|
|
|
// NOTE: The tensors' Place of input and output must be all on GPU or all on
|
|
|
|
|
// CPU.
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
|
if (*out_var_handle == *in_var_handle) {
|
|
|
|
|
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto &out_p = out_var_handle->place_;
|
|
|
|
|
auto t_out_p = out_var_handle->place_;
|
|
|
|
|
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_p.which(), in_tensor.place().which(),
|
|
|
|
|
"Currently, Places of input and output must be all on CPU "
|
|
|
|
|
"or all on GPU.");
|
|
|
|
|
if (platform::is_gpu_place(in_tensor.place())) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
|
|
|
|
|
"Places of input and output must be all on GPU.");
|
|
|
|
|
} else {
|
|
|
|
|
t_out_p = platform::CPUPlace();
|
|
|
|
|
}
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
|
|
|
|
|
in_tensor.type());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(in_tensor.place())) {
|
|
|
|
|
for (auto *out_var_handle : out_var_handles) {
|
|
|
|
|
if (*out_var_handle == *in_var_handle) {
|
|
|
|
|
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &out_p = out_var_handle->place_;
|
|
|
|
|
auto dev_ctx = dev_ctxes_.at(out_p);
|
|
|
|
|
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handle->name_);
|
|
|
|
|
|
|
|
|
|
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
|
|
|
|
|
RunAndRecordEvent(out_p, [in_tensor, out_var] {
|
|
|
|
|
paddle::framework::TensorCopy(
|
|
|
|
|
in_tensor, out_p, *dev_ctx,
|
|
|
|
|
in_tensor, platform::CPUPlace(),
|
|
|
|
|
&VariableVisitor::GetMutableTensor(out_var));
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
@ -134,8 +131,8 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
call();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// TODO(zcd): Maybe the unequal operator is not appropriate here.
|
|
|
|
|
if (*out_handle != *in_var_handle) {
|
|
|
|
|
|
|
|
|
|
if (!out_handle->IsTheSameVar(*in_var_handle)) {
|
|
|
|
|
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
|
|
|
|
|
->FindVar(out_var_handles[0]->name_);
|
|
|
|
|
paddle::framework::TensorCopy(
|
|
|
|
|