diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5d1b34537c..181f08d028 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -21,11 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) -cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory) -cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory) - cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) +cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory) +cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory) cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 0fb54a1d3e..0bc3ee78d6 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -61,8 +61,9 @@ void BroadcastOpHandle::RunImpl() { "Places must be all on CPU or all on CUDA."); VariableVisitor::ShareDimsAndLoD(*in_var, out_var); - VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, - in_tensor.type()); + VariableVisitor::GetMutableTensor(out_var) + .Resize(in_tensor.dims()) + .mutable_data(out_p, in_tensor.type()); auto dev_ctx = dev_ctxes_[out_p]; RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { @@ -74,8 +75,10 @@ void BroadcastOpHandle::RunImpl() { } void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { - for (auto &pair : dev_ctxes_) { - in_var.generated_op_->Wait(pair.second); + if (in_var.generated_op_) { + for (auto &pair : dev_ctxes_) { + in_var.generated_op_->Wait(pair.second); + } } }