|
|
|
@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
const platform::Place &p,
|
|
|
|
|
const size_t &i) const {
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p);
|
|
|
|
|
|
|
|
|
|
auto var_names = op->InputArgumentNames();
|
|
|
|
|
auto var_names = op.InputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
var_names = op.OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
@ -107,7 +107,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
|
|
|
|
|
// Create inputs for output on original place and no ssa output
|
|
|
|
|
// is created for send op.
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, 0);
|
|
|
|
|
CreateOpHandleIOs(&result, *op, p, 0);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -117,7 +117,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
|
|
|
|
|
auto *op_handle = result.ops_.back().get();
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, i);
|
|
|
|
|
CreateOpHandleIOs(&result, *op, p, i);
|
|
|
|
|
|
|
|
|
|
auto var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|