|
|
|
@ -137,13 +137,13 @@ class VarBase {
|
|
|
|
|
persistable) {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// TODO(minqiyang): need support SelectedRows
|
|
|
|
|
VarBase(const std::string& name, framework::proto::VarType::Type dtype,
|
|
|
|
|
const framework::DDim& shape, const platform::Place& place,
|
|
|
|
|
framework::Variable* var, VarBase* grad, bool stop_gradient,
|
|
|
|
|
bool persistable)
|
|
|
|
|
: name_(name),
|
|
|
|
|
dtype_(dtype),
|
|
|
|
|
place_(place),
|
|
|
|
|
type_(framework::proto::VarType::LOD_TENSOR),
|
|
|
|
|
var_(var),
|
|
|
|
|
grads_(grad),
|
|
|
|
|
stop_gradient_(stop_gradient),
|
|
|
|
@ -153,10 +153,12 @@ class VarBase {
|
|
|
|
|
pre_op_out_idx_(-1) {
|
|
|
|
|
if (!var_) {
|
|
|
|
|
var_ = new framework::Variable();
|
|
|
|
|
auto tensor = var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(shape);
|
|
|
|
|
tensor->mutable_data(place_, dtype_);
|
|
|
|
|
}
|
|
|
|
|
auto tensor = var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(shape);
|
|
|
|
|
tensor->mutable_data(place, dtype);
|
|
|
|
|
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
|
|
|
|
|
<< " place: " << place;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
@ -186,11 +188,23 @@ class VarBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void SetDType(framework::proto::VarType::Type type) {
|
|
|
|
|
inline framework::DDim Dims() const {
|
|
|
|
|
return var_->Get<framework::LoDTensor>().dims();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// data type. e.g.. FP32
|
|
|
|
|
inline void SetDataType(framework::proto::VarType::Type type) {
|
|
|
|
|
auto tensor = var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->mutable_data(place_, dtype_);
|
|
|
|
|
tensor->mutable_data(place_, type);
|
|
|
|
|
}
|
|
|
|
|
inline framework::proto::VarType::Type DType() const { return dtype_; }
|
|
|
|
|
inline framework::proto::VarType::Type DataType() const {
|
|
|
|
|
auto tensor = var_->Get<framework::LoDTensor>();
|
|
|
|
|
return tensor.type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor type. e.g.. LoDTensor
|
|
|
|
|
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
|
|
|
|
|
inline framework::proto::VarType::Type Type() const { return type_; }
|
|
|
|
|
|
|
|
|
|
inline void SetStopGradient(bool stop_gradient) {
|
|
|
|
|
stop_gradient_ = stop_gradient;
|
|
|
|
@ -244,7 +258,7 @@ class VarBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string name_;
|
|
|
|
|
framework::proto::VarType::Type dtype_;
|
|
|
|
|
framework::proto::VarType::Type type_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
framework::Variable* var_;
|
|
|
|
@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase {
|
|
|
|
|
std::vector<VarBasePtrMap> grad_output_vars_;
|
|
|
|
|
|
|
|
|
|
std::vector<py::object> backward_hooks_;
|
|
|
|
|
|
|
|
|
|
framework::AttributeMap attrs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Layer {
|
|
|
|
@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
|
|
|
|
|
|
|
|
|
|
framework::proto::VarType::Type GetType(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return var_set_.at(name)->DType();
|
|
|
|
|
return var_set_.at(name)->Type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) override {
|
|
|
|
|
var_set_[name]->SetDType(type);
|
|
|
|
|
var_set_[name]->SetType(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::proto::VarType::Type GetDataType(
|
|
|
|
|
const std::string& name) const override {
|
|
|
|
|
return var_set_.at(name)->DType();
|
|
|
|
|
return var_set_.at(name)->DataType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataType(const std::string& name,
|
|
|
|
|
framework::proto::VarType::Type type) override {
|
|
|
|
|
var_set_[name]->SetDType(type);
|
|
|
|
|
var_set_[name]->SetDataType(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<framework::proto::VarType::Type> GetDataTypes(
|
|
|
|
|