|
|
|
@ -61,8 +61,9 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
"Places must be all on CPU or all on CUDA.");
|
|
|
|
|
|
|
|
|
|
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
|
|
|
|
|
in_tensor.type());
|
|
|
|
|
VariableVisitor::GetMutableTensor(out_var)
|
|
|
|
|
.Resize(in_tensor.dims())
|
|
|
|
|
.mutable_data(out_p, in_tensor.type());
|
|
|
|
|
|
|
|
|
|
auto dev_ctx = dev_ctxes_[out_p];
|
|
|
|
|
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
|
|
|
|
@ -74,8 +75,10 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
|
|
|
|
|
for (auto &pair : dev_ctxes_) {
|
|
|
|
|
in_var.generated_op_->Wait(pair.second);
|
|
|
|
|
if (in_var.generated_op_) {
|
|
|
|
|
for (auto &pair : dev_ctxes_) {
|
|
|
|
|
in_var.generated_op_->Wait(pair.second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|