|
|
|
@ -77,14 +77,9 @@ struct TestBroadcastOpHandle {
|
|
|
|
|
local_scopes_[input_scope_idx]->Var("input");
|
|
|
|
|
|
|
|
|
|
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
|
|
|
|
|
|
|
|
|
|
vars_.emplace_back(new VarHandle());
|
|
|
|
|
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get());
|
|
|
|
|
in_var_handle->place_ = gpu_list_[input_scope_idx];
|
|
|
|
|
in_var_handle->name_ = "input";
|
|
|
|
|
in_var_handle->version_ = 1;
|
|
|
|
|
in_var_handle->scope_idx_ = input_scope_idx;
|
|
|
|
|
in_var_handle->generated_op_ = nullptr;
|
|
|
|
|
auto* in_var_handle =
|
|
|
|
|
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
|
|
|
|
|
vars_.emplace_back(in_var_handle);
|
|
|
|
|
op_handle_->AddInput(in_var_handle);
|
|
|
|
|
|
|
|
|
|
// add dummy var
|
|
|
|
@ -96,12 +91,8 @@ struct TestBroadcastOpHandle {
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < gpu_list_.size(); ++j) {
|
|
|
|
|
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
|
|
|
|
|
vars_.emplace_back(new VarHandle());
|
|
|
|
|
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get());
|
|
|
|
|
out_var_handle->place_ = gpu_list_[j];
|
|
|
|
|
out_var_handle->name_ = "out";
|
|
|
|
|
out_var_handle->version_ = 2;
|
|
|
|
|
out_var_handle->scope_idx_ = j;
|
|
|
|
|
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
|
|
|
|
|
vars_.emplace_back(out_var_handle);
|
|
|
|
|
op_handle_->AddOutput(out_var_handle);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|