Add dygraph double grad implementation (#22939)
* add double grad implementation for dygraph, test=develop * polish code, add uts, test=develop * fix place bug, test=develop * polish codes, add more uts for coverages, test=develop * add no_grad_set, test=develop * add star gan ut, test=develop * follow comments, test=developrevert-23830-2.0-beta
parent
995a6376f7
commit
a31d7328b7
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,57 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/imperative/backward_strategy.h"
|
||||
#include "paddle/fluid/imperative/engine.h"
|
||||
#include "paddle/fluid/imperative/gradient_accumulator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
class VarBase;
|
||||
class OpBase;
|
||||
|
||||
class BasicEngine : public Engine {
|
||||
public:
|
||||
void Init(VarBase* var, const detail::BackwardStrategy& strategy);
|
||||
|
||||
void Execute() override;
|
||||
|
||||
private:
|
||||
void PrepareDeps();
|
||||
|
||||
void CheckBackwardInputs(const OpBase& op);
|
||||
|
||||
void PrepareGradAccumulators(const OpBase& op);
|
||||
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
std::shared_ptr<GradOpNode> init_node_;
|
||||
detail::BackwardStrategy backward_strategy_;
|
||||
std::unordered_map<GradOpNode*, size_t> node_deps_;
|
||||
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
|
||||
accumulators_;
|
||||
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
|
||||
need_accu_var_list_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,199 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/type_defs.h"
|
||||
#include "paddle/fluid/framework/variable.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
template <typename VarType>
|
||||
class DygraphExecutionContext : public framework::ExecutionContext {
|
||||
using Variable = framework::Variable;
|
||||
|
||||
public:
|
||||
DygraphExecutionContext(const framework::OperatorBase& op,
|
||||
const framework::Scope& scope,
|
||||
const platform::DeviceContext& device_context,
|
||||
const framework::RuntimeContext& ctx,
|
||||
std::vector<framework::KernelConfig>* configs,
|
||||
const NameVarMap<VarType>& var_base_map_in,
|
||||
const NameVarMap<VarType>& var_base_map_out,
|
||||
const framework::AttributeMap& attrs)
|
||||
: ExecutionContext(op, scope, device_context, ctx, configs),
|
||||
var_base_map_in_(var_base_map_in),
|
||||
var_base_map_out_(var_base_map_out),
|
||||
attrs_(attrs) {}
|
||||
|
||||
std::string InputName(const std::string& name) const override {
|
||||
auto it = var_base_map_in_.find(name);
|
||||
PADDLE_ENFORCE_NE(it, var_base_map_in_.end(),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"Can not find [%s] in Input", name));
|
||||
return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName;
|
||||
}
|
||||
|
||||
std::vector<std::string> InputNames(const std::string& name) const override {
|
||||
auto it = var_base_map_in_.find(name);
|
||||
PADDLE_ENFORCE_NE(
|
||||
it, var_base_map_in_.end(),
|
||||
platform::errors::NotFound("Can not find [%s] in Input", name));
|
||||
std::vector<std::string> vec_res;
|
||||
vec_res.reserve(it->second.size());
|
||||
for (size_t i = 0; i < it->second.size(); ++i) {
|
||||
if (it->second[i]) {
|
||||
vec_res.push_back(it->second[i]->Name());
|
||||
} else {
|
||||
vec_res.push_back(framework::kEmptyVarName);
|
||||
}
|
||||
}
|
||||
return vec_res;
|
||||
}
|
||||
|
||||
std::string OutputName(const std::string& name) const override {
|
||||
auto it = var_base_map_out_.find(name);
|
||||
PADDLE_ENFORCE_NE(
|
||||
it, var_base_map_out_.end(),
|
||||
platform::errors::NotFound("Can not find [%s] in Output", name));
|
||||
return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName;
|
||||
}
|
||||
|
||||
std::vector<std::string> OutputNames(const std::string& name) const override {
|
||||
auto it = var_base_map_out_.find(name);
|
||||
PADDLE_ENFORCE_NE(
|
||||
it, var_base_map_out_.end(),
|
||||
platform::errors::NotFound("Can not find [%s] in Output", name));
|
||||
std::vector<std::string> vec_res;
|
||||
vec_res.reserve(it->second.size());
|
||||
for (size_t i = 0; i < it->second.size(); ++i) {
|
||||
if (it->second[i]) {
|
||||
vec_res.push_back(it->second[i]->Name());
|
||||
} else {
|
||||
vec_res.push_back(framework::kEmptyVarName);
|
||||
}
|
||||
}
|
||||
return vec_res;
|
||||
}
|
||||
|
||||
bool HasAttr(const std::string& name) const override {
|
||||
return attrs_.count(name) != 0;
|
||||
}
|
||||
|
||||
const framework::AttributeMap& Attrs() const override { return attrs_; }
|
||||
|
||||
const framework::Attribute& GetAttr(const std::string& name) const override {
|
||||
auto it = attrs_.find(name);
|
||||
|
||||
PADDLE_ENFORCE_NE(
|
||||
it, attrs_.end(),
|
||||
platform::errors::NotFound("can not find [%s] in attrs", name));
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> InNameList() const override {
|
||||
std::vector<std::string> vec_temp;
|
||||
vec_temp.reserve(var_base_map_in_.size());
|
||||
|
||||
for (auto& v : var_base_map_in_) {
|
||||
vec_temp.push_back(v.first);
|
||||
}
|
||||
|
||||
return vec_temp;
|
||||
}
|
||||
|
||||
bool HasInput(const std::string& name) const override {
|
||||
auto it = var_base_map_in_.find(name);
|
||||
return (it != var_base_map_in_.end() && it->second.size() > 0);
|
||||
}
|
||||
|
||||
bool HasOutput(const std::string& name) const override {
|
||||
auto it = var_base_map_out_.find(name);
|
||||
return (it != var_base_map_out_.end() && it->second.size() > 0);
|
||||
}
|
||||
|
||||
size_t InputSize(const std::string& name) const override {
|
||||
return InputNames(name).size();
|
||||
}
|
||||
|
||||
size_t OutputSize(const std::string& name) const override {
|
||||
return OutputNames(name).size();
|
||||
}
|
||||
|
||||
const Variable* InputVar(const std::string& name) const override {
|
||||
auto it = var_base_map_in_.find(name);
|
||||
if (it == var_base_map_in_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.empty() || it->second[0] == nullptr
|
||||
? nullptr
|
||||
: it->second[0]->MutableVar();
|
||||
}
|
||||
|
||||
Variable* OutputVar(const std::string& name) const override {
|
||||
auto it = var_base_map_out_.find(name);
|
||||
if (it == var_base_map_out_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.empty() || it->second[0] == nullptr
|
||||
? nullptr
|
||||
: it->second[0]->MutableVar();
|
||||
}
|
||||
|
||||
const std::vector<Variable*> MultiInputVar(
|
||||
const std::string& name) const override {
|
||||
auto it = var_base_map_in_.find(name);
|
||||
if (it == var_base_map_in_.end()) {
|
||||
return {};
|
||||
}
|
||||
std::vector<Variable*> vec_res;
|
||||
vec_res.reserve(it->second.size());
|
||||
for (size_t i = 0; i < it->second.size(); ++i) {
|
||||
vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
|
||||
}
|
||||
|
||||
return vec_res;
|
||||
}
|
||||
|
||||
std::vector<Variable*> MultiOutputVar(
|
||||
const std::string& name) const override {
|
||||
auto it = var_base_map_out_.find(name);
|
||||
if (it == var_base_map_out_.end()) {
|
||||
return {};
|
||||
}
|
||||
std::vector<Variable*> vec_res;
|
||||
vec_res.reserve(it->second.size());
|
||||
for (size_t i = 0; i < it->second.size(); ++i) {
|
||||
vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
|
||||
}
|
||||
|
||||
return vec_res;
|
||||
}
|
||||
|
||||
private:
|
||||
const NameVarMap<VarType>& var_base_map_in_;
|
||||
const NameVarMap<VarType>& var_base_map_out_;
|
||||
const framework::AttributeMap& attrs_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,188 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/type_defs.h"
|
||||
#include "paddle/fluid/framework/var_type_inference.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
#include "paddle/fluid/imperative/variable_wrapper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
// infer var type context for imperative mode
|
||||
template <typename VarType>
|
||||
class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
|
||||
public:
|
||||
RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
|
||||
const NameVarMap<VarType>& outputs,
|
||||
const framework::AttributeMap& attrs_map)
|
||||
: InferVarTypeContext(nullptr, nullptr),
|
||||
inputs_(inputs),
|
||||
outputs_(outputs),
|
||||
attrs_(attrs_map),
|
||||
input_names_(),
|
||||
output_names_(),
|
||||
var_set_() {
|
||||
input_names_.reserve(inputs_.size());
|
||||
for (auto& it : inputs_) {
|
||||
for (auto& var : it.second) {
|
||||
if (var) {
|
||||
input_names_[it.first].emplace_back(var->Name());
|
||||
var_set_[var->Name()] = var.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_names_.reserve(outputs_.size());
|
||||
for (auto& it : outputs_) {
|
||||
for (auto& var : it.second) {
|
||||
if (var) {
|
||||
output_names_[it.first].emplace_back(var->Name());
|
||||
var_set_[var->Name()] = var.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~RuntimeInferVarTypeContext() {}
|
||||
|
||||
framework::Attribute GetAttr(const std::string& name) const override {
|
||||
auto iter = attrs_.find(name);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
iter != attrs_.end(), true,
|
||||
platform::errors::NotFound("Cannot find attribute %s", name));
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
bool HasVar(const std::string& name) const override {
|
||||
return var_set_.count(name) > 0;
|
||||
}
|
||||
|
||||
bool HasInput(const std::string& name) const override {
|
||||
auto it = inputs_.find(name);
|
||||
return (it != inputs_.end() && it->second.size() > 0);
|
||||
}
|
||||
|
||||
bool HasOutput(const std::string& name) const override {
|
||||
auto it = outputs_.find(name);
|
||||
return (it != outputs_.end() && it->second.size() > 0);
|
||||
}
|
||||
|
||||
const std::vector<std::string>& Input(
|
||||
const std::string& name) const override {
|
||||
auto iter = input_names_.find(name);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
iter != input_names_.end(), true,
|
||||
platform::errors::NotFound("Cannot find input var %s", name));
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& Output(
|
||||
const std::string& name) const override {
|
||||
auto iter = output_names_.find(name);
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
iter != output_names_.end(), true,
|
||||
platform::errors::NotFound("Cannot find output var %s", name));
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
framework::proto::VarType::Type GetType(
|
||||
const std::string& name) const override {
|
||||
auto iter = var_set_.find(name);
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
iter != var_set_.end(), true,
|
||||
platform::errors::NotFound("Cannot find var %s in GetType", name));
|
||||
return iter->second->Type();
|
||||
}
|
||||
|
||||
void SetType(const std::string& name,
|
||||
framework::proto::VarType::Type type) override {
|
||||
if (name == "kLookupTablePath") {
|
||||
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
|
||||
} else {
|
||||
var_set_[name]->SetType(type);
|
||||
if ((var_set_[name]->MutableVar()->IsInitialized() == true) &&
|
||||
(var_set_[name]->MutableVar()->Type() != type)) {
|
||||
var_set_[name]->MutableVar()->Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
framework::proto::VarType::Type GetDataType(
|
||||
const std::string& name) const override {
|
||||
auto iter = var_set_.find(name);
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
iter != var_set_.end(), true,
|
||||
platform::errors::NotFound("Cannot find var %s in GetDataType", name));
|
||||
return iter->second->DataType();
|
||||
}
|
||||
|
||||
void SetDataType(const std::string& name,
|
||||
framework::proto::VarType::Type type) override {
|
||||
var_set_[name]->SetDataType(type);
|
||||
}
|
||||
|
||||
std::vector<framework::proto::VarType::Type> GetDataTypes(
|
||||
const std::string& name) const override {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"GetDataTypes is not supported in runtime InferVarType"));
|
||||
}
|
||||
|
||||
void SetDataTypes(const std::string& name,
|
||||
const std::vector<framework::proto::VarType::Type>&
|
||||
multiple_data_type) override {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"SetDataTypes is not supported in runtime InferVarType"));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetShape(const std::string& name) const override {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"Do not handle Shape in runtime InferVarType"));
|
||||
}
|
||||
|
||||
void SetShape(const std::string& name,
|
||||
const std::vector<int64_t>& dims) override {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"Do not handle Shape in runtime InferVarType"));
|
||||
}
|
||||
|
||||
int32_t GetLoDLevel(const std::string& name) const override {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"Do not handle LoDLevel in runtime InferVarType"));
|
||||
}
|
||||
|
||||
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"Do not handle LoDLevel in runtime InferVarType"));
|
||||
}
|
||||
|
||||
private:
|
||||
const NameVarMap<VarType>& inputs_;
|
||||
const NameVarMap<VarType>& outputs_;
|
||||
const framework::AttributeMap& attrs_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> input_names_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> output_names_;
|
||||
std::unordered_map<std::string, VarType*> var_set_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,211 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/type_defs.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
#include "paddle/fluid/imperative/variable_wrapper.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
// TODO(zjl): to support py_func layer
|
||||
class OpBase {
|
||||
public:
|
||||
OpBase() = default;
|
||||
|
||||
OpBase(const OpBase&) = delete;
|
||||
|
||||
OpBase(OpBase&&) = default;
|
||||
|
||||
OpBase& operator=(const OpBase&) = delete;
|
||||
|
||||
OpBase& operator=(OpBase&&) = default;
|
||||
|
||||
~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
|
||||
|
||||
const std::string& Type() const { return op_->Type(); }
|
||||
|
||||
const framework::AttributeMap& Attrs() const { return attrs_; }
|
||||
|
||||
const framework::OpInfo& Info() const { return op_->Info(); }
|
||||
|
||||
const framework::OperatorBase& InnerOp() const { return *op_; }
|
||||
|
||||
void ClearBackwardTrace();
|
||||
|
||||
NameVarMap<VariableWrapper>* GetMutableOutsMap() { return &outs_; }
|
||||
|
||||
NameVarMap<VariableWrapper>* GetMutableInsMap() { return &ins_; }
|
||||
|
||||
const NameVarMap<VariableWrapper>& GetInsMap() const { return ins_; }
|
||||
|
||||
const NameVarMap<VariableWrapper>& GetOutsMap() const { return outs_; }
|
||||
|
||||
void SetType(const std::string& type);
|
||||
|
||||
void CheckAttrs() {
|
||||
auto& info = op_->Info();
|
||||
if (info.Checker() != nullptr) {
|
||||
info.Checker()->Check(&attrs_, true);
|
||||
}
|
||||
}
|
||||
|
||||
void SetInput(const std::string& name, VariableWrapperList vars,
|
||||
bool is_grad) {
|
||||
auto& in_vars = ins_[name];
|
||||
*(in_vars.MutableVarList()) = std::move(vars);
|
||||
in_vars.SetIsGrad(is_grad);
|
||||
}
|
||||
|
||||
void SetOutput(const std::string& name, VariableWrapperList vars,
|
||||
bool is_grad) {
|
||||
auto& out_vars = outs_[name];
|
||||
*(out_vars.MutableVarList()) = std::move(vars);
|
||||
out_vars.SetIsGrad(is_grad);
|
||||
}
|
||||
|
||||
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
|
||||
|
||||
void SetAttr(const std::string& name, const framework::Attribute& v) {
|
||||
attrs_[name] = v;
|
||||
}
|
||||
|
||||
void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
|
||||
PADDLE_THROW(platform::errors::PermissionDenied(
|
||||
"SetBlockAttr is not support in dygraph OpBase"));
|
||||
}
|
||||
|
||||
const framework::AttributeMap& Attrs() { return attrs_; }
|
||||
|
||||
bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
|
||||
|
||||
const framework::Attribute& GetAttr(const std::string& name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE_NE(
|
||||
it, attrs_.end(),
|
||||
platform::errors::NotFound("can not find attribute [%s]", name));
|
||||
return it->second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T& Attr(const std::string& name) const {
|
||||
return boost::get<T>(GetAttr(name));
|
||||
}
|
||||
|
||||
size_t id() const { return id_; }
|
||||
|
||||
void SetId(size_t id) { id_ = id; }
|
||||
|
||||
const platform::Place& place() const { return place_; }
|
||||
|
||||
void SetPlace(const platform::Place& place) { place_ = place; }
|
||||
|
||||
static size_t GenerateUniqueId() {
|
||||
static std::atomic<size_t> unique_id{0};
|
||||
return unique_id.fetch_add(1);
|
||||
}
|
||||
|
||||
static void Run(const framework::OperatorBase& op,
|
||||
const NameVarMap<VarBase>& ins,
|
||||
const NameVarMap<VarBase>& outs,
|
||||
const framework::AttributeMap& attrs,
|
||||
const platform::Place& place);
|
||||
|
||||
static void Run(const framework::OperatorBase& op,
|
||||
const NameVarMap<VariableWrapper>& ins,
|
||||
const NameVarMap<VariableWrapper>& outs,
|
||||
const framework::AttributeMap& attrs,
|
||||
const platform::Place& place);
|
||||
|
||||
private:
|
||||
NameVarMap<VariableWrapper> ins_;
|
||||
NameVarMap<VariableWrapper> outs_;
|
||||
framework::AttributeMap attrs_;
|
||||
std::unique_ptr<framework::OperatorBase> op_;
|
||||
platform::Place place_;
|
||||
size_t id_{-1UL};
|
||||
|
||||
std::vector<std::function<void()>> backward_hooks_;
|
||||
};
|
||||
|
||||
class GradOpNode {
|
||||
public:
|
||||
GradOpNode() = default;
|
||||
|
||||
void reserve(size_t size) { ops_.reserve(size); }
|
||||
|
||||
size_t size() const { return ops_.size(); }
|
||||
|
||||
bool empty() const { return ops_.empty(); }
|
||||
|
||||
void clear() { ops_.clear(); }
|
||||
|
||||
void pop_back() { ops_.pop_back(); }
|
||||
|
||||
template <typename... ARGS>
|
||||
OpBase& emplace_back(ARGS&&... args) { // NOLINT
|
||||
ops_.emplace_back(std::forward<ARGS>(args)...);
|
||||
return ops_.back();
|
||||
}
|
||||
|
||||
const OpBase& back() const { return ops_.back(); }
|
||||
|
||||
OpBase& back() { return ops_.back(); }
|
||||
|
||||
OpBase& operator[](size_t idx) { return ops_[idx]; }
|
||||
|
||||
const OpBase& operator[](size_t idx) const { return ops_[idx]; }
|
||||
|
||||
/* Iterator related */
|
||||
using Iterator = std::vector<OpBase>::iterator;
|
||||
using ConstIterator = std::vector<OpBase>::const_iterator;
|
||||
|
||||
Iterator begin() { return ops_.begin(); }
|
||||
|
||||
Iterator end() { return ops_.end(); }
|
||||
|
||||
ConstIterator begin() const { return ops_.begin(); }
|
||||
|
||||
ConstIterator end() const { return ops_.end(); }
|
||||
|
||||
void InsertGradPendingNode(const std::shared_ptr<GradOpNode>& node) {
|
||||
if (node &&
|
||||
std::find(grad_pending_nodes_.begin(), grad_pending_nodes_.end(),
|
||||
node) == grad_pending_nodes_.end()) {
|
||||
grad_pending_nodes_.emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<GradOpNode>>& GradPendingNodes() const {
|
||||
return grad_pending_nodes_;
|
||||
}
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GradOpNode);
|
||||
|
||||
private:
|
||||
std::vector<OpBase> ops_;
|
||||
std::vector<std::shared_ptr<GradOpNode>> grad_pending_nodes_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/imperative/backward_strategy.h"
|
||||
#include "paddle/fluid/imperative/engine.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
class VarBase;
|
||||
|
||||
class PartialGradEngine : public Engine {
|
||||
public:
|
||||
PartialGradEngine(const std::vector<std::shared_ptr<VarBase>> &input_targets,
|
||||
const std::vector<std::shared_ptr<VarBase>> &output_targets,
|
||||
const std::vector<std::shared_ptr<VarBase>> &output_grads,
|
||||
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
|
||||
const platform::Place &place,
|
||||
const detail::BackwardStrategy &strategy,
|
||||
bool create_graph);
|
||||
|
||||
void Execute() override;
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> GetResult() const;
|
||||
|
||||
private:
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<VarBase>> input_targets_;
|
||||
std::vector<std::shared_ptr<VarBase>> output_targets_;
|
||||
std::vector<std::shared_ptr<VarBase>> output_grads_;
|
||||
std::vector<std::shared_ptr<VarBase>> no_grad_vars_;
|
||||
platform::Place place_;
|
||||
detail::BackwardStrategy strategy_;
|
||||
bool create_graph_;
|
||||
|
||||
std::vector<std::shared_ptr<VarBase>> results_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
|
||||
class VariableWrapper;
|
||||
|
||||
class SavedVariableWrapperList {
|
||||
public:
|
||||
SavedVariableWrapperList() : vars_(), is_grad_(false) {}
|
||||
|
||||
template <typename... Args>
|
||||
explicit SavedVariableWrapperList(bool is_grad, Args&&... args)
|
||||
: vars_(std::forward<Args>(args)...), is_grad_(is_grad) {}
|
||||
|
||||
bool IsGrad() const { return is_grad_; }
|
||||
|
||||
void SetIsGrad(bool is_grad) { is_grad_ = is_grad; }
|
||||
|
||||
const std::vector<std::shared_ptr<VariableWrapper>>& VarList() const {
|
||||
return vars_;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<VariableWrapper>>* MutableVarList() {
|
||||
return &vars_;
|
||||
}
|
||||
|
||||
/* Borrow method from std::vector */
|
||||
size_t size() const { return vars_.size(); }
|
||||
|
||||
bool empty() const { return vars_.empty(); }
|
||||
|
||||
template <typename... ARGS>
|
||||
void emplace_back(ARGS&&... args) {
|
||||
vars_.emplace_back(std::forward<ARGS>(args)...);
|
||||
}
|
||||
|
||||
using Iterator = std::vector<std::shared_ptr<VariableWrapper>>::iterator;
|
||||
|
||||
using ConstIterator =
|
||||
std::vector<std::shared_ptr<VariableWrapper>>::const_iterator;
|
||||
|
||||
Iterator begin() { return vars_.begin(); }
|
||||
|
||||
Iterator end() { return vars_.end(); }
|
||||
|
||||
ConstIterator begin() const { return vars_.begin(); }
|
||||
|
||||
ConstIterator end() const { return vars_.end(); }
|
||||
|
||||
std::shared_ptr<VariableWrapper>& operator[](size_t idx) {
|
||||
return vars_[idx];
|
||||
}
|
||||
|
||||
const std::shared_ptr<VariableWrapper>& operator[](size_t idx) const {
|
||||
return vars_[idx];
|
||||
}
|
||||
|
||||
operator const std::vector<std::shared_ptr<VariableWrapper>>&() const {
|
||||
return vars_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<VariableWrapper>> vars_;
|
||||
bool is_grad_;
|
||||
};
|
||||
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue