|
|
|
@ -130,16 +130,15 @@ public:
|
|
|
|
|
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataType(int type_id) {
|
|
|
|
|
desc_.mutable_lod_tensor()->set_data_type(
|
|
|
|
|
static_cast<enum DataType>(type_id));
|
|
|
|
|
void SetDataType(framework::DataType data_type) {
|
|
|
|
|
desc_.mutable_lod_tensor()->set_data_type(data_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> Shape() {
|
|
|
|
|
return RepeatedToVector(desc_.lod_tensor().dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int DataType() { return desc_.lod_tensor().data_type(); }
|
|
|
|
|
framework::DataType DataType() { return desc_.lod_tensor().data_type(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
VarDesc desc_;
|
|
|
|
@ -502,14 +501,21 @@ void BindBlockDesc(py::module &m) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindVarDsec(py::module &m) {
|
|
|
|
|
py::enum_<framework::DataType>(m, "DataType", "")
|
|
|
|
|
.value("BOOL", DataType::BOOL)
|
|
|
|
|
.value("INT16", DataType::INT16)
|
|
|
|
|
.value("INT32", DataType::INT32)
|
|
|
|
|
.value("INT64", DataType::INT64)
|
|
|
|
|
.value("FP16", DataType::FP16)
|
|
|
|
|
.value("FP32", DataType::FP32)
|
|
|
|
|
.value("FP64", DataType::FP64);
|
|
|
|
|
|
|
|
|
|
py::class_<VarDescBind>(m, "VarDesc", "")
|
|
|
|
|
.def("name", &VarDescBind::Name, py::return_value_policy::reference)
|
|
|
|
|
.def("set_shape", &VarDescBind::SetShape)
|
|
|
|
|
.def("set_data_type", &VarDescBind::SetDataType)
|
|
|
|
|
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
|
|
|
|
|
.def("data_type",
|
|
|
|
|
&VarDescBind::DataType,
|
|
|
|
|
py::return_value_policy::reference);
|
|
|
|
|
.def("data_type", &VarDescBind::DataType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindOpDesc(py::module &m) {
|
|
|
|
|