GradMaker for dygraph (#19706)
* refactor dygraph,test=develop * fix failed unittest,test=develop * polish code,test=develop * check windows ci error,test=develop try to fix windows ci error by np.allclose,test=develop * polish vlog and profiler, test=develop * try to fix preceding ops order,test=develop * test transformer in windows ci, test=develop * use python c-api to speed up tracer.trace,test=develop * test=develop, fix docker with paddle nccl problem * test=develop, add ut for debug string and gradient_accumulator * test=develop, add tests for layer/gradient_accumulator/prepared_op * test=develop, fix complie error for test_prepared_op * test=develop, add more ut for dygraph * test=develop, create API.spec for dygraph api change * optimize grad maker; test=develop * optimize grad maker * test * grad make optim; test=develop * fix unittest bugs; test=develop * add dygraph grad op maker and split_op * grad op maker refactor; test=develop * add dygraph grad maker; test=develop * fix op deformable_conv_v1_op bug; test=develop * fix deformable_conv prroi pool bugs; * fix new op grad op maker bug; test=develop * fix split by ref bug; test=develop * fix dygraph auto prune bug; test=develop * fix test_trace bug; test=develop * fix fused emb seq pool bug; test=develop * remove useless code in op_desc file; test=develop * remove useless code, StrVarBaseNode; test=develop * fix review issues; test=develop * fix rank_loss grad maker; test=develop * remove flag in VarBase; test=develop * fix distributed_notify_op compile bug ; test=develop * fix reshape op double grad; test=develop * fix expand as op; test=develop * add impertive type_defs.h for demo_train; test=develop * fix inference lib cmake; test=develop * fix inference lib; test=develop * fix infernce_lib; test=develop * fix inference cmake; test=develop * fix inference lib; test=develop * fix inference lib; test=develop * remove condition dygraph grad maker, modify local name; test=develop * fix split grad maker bug; test=develop * fix pyramid_op bug; test=develop * change travis time out limit; test=develop * restore travis; test=develop * change timeout limit; test=developyaoxuefeng
parent
b741761098
commit
8c4573a3cb
@ -0,0 +1,153 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/imperative/layer.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
class GradOpBaseMakerBase {
|
||||
public:
|
||||
explicit GradOpBaseMakerBase(const OpBase* fw_op_base,
|
||||
const NameVarBaseMap& var_base_map_in,
|
||||
const NameVarBaseMap& var_base_map_out)
|
||||
: fw_op_base_(fw_op_base),
|
||||
var_base_map_in_(var_base_map_in),
|
||||
var_base_map_out_(var_base_map_out) {}
|
||||
|
||||
virtual ~GradOpBaseMakerBase() = default;
|
||||
virtual std::vector<std::unique_ptr<OpBase>> operator()() const = 0;
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> InputGrad(
|
||||
const std::string& name, bool drop_empty_grad = true) const {
|
||||
return GetVarBaseList(name, true, true);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> OutputGrad(
|
||||
const std::string& name) const {
|
||||
return GetVarBaseList(name, true, false);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> Input(const std::string name) const {
|
||||
return GetVarBaseList(name, false, true);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> Output(const std::string& name) const {
|
||||
return GetVarBaseList(name, false, false);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> Empty() const { return {}; }
|
||||
|
||||
std::vector<std::string> InputNames() const {
|
||||
std::vector<std::string> vec_temp;
|
||||
vec_temp.reserve(var_base_map_in_.size());
|
||||
for (auto& it : var_base_map_in_) {
|
||||
vec_temp.emplace_back(it.first);
|
||||
}
|
||||
|
||||
return vec_temp;
|
||||
}
|
||||
|
||||
std::vector<std::string> OutputNames() const {
|
||||
std::vector<std::string> vec_temp;
|
||||
vec_temp.reserve(var_base_map_out_.size());
|
||||
for (auto& it : var_base_map_out_) {
|
||||
vec_temp.emplace_back(it.first);
|
||||
}
|
||||
|
||||
return vec_temp;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, framework::Attribute>& Attrs() const {
|
||||
return fw_op_base_->Attrs();
|
||||
}
|
||||
|
||||
const framework::Attribute& GetAttr(const std::string& name) const {
|
||||
auto& map = fw_op_base_->Attrs();
|
||||
auto it = map.find(name);
|
||||
PADDLE_ENFORCE(it != map.end(),
|
||||
"Cannot find attribute [%s] in operator [%s]", name,
|
||||
fw_op_base_->Type());
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T& Attr(const std::string& name) const {
|
||||
return boost::get<T>(GetAttr(name));
|
||||
}
|
||||
|
||||
std::string ForwardOpType() const { return fw_op_base_->Type(); }
|
||||
|
||||
protected:
|
||||
bool HasInput(const std::string& name) const {
|
||||
auto it = var_base_map_in_.find(name);
|
||||
|
||||
return it != var_base_map_in_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<VarBase>> GetVarBaseList(const std::string& name,
|
||||
bool is_grad,
|
||||
bool is_input) const {
|
||||
const NameVarBaseMap& data_map =
|
||||
is_input ? var_base_map_in_ : var_base_map_out_;
|
||||
auto iterator = data_map.find(name);
|
||||
|
||||
std::vector<std::shared_ptr<imperative::VarBase>> vec_temp;
|
||||
if (iterator != data_map.end()) {
|
||||
vec_temp.reserve(iterator->second.size());
|
||||
|
||||
for (auto& var_base_temp : iterator->second) {
|
||||
if (is_grad) {
|
||||
PADDLE_ENFORCE_NOT_NULL(var_base_temp->GradVarBase(),
|
||||
"VarBase grad of OP [%s] should not be null",
|
||||
fw_op_base_->Type());
|
||||
auto grad_var_base_tmp = var_base_temp->GradVarBase();
|
||||
auto* tensor = grad_var_base_tmp->MutableVar()
|
||||
->GetMutable<framework::LoDTensor>();
|
||||
tensor->Resize(
|
||||
var_base_temp->Var().Get<framework::LoDTensor>().dims());
|
||||
|
||||
vec_temp.emplace_back(grad_var_base_tmp);
|
||||
} else {
|
||||
vec_temp.emplace_back(var_base_temp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vec_temp;
|
||||
}
|
||||
|
||||
private:
|
||||
const OpBase* fw_op_base_;
|
||||
const NameVarBaseMap& var_base_map_in_;
|
||||
const NameVarBaseMap& var_base_map_out_;
|
||||
|
||||
protected:
|
||||
std::vector<framework::BlockDesc*> grad_block_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue