|
|
|
|
@ -18,14 +18,16 @@
|
|
|
|
|
#include "paddle/fluid/framework/python_headers.h"
|
|
|
|
|
// clang-format on
|
|
|
|
|
|
|
|
|
|
#include <map> // NOLINT
|
|
|
|
|
#include <string> // NOLINT
|
|
|
|
|
#include <vector> // NOLINT
|
|
|
|
|
#include <memory> // NOLINT
|
|
|
|
|
#include <map> // NOLINT
|
|
|
|
|
#include <string> // NOLINT
|
|
|
|
|
#include <vector> // NOLINT
|
|
|
|
|
#include <memory> // NOLINT
|
|
|
|
|
#include <unordered_map> // NOLINT
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_desc.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/framework/var_desc.h"
|
|
|
|
|
#include "paddle/fluid/framework/var_type_inference.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
@ -135,13 +137,13 @@ class VarBase {
|
|
|
|
|
persistable) {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// TODO(minqiyang): need support SelectedRows
|
|
|
|
|
VarBase(const std::string& name, framework::proto::VarType::Type dtype,
|
|
|
|
|
const framework::DDim& shape, const platform::Place& place,
|
|
|
|
|
framework::Variable* var, VarBase* grad, bool stop_gradient,
|
|
|
|
|
bool persistable)
|
|
|
|
|
: name_(name),
|
|
|
|
|
dtype_(dtype),
|
|
|
|
|
place_(place),
|
|
|
|
|
type_(framework::proto::VarType::LOD_TENSOR),
|
|
|
|
|
var_(var),
|
|
|
|
|
grads_(grad),
|
|
|
|
|
stop_gradient_(stop_gradient),
|
|
|
|
|
@ -151,10 +153,12 @@ class VarBase {
|
|
|
|
|
pre_op_out_idx_(-1) {
|
|
|
|
|
if (!var_) {
|
|
|
|
|
var_ = new framework::Variable();
|
|
|
|
|
auto tensor = var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(shape);
|
|
|
|
|
tensor->mutable_data(place_, dtype_);
|
|
|
|
|
}
|
|
|
|
|
auto tensor = var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(shape);
|
|
|
|
|
tensor->mutable_data(place, dtype);
|
|
|
|
|
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
|
|
|
|
|
<< " place: " << place;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
@ -184,7 +188,23 @@ class VarBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline framework::proto::VarType::Type DType() const { return dtype_; }
|
|
|
|
|
inline framework::DDim Dims() const {
|
|
|
|
|
return var_->Get<framework::LoDTensor>().dims();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// data type. e.g.. FP32
|
|
|
|
|
inline void SetDataType(framework::proto::VarType::Type type) {
|
|
|
|
|
auto tensor = var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->mutable_data(tensor->place(), type);
|
|
|
|
|
}
|
|
|
|
|
inline framework::proto::VarType::Type DataType() const {
|
|
|
|
|
auto tensor = var_->Get<framework::LoDTensor>();
|
|
|
|
|
return tensor.type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor type. e.g.. LoDTensor
|
|
|
|
|
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
|
|
|
|
|
inline framework::proto::VarType::Type Type() const { return type_; }
|
|
|
|
|
|
|
|
|
|
inline void SetStopGradient(bool stop_gradient) {
|
|
|
|
|
stop_gradient_ = stop_gradient;
|
|
|
|
|
@ -238,7 +258,7 @@ class VarBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string name_;
|
|
|
|
|
framework::proto::VarType::Type dtype_;
|
|
|
|
|
framework::proto::VarType::Type type_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
framework::Variable* var_;
|
|
|
|
|
@ -334,11 +354,13 @@ class PYBIND11_HIDDEN OpBase {
|
|
|
|
|
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
|
|
|
|
|
|
|
|
|
|
// Inputs to a vector of bwd ops.
|
|
|
|
|
std::vector<framework::VariableValueMap> grad_input_vars_;
|
|
|
|
|
std::vector<VarBasePtrMap> grad_input_vars_;
|
|
|
|
|
// Outputs to a vector of bwd ops.
|
|
|
|
|
std::vector<framework::VariableValueMap> grad_output_vars_;
|
|
|
|
|
std::vector<VarBasePtrMap> grad_output_vars_;
|
|
|
|
|
|
|
|
|
|
std::vector<py::object> backward_hooks_;
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Layer {
|
|
|
|
|
@ -365,12 +387,131 @@ class PyLayer {
|
|
|
|
|
static std::vector<framework::Variable*> Apply(
|
|
|
|
|
int func_id, const std::vector<VarBase*>& inputs);
|
|
|
|
|
|
|
|
|
|
static std::vector<framework::Variable*> ApplyGrad(
|
|
|
|
|
int func_id, const std::vector<framework::Variable*>& inputs);
|
|
|
|
|
static std::vector<VarBase*> ApplyGrad(int func_id,
|
|
|
|
|
const std::vector<VarBase*>& inputs);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::vector<framework::Variable*> CallPythonFunc(
|
|
|
|
|
const py::object& callable, const std::vector<framework::Variable*>& ins);
|
|
|
|
|
const py::object& callable, const std::vector<VarBase*>& ins);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// infer var type context for imperative mode
|
|
|
|
|
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
|
|
|
|
|
: public framework::InferVarTypeContext {
|
|
|
|
|
public:
|
|
|
|
|
RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs,
|
|
|
|
|
imperative::VarBasePtrMap* outputs,
|
|
|
|
|
const framework::AttributeMap* attrs_map)
|
|
|
|
|
: InferVarTypeContext(nullptr, nullptr),
|
|
|
|
|
inputs_(inputs),
|
|
|
|
|
outputs_(outputs),
|
|
|
|
|
attrs_(attrs_map),
|
|
|
|
|
input_names_(),
|
|
|
|
|
output_names_(),
|
|
|
|
|
var_set_() {
|
|
|
|
|
input_names_.reserve(inputs_->size());
|
|
|
|
|
for (auto& it : *inputs_) {
|
|
|
|
|
for (imperative::VarBase* var : it.second) {
|
|
|
|
|
input_names_[it.first].emplace_back(var->Name());
|
|
|
|
|
var_set_[var->Name()] = var;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output_names_.reserve(outputs_->size());
|
|
|
|
|
for (auto& it : *outputs_) {
|
|
|
|
|
for (imperative::VarBase* var : it.second) {
|
|
|
|
|
output_names_[it.first].emplace_back(var->Name());
|
|
|
|
|
var_set_[var->Name()] = var;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual ~RuntimeInferVarTypeContext() {}
|
|
|
|
|
|
|
|
|
|
framework::Attribute GetAttr(const std::string& name) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(attrs_);
|
|
|
|
|
return attrs_->at(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasVar(const std::string& name) const override {
|
|
|
|
|
return var_set_.count(name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(inputs_);
|
|
|
|
|
return inputs_->count(name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(const std::string& name) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(outputs_);
|
|
|
|
|
return outputs_->count(name) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Input(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return input_names_.at(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Output(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return output_names_.at(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::proto::VarType::Type GetType(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return var_set_.at(name)->Type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) override {
|
|
|
|
|
var_set_[name]->SetType(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::proto::VarType::Type GetDataType(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return var_set_.at(name)->DataType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) override {
|
|
|
|
|
var_set_[name]->SetDataType(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<framework::proto::VarType::Type> GetDataTypes(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataTypes(const std::string& name,
|
|
|
|
|
const std::vector<framework::proto::VarType::Type>&
|
|
|
|
|
multiple_data_type) override {
|
|
|
|
|
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> GetShape(const std::string& name) const override {
|
|
|
|
|
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetShape(const std::string& name,
|
|
|
|
|
const std::vector<int64_t>& dims) override {
|
|
|
|
|
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t GetLoDLevel(const std::string& name) const override {
|
|
|
|
|
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
|
|
|
|
|
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const imperative::VarBasePtrMap* inputs_;
|
|
|
|
|
imperative::VarBasePtrMap* outputs_;
|
|
|
|
|
const framework::AttributeMap* attrs_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> input_names_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> output_names_;
|
|
|
|
|
std::unordered_map<std::string, imperative::VarBase*> var_set_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace imperative
|
|
|
|
|
|