|
|
|
@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
|
|
|
|
|
const platform::Place &p,
|
|
|
|
|
const size_t &i,
|
|
|
|
|
bool create_output) const {
|
|
|
|
|
const size_t &i) const {
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
@ -69,12 +68,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
|
|
|
|
|
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
if (create_output) {
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
|
}
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -106,10 +104,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
auto *s = local_scopes_[0];
|
|
|
|
|
// FIXME(wuyi): send op always copy from GPU 0
|
|
|
|
|
result.ops_.emplace_back(new SendOpHandle(*op, s));
|
|
|
|
|
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
|
|
|
|
|
// Create inputs for output on original place and no ssa output
|
|
|
|
|
// is created for send op.
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, 0, false);
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, 0);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|