|
|
|
@ -57,8 +57,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
|
|
|
|
|
const platform::Place &p,
|
|
|
|
|
const size_t &i) const {
|
|
|
|
|
const size_t &i,
|
|
|
|
|
bool create_output) const {
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
|
auto var_names = op->InputArgumentNames();
|
|
|
|
|
|
|
|
|
@ -66,10 +69,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
|
|
|
|
|
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
if (create_output) {
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -100,9 +105,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
if (!is_forwarding && op->Type() == "send") {
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
auto *s = local_scopes_[0];
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, i);
|
|
|
|
|
// FIXME(wuyi): send op always copy from GPU 0
|
|
|
|
|
result.ops_.emplace_back(new SendOpHandle(*op, s));
|
|
|
|
|
// Create inputs for output on original place and no ssa output
|
|
|
|
|
// is created for send op.
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, 0, false);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -112,23 +119,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
|
|
|
|
|
auto *op_handle = result.ops_.back().get();
|
|
|
|
|
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, i);
|
|
|
|
|
// auto var_names = op->InputArgumentNames();
|
|
|
|
|
|
|
|
|
|
// for (auto &each_var_name : var_names) {
|
|
|
|
|
// VarHandle *var =
|
|
|
|
|
// CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
|
|
|
|
|
// op_handle->AddInput(var);
|
|
|
|
|
// }
|
|
|
|
|
auto var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
// for (auto &each_var_name : var_names) {
|
|
|
|
|
// CreateOpOutput(&result, op_handle, each_var_name, p, i);
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
if (is_forwarding) {
|
|
|
|
|
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|