|
|
|
|
@ -1015,34 +1015,28 @@ PartialGradEngine::PartialGradEngine(
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
|
|
|
|
|
const platform::Place &place, const detail::BackwardStrategy &strategy,
|
|
|
|
|
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
|
|
|
|
|
: input_targets_(input_targets),
|
|
|
|
|
output_targets_(output_targets),
|
|
|
|
|
output_grads_(output_grads),
|
|
|
|
|
no_grad_vars_(no_grad_vars),
|
|
|
|
|
place_(place),
|
|
|
|
|
strategy_(strategy),
|
|
|
|
|
create_graph_(create_graph),
|
|
|
|
|
retain_graph_(retain_graph),
|
|
|
|
|
allow_unused_(allow_unused),
|
|
|
|
|
only_inputs_(only_inputs) {}
|
|
|
|
|
: task_(new PartialGradTask(input_targets, output_targets, output_grads,
|
|
|
|
|
no_grad_vars, place, strategy, create_graph,
|
|
|
|
|
retain_graph, allow_unused, only_inputs)) {}
|
|
|
|
|
|
|
|
|
|
PartialGradEngine::~PartialGradEngine() { Clear(); }
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const {
|
|
|
|
|
return results_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PartialGradEngine::Clear() {
|
|
|
|
|
input_targets_.clear();
|
|
|
|
|
output_targets_.clear();
|
|
|
|
|
output_grads_.clear();
|
|
|
|
|
no_grad_vars_.clear();
|
|
|
|
|
if (task_) {
|
|
|
|
|
delete task_;
|
|
|
|
|
task_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PartialGradEngine::Execute() {
|
|
|
|
|
PartialGradTask task(input_targets_, output_targets_, output_grads_,
|
|
|
|
|
no_grad_vars_, place_, strategy_, create_graph_,
|
|
|
|
|
retain_graph_, allow_unused_, only_inputs_);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied(
|
|
|
|
|
"PartialGradEngine has been destructed"));
|
|
|
|
|
VLOG(10) << "Starts to execute PartialGradEngine";
|
|
|
|
|
results_ = task.Run();
|
|
|
|
|
results_ = task_->Run();
|
|
|
|
|
Clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|