|
|
|
@ -27,7 +27,8 @@ struct VarHandle {
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
OpHandle *generated_op_;
|
|
|
|
|
std::vector<OpHandle *> deps_ops_;
|
|
|
|
|
|
|
|
|
|
std::vector<OpHandle *> pending_ops_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpHandle {
|
|
|
|
@ -141,7 +142,7 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
auto &place = pair.first;
|
|
|
|
|
VarHandle *var = GetVarHandle(each_var_name, place);
|
|
|
|
|
op_handle->inputs_.emplace_back(var);
|
|
|
|
|
var->deps_ops_.emplace_back(op_handle);
|
|
|
|
|
var->pending_ops_.emplace_back(op_handle);
|
|
|
|
|
}
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
@ -158,7 +159,7 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
op_handle = member_->ops_.back().get();
|
|
|
|
|
auto &place = pair.first;
|
|
|
|
|
VarHandle *loss = GetVarHandle(loss_var_name, place);
|
|
|
|
|
loss->deps_ops_.emplace_back(op_handle);
|
|
|
|
|
loss->pending_ops_.emplace_back(op_handle);
|
|
|
|
|
op_handle->inputs_.emplace_back(loss);
|
|
|
|
|
GenerateVar(op_handle, loss_var_name + "@GRAD", place);
|
|
|
|
|
change_forward = true;
|
|
|
|
@ -188,7 +189,7 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
}
|
|
|
|
|
auto *prev_grad = &vars[vars.size() - 1];
|
|
|
|
|
op_handle->inputs_.emplace_back(prev_grad);
|
|
|
|
|
prev_grad->deps_ops_.emplace_back(op_handle);
|
|
|
|
|
prev_grad->pending_ops_.emplace_back(op_handle);
|
|
|
|
|
auto &var = vars[vars.size()];
|
|
|
|
|
var.place_ = place;
|
|
|
|
|
var.generated_op_ = op_handle;
|
|
|
|
@ -317,7 +318,7 @@ std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
|
|
|
|
|
|
std::vector<OpHandle *> to_run;
|
|
|
|
|
for (auto *var : to_remove) {
|
|
|
|
|
for (auto *op : var->deps_ops_) {
|
|
|
|
|
for (auto *op : var->pending_ops_) {
|
|
|
|
|
if (var->name_ == "mean_0.tmp_0@GRAD") {
|
|
|
|
|
LOG(INFO) << op->DebugString();
|
|
|
|
|
}
|
|
|
|
|