|
|
|
@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
|
|
|
|
|
std::vector<Scope*> local_scopes_;
|
|
|
|
|
std::vector<Scope*> param_scopes_;
|
|
|
|
|
Scope g_scope_;
|
|
|
|
|
std::unique_ptr<OpHandleBase> op_handle_;
|
|
|
|
|
std::vector<std::unique_ptr<VarHandleBase>> vars_;
|
|
|
|
|
OpHandleBase* op_handle_;
|
|
|
|
|
std::vector<VarHandleBase*> vars_;
|
|
|
|
|
std::vector<std::unique_ptr<ir::Node>> nodes_;
|
|
|
|
|
std::vector<p::Place> place_list_;
|
|
|
|
|
bool use_gpu_;
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InitBroadcastOp(size_t input_scope_idx) {
|
|
|
|
|
nodes_.clear();
|
|
|
|
|
for (size_t j = 0; j < place_list_.size(); ++j) {
|
|
|
|
|
local_scopes_.push_back(&(g_scope_.NewScope()));
|
|
|
|
|
Scope& local_scope = local_scopes_.back()->NewScope();
|
|
|
|
@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
|
|
|
|
|
}
|
|
|
|
|
param_scopes_[input_scope_idx]->Var("input");
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Node> n =
|
|
|
|
|
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation);
|
|
|
|
|
nodes_.emplace_back(
|
|
|
|
|
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
|
|
|
|
|
if (use_gpu_) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_,
|
|
|
|
|
place_list_, nccl_ctxs_.get()));
|
|
|
|
|
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
|
|
|
|
|
place_list_, nccl_ctxs_.get());
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("CUDA is not support.");
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_,
|
|
|
|
|
place_list_, nccl_ctxs_.get()));
|
|
|
|
|
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
|
|
|
|
|
place_list_, nccl_ctxs_.get());
|
|
|
|
|
#else
|
|
|
|
|
op_handle_.reset(
|
|
|
|
|
new BroadcastOpHandle(n.get(), local_scopes_, place_list_));
|
|
|
|
|
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
|
|
|
|
|
place_list_);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Node> v =
|
|
|
|
|
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable);
|
|
|
|
|
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
|
|
|
|
|
place_list_[input_scope_idx]);
|
|
|
|
|
nodes_.emplace_back(
|
|
|
|
|
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
|
|
|
|
|
auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
|
|
|
|
|
"input", place_list_[input_scope_idx]);
|
|
|
|
|
vars_.emplace_back(in_var_handle);
|
|
|
|
|
op_handle_->AddInput(in_var_handle);
|
|
|
|
|
|
|
|
|
|
// add dummy var
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Node> v2 =
|
|
|
|
|
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable);
|
|
|
|
|
vars_.emplace_back(new DummyVarHandle(v2.get()));
|
|
|
|
|
nodes_.emplace_back(
|
|
|
|
|
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
|
|
|
|
|
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
|
|
|
|
|
DummyVarHandle* dummy_var_handle =
|
|
|
|
|
static_cast<DummyVarHandle*>(vars_.back().get());
|
|
|
|
|
static_cast<DummyVarHandle*>(vars_.back());
|
|
|
|
|
dummy_var_handle->ClearGeneratedOp();
|
|
|
|
|
op_handle_->AddInput(dummy_var_handle);
|
|
|
|
|
|
|
|
|
@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
|
|
|
|
|
if (!use_gpu_) {
|
|
|
|
|
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<ir::Node> v3 =
|
|
|
|
|
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable);
|
|
|
|
|
nodes_.emplace_back(
|
|
|
|
|
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
|
|
|
|
|
VarHandle* out_var_handle =
|
|
|
|
|
new VarHandle(v3.get(), 2, j, "out", place_list_[j]);
|
|
|
|
|
new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
|
|
|
|
|
vars_.emplace_back(out_var_handle);
|
|
|
|
|
op_handle_->AddOutput(out_var_handle);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// add dummy var
|
|
|
|
|
std::unique_ptr<ir::Node> v4 =
|
|
|
|
|
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable);
|
|
|
|
|
vars_.emplace_back(new DummyVarHandle(v4.get()));
|
|
|
|
|
nodes_.emplace_back(
|
|
|
|
|
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
|
|
|
|
|
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
|
|
|
|
|
DummyVarHandle* out_dummy_var_handle =
|
|
|
|
|
static_cast<DummyVarHandle*>(vars_.back().get());
|
|
|
|
|
static_cast<DummyVarHandle*>(vars_.back());
|
|
|
|
|
out_dummy_var_handle->ClearGeneratedOp();
|
|
|
|
|
op_handle_->AddOutput(out_dummy_var_handle);
|
|
|
|
|
}
|
|
|
|
|