|
|
|
@ -27,6 +27,7 @@
|
|
|
|
|
#include "paddle/fluid/framework/tensor_util.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
#include "paddle/fluid/platform/profiler.h"
|
|
|
|
|
#include "paddle/fluid/string/printf.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -256,7 +257,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
const detail::BackwardStrategy& bck_stratedy) {
|
|
|
|
|
PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0,
|
|
|
|
|
"%s has no backward implementation", Type());
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "apply op grad: " << Type();
|
|
|
|
|
std::vector<VarBasePtrMap> tmp_grad_outputs;
|
|
|
|
|
if (backward_id_ > 0) {
|
|
|
|
@ -272,8 +272,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
tmp_grad_outputs.resize(grad_op_count);
|
|
|
|
|
for (size_t k = 0; k < grad_op_count; ++k) {
|
|
|
|
|
framework::OpDesc* grad_op_desc = grad_op_descs_[k];
|
|
|
|
|
platform::RecordEvent record_event(grad_op_desc->Type());
|
|
|
|
|
auto& grad_output_variable_map = grad_output_vars_[k];
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "apply grad op " << grad_op_desc->Type();
|
|
|
|
|
|
|
|
|
|
// Allocate tmp grad output variable
|
|
|
|
@ -345,6 +345,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event("merge_grads");
|
|
|
|
|
// Add tmp grad outputs to original grad vars
|
|
|
|
|
for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
|
|
|
|
|
for (const auto& it : grad_output_vars_[k]) {
|
|
|
|
@ -424,7 +425,7 @@ void OpBase::RegisterBackwardHooks(const py::object& callable) {
|
|
|
|
|
|
|
|
|
|
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
|
|
|
|
|
if (!pre_op_) return;
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event("Imperative Backward");
|
|
|
|
|
VLOG(3) << "start backward";
|
|
|
|
|
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
operators::math::set_constant(
|
|
|
|
|