|
|
|
@ -57,16 +57,15 @@ namespace imperative {
|
|
|
|
|
static void GetGraphInfoBetweenTargets(
|
|
|
|
|
std::unordered_set<VariableWrapper *> *input_target_grads,
|
|
|
|
|
std::unordered_set<VarBase *> *output_targets,
|
|
|
|
|
std::unordered_set<const OpBase *> *startup_ops_ptr,
|
|
|
|
|
std::unordered_map<const OpBase *, std::unordered_set<const OpBase *>>
|
|
|
|
|
*pending_ops_ptr,
|
|
|
|
|
std::unordered_map<const OpBase *, size_t> *op_deps_ptr,
|
|
|
|
|
std::unordered_set<OpBase *> *startup_ops_ptr,
|
|
|
|
|
std::unordered_map<OpBase *, std::unordered_set<OpBase *>> *pending_ops_ptr,
|
|
|
|
|
std::unordered_map<OpBase *, size_t> *op_deps_ptr,
|
|
|
|
|
std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
|
|
|
|
|
const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
|
|
|
|
|
/**
|
|
|
|
|
* Step 1. Find the candidate startup grad ops, prepared for following BFS.
|
|
|
|
|
*/
|
|
|
|
|
std::queue<std::pair<const OpBase *, const GradOpNode *>> q;
|
|
|
|
|
std::queue<std::pair<OpBase *, GradOpNode *>> q;
|
|
|
|
|
std::unordered_set<GradOpNode *> visited;
|
|
|
|
|
for (auto iter = output_targets->begin(); iter != output_targets->end();) {
|
|
|
|
|
auto *output_target = *iter;
|
|
|
|
@ -98,9 +97,8 @@ static void GetGraphInfoBetweenTargets(
|
|
|
|
|
* not all input_target_grads would be found.
|
|
|
|
|
*/
|
|
|
|
|
std::unordered_set<VariableWrapper *> found_input_target_grads;
|
|
|
|
|
std::unordered_set<const OpBase *> endpoint_ops;
|
|
|
|
|
std::unordered_map<const OpBase *, std::unordered_set<const OpBase *>>
|
|
|
|
|
preceding_ops;
|
|
|
|
|
std::unordered_set<OpBase *> endpoint_ops;
|
|
|
|
|
std::unordered_map<OpBase *, std::unordered_set<OpBase *>> preceding_ops;
|
|
|
|
|
while (!q.empty()) {
|
|
|
|
|
auto op_node_pair = q.front();
|
|
|
|
|
q.pop();
|
|
|
|
@ -153,8 +151,7 @@ static void GetGraphInfoBetweenTargets(
|
|
|
|
|
auto &target_vars = *related_grad_vars_ptr;
|
|
|
|
|
target_vars = *input_target_grads;
|
|
|
|
|
|
|
|
|
|
std::queue<std::pair<const OpBase * /*op*/, const OpBase * /*pending op*/>>
|
|
|
|
|
op_queue;
|
|
|
|
|
std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
|
|
|
|
|
for (auto &endpoint_op : endpoint_ops) {
|
|
|
|
|
op_queue.emplace(endpoint_op, nullptr);
|
|
|
|
|
}
|
|
|
|
@ -238,7 +235,7 @@ static void GetGraphInfoBetweenTargets(
|
|
|
|
|
for (auto iter = output_targets->begin(); iter != output_targets->end();) {
|
|
|
|
|
auto &grad_node = (*iter)->GradVarBase()->GradNode();
|
|
|
|
|
bool is_valid = std::find_if(grad_node->begin(), grad_node->end(),
|
|
|
|
|
[&](const OpBase &op) {
|
|
|
|
|
[&](OpBase &op) { // NOLINT
|
|
|
|
|
return startup_ops.count(&op) > 0;
|
|
|
|
|
}) != grad_node->end();
|
|
|
|
|
if (is_valid) {
|
|
|
|
@ -518,12 +515,13 @@ class PartialGradTask {
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &output_grads,
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
|
|
|
|
|
const platform::Place &place,
|
|
|
|
|
const detail::BackwardStrategy &strategy, bool create_graph);
|
|
|
|
|
const detail::BackwardStrategy &strategy, bool create_graph,
|
|
|
|
|
bool retain_graph, bool allow_unused, bool only_inputs);
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<VarBase>> Run();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunEachOp(const OpBase *op);
|
|
|
|
|
void RunEachOp(OpBase *op);
|
|
|
|
|
|
|
|
|
|
void PrepareInitialReadyVarsMap(const OpBase *op);
|
|
|
|
|
|
|
|
|
@ -536,10 +534,9 @@ class PartialGradTask {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unordered_set<const OpBase *> startup_ops_;
|
|
|
|
|
std::unordered_map<const OpBase *, std::unordered_set<const OpBase *>>
|
|
|
|
|
pending_ops_;
|
|
|
|
|
std::unordered_map<const OpBase *, size_t> op_deps_;
|
|
|
|
|
std::unordered_set<OpBase *> startup_ops_;
|
|
|
|
|
std::unordered_map<OpBase *, std::unordered_set<OpBase *>> pending_ops_;
|
|
|
|
|
std::unordered_map<OpBase *, size_t> op_deps_;
|
|
|
|
|
|
|
|
|
|
ReadyGradVarInfoMap ready_grad_vars_;
|
|
|
|
|
|
|
|
|
@ -562,6 +559,9 @@ class PartialGradTask {
|
|
|
|
|
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
bool create_graph_;
|
|
|
|
|
bool retain_graph_;
|
|
|
|
|
bool allow_unused_;
|
|
|
|
|
bool only_inputs_;
|
|
|
|
|
detail::BackwardStrategy strategy_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -571,12 +571,19 @@ PartialGradTask::PartialGradTask(
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &output_grads,
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
|
|
|
|
|
const platform::Place &place, const detail::BackwardStrategy &strategy,
|
|
|
|
|
bool create_graph) {
|
|
|
|
|
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) {
|
|
|
|
|
input_targets_ = input_targets;
|
|
|
|
|
place_ = place;
|
|
|
|
|
create_graph_ = create_graph;
|
|
|
|
|
retain_graph_ = retain_graph;
|
|
|
|
|
allow_unused_ = allow_unused;
|
|
|
|
|
only_inputs_ = only_inputs;
|
|
|
|
|
strategy_ = strategy;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(only_inputs_, true,
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"only_inputs=False is not supported yet"));
|
|
|
|
|
|
|
|
|
|
for (auto &var : no_grad_vars) {
|
|
|
|
|
if (var && var->GradVarBase()) {
|
|
|
|
|
no_grad_var_grad_.insert(var->GradVarBase()->SharedVar().get());
|
|
|
|
@ -738,7 +745,7 @@ PartialGradTask::PartialGradTask(
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
|
|
|
|
|
VLOG(10) << "Startup op number " << startup_ops_.size();
|
|
|
|
|
std::queue<const OpBase *> q;
|
|
|
|
|
std::queue<OpBase *> q;
|
|
|
|
|
for (auto *op : startup_ops_) {
|
|
|
|
|
q.push(op);
|
|
|
|
|
}
|
|
|
|
@ -746,8 +753,13 @@ std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
|
|
|
|
|
while (!q.empty()) {
|
|
|
|
|
auto *op = q.front();
|
|
|
|
|
q.pop();
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "Start to run " << op->Type();
|
|
|
|
|
op->EnforceHasInOut();
|
|
|
|
|
RunEachOp(op);
|
|
|
|
|
if (!retain_graph_) {
|
|
|
|
|
op->ClearBackwardTrace();
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "End to run " << op->Type();
|
|
|
|
|
|
|
|
|
|
auto iter = pending_ops_.find(op);
|
|
|
|
@ -773,7 +785,7 @@ std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
|
|
|
|
|
return CreateResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PartialGradTask::RunEachOp(const OpBase *op) {
|
|
|
|
|
void PartialGradTask::RunEachOp(OpBase *op) {
|
|
|
|
|
// Prepare new inputs
|
|
|
|
|
NameVarMap<VarBase> tmp_ins;
|
|
|
|
|
for (auto &input_pair : op->GetInsMap()) {
|
|
|
|
@ -960,7 +972,8 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) {
|
|
|
|
|
std::vector<std::shared_ptr<VarBase>> PartialGradTask::CreateResult() {
|
|
|
|
|
std::vector<std::shared_ptr<VarBase>> result;
|
|
|
|
|
result.reserve(input_targets_.size());
|
|
|
|
|
for (auto &input_target : input_targets_) {
|
|
|
|
|
for (size_t i = 0; i < input_targets_.size(); ++i) {
|
|
|
|
|
auto &input_target = input_targets_[i];
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
input_target->GradVarBase(),
|
|
|
|
|
platform::errors::InvalidArgument("input should have gradient"));
|
|
|
|
@ -971,6 +984,12 @@ std::vector<std::shared_ptr<VarBase>> PartialGradTask::CreateResult() {
|
|
|
|
|
ready_var->SetOverridedStopGradient(!create_graph_);
|
|
|
|
|
result.emplace_back(std::move(ready_var));
|
|
|
|
|
} else { // return None if it does not appear in the graph
|
|
|
|
|
PADDLE_ENFORCE_EQ(allow_unused_, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The %d-th input does not appear in the backward "
|
|
|
|
|
"graph. Please check the input variable or set "
|
|
|
|
|
"allow_unused=True to get None result.",
|
|
|
|
|
i));
|
|
|
|
|
result.emplace_back();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -995,14 +1014,17 @@ PartialGradEngine::PartialGradEngine(
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &output_grads,
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
|
|
|
|
|
const platform::Place &place, const detail::BackwardStrategy &strategy,
|
|
|
|
|
bool create_graph)
|
|
|
|
|
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
|
|
|
|
|
: input_targets_(input_targets),
|
|
|
|
|
output_targets_(output_targets),
|
|
|
|
|
output_grads_(output_grads),
|
|
|
|
|
no_grad_vars_(no_grad_vars),
|
|
|
|
|
place_(place),
|
|
|
|
|
strategy_(strategy),
|
|
|
|
|
create_graph_(create_graph) {}
|
|
|
|
|
create_graph_(create_graph),
|
|
|
|
|
retain_graph_(retain_graph),
|
|
|
|
|
allow_unused_(allow_unused),
|
|
|
|
|
only_inputs_(only_inputs) {}
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const {
|
|
|
|
|
return results_;
|
|
|
|
@ -1017,7 +1039,8 @@ void PartialGradEngine::Clear() {
|
|
|
|
|
|
|
|
|
|
void PartialGradEngine::Execute() {
|
|
|
|
|
PartialGradTask task(input_targets_, output_targets_, output_grads_,
|
|
|
|
|
no_grad_vars_, place_, strategy_, create_graph_);
|
|
|
|
|
no_grad_vars_, place_, strategy_, create_graph_,
|
|
|
|
|
retain_graph_, allow_unused_, only_inputs_);
|
|
|
|
|
VLOG(10) << "Starts to execute PartialGradEngine";
|
|
|
|
|
results_ = task.Run();
|
|
|
|
|
Clear();
|
|
|
|
|