|
|
|
@ -17,11 +17,14 @@
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_desc.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/framework/var_desc.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/imperative/type_defs.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
@ -79,6 +82,11 @@ class PreparedOp {
|
|
|
|
|
};
|
|
|
|
|
class OpBase;
|
|
|
|
|
|
|
|
|
|
/* The wrapper for Variable which holds a Variable and a VarBase of its
|
|
|
|
|
* gradient. This object should be managed totally by Python intepreter.
|
|
|
|
|
*
|
|
|
|
|
* Nearly all interface should be implemented in C++.
|
|
|
|
|
*/
|
|
|
|
|
class VarBase {
|
|
|
|
|
public:
|
|
|
|
|
VarBase()
|
|
|
|
@ -86,7 +94,7 @@ class VarBase {
|
|
|
|
|
pre_op_out_idx_(-1),
|
|
|
|
|
var_desc_(nullptr),
|
|
|
|
|
var_(new framework::Variable()),
|
|
|
|
|
grads_(new framework::Variable()),
|
|
|
|
|
grads_(new VarBase(true)),
|
|
|
|
|
stop_gradient_(false) {}
|
|
|
|
|
|
|
|
|
|
explicit VarBase(bool stop_gradient)
|
|
|
|
@ -94,7 +102,7 @@ class VarBase {
|
|
|
|
|
pre_op_out_idx_(-1),
|
|
|
|
|
var_desc_(nullptr),
|
|
|
|
|
var_(new framework::Variable()),
|
|
|
|
|
grads_(new framework::Variable()),
|
|
|
|
|
grads_(stop_gradient ? nullptr : new VarBase(true)),
|
|
|
|
|
stop_gradient_(stop_gradient) {}
|
|
|
|
|
|
|
|
|
|
virtual ~VarBase() {}
|
|
|
|
@ -116,11 +124,14 @@ class VarBase {
|
|
|
|
|
|
|
|
|
|
framework::VarDesc* var_desc_;
|
|
|
|
|
framework::Variable* var_;
|
|
|
|
|
framework::Variable* grads_;
|
|
|
|
|
VarBase* grads_;
|
|
|
|
|
|
|
|
|
|
bool stop_gradient_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
|
|
|
|
|
* gradient. This object should be managed totally by Python intepreter.
|
|
|
|
|
*/
|
|
|
|
|
class OpBase {
|
|
|
|
|
public:
|
|
|
|
|
OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {}
|
|
|
|
@ -134,13 +145,13 @@ class OpBase {
|
|
|
|
|
framework::OpDesc* op_desc_;
|
|
|
|
|
framework::OpDesc* grad_op_desc_;
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> input_vars_;
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> output_vars_;
|
|
|
|
|
std::map<std::string, std::vector<OpBase*>> pre_ops_;
|
|
|
|
|
VarBasePtrMap input_vars_;
|
|
|
|
|
VarBasePtrMap output_vars_;
|
|
|
|
|
OpBasePtrMap pre_ops_;
|
|
|
|
|
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<framework::Variable*>> grad_input_vars_;
|
|
|
|
|
std::map<std::string, std::vector<framework::Variable*>> grad_output_vars_;
|
|
|
|
|
framework::VariableValueMap grad_input_vars_;
|
|
|
|
|
framework::VariableValueMap grad_output_vars_;
|
|
|
|
|
framework::BlockDesc* block_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|