|
|
|
@ -18,18 +18,21 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); }
|
|
|
|
|
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
|
|
|
|
|
|
|
|
|
|
void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); }
|
|
|
|
|
void VarDesc::SetType(proto::VarType::Type type) {
|
|
|
|
|
desc_.mutable_type()->set_type(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
|
|
|
|
|
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VarDesc::SetTensorDescNum(size_t num) {
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::READER: {
|
|
|
|
|
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::READER: {
|
|
|
|
|
auto *lod_tensors_ptr =
|
|
|
|
|
desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
|
|
|
|
|
lod_tensors_ptr->Clear();
|
|
|
|
|
for (size_t i = 0; i < num; ++i) {
|
|
|
|
|
lod_tensors_ptr->Add();
|
|
|
|
@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t VarDesc::GetTensorDescNum() const {
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::READER:
|
|
|
|
|
return desc_.reader().lod_tensor_size();
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::READER:
|
|
|
|
|
return desc_.type().reader().lod_tensor_size();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
@ -64,7 +67,7 @@ void VarDesc::SetShapes(
|
|
|
|
|
<< "). The Reader is going to be reinitialized.";
|
|
|
|
|
SetTensorDescNum(multiple_dims.size());
|
|
|
|
|
}
|
|
|
|
|
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
|
|
|
|
|
std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
|
|
|
|
|
for (size_t i = 0; i < multiple_dims.size(); ++i) {
|
|
|
|
|
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
|
|
|
|
|
}
|
|
|
|
@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
|
|
|
|
|
std::vector<proto::TensorDesc> descs = tensor_descs();
|
|
|
|
|
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
|
|
|
|
|
std::vector<std::vector<int64_t>> res;
|
|
|
|
|
res.reserve(descs.size());
|
|
|
|
|
for (const auto &tensor_desc : descs) {
|
|
|
|
@ -98,7 +101,8 @@ void VarDesc::SetDataTypes(
|
|
|
|
|
<< "). The Reader is going to be reinitialized.";
|
|
|
|
|
SetTensorDescNum(multiple_data_type.size());
|
|
|
|
|
}
|
|
|
|
|
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
|
|
|
|
|
std::vector<proto::VarType::TensorDesc *> tensor_descs =
|
|
|
|
|
mutable_tensor_descs();
|
|
|
|
|
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
|
|
|
|
|
tensor_descs[i]->set_data_type(multiple_data_type[i]);
|
|
|
|
|
}
|
|
|
|
@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
|
|
|
|
|
std::vector<proto::TensorDesc> descs = tensor_descs();
|
|
|
|
|
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
|
|
|
|
|
std::vector<proto::DataType> res;
|
|
|
|
|
res.reserve(descs.size());
|
|
|
|
|
for (const auto &tensor_desc : descs) {
|
|
|
|
@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VarDesc::SetLoDLevel(int32_t lod_level) {
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR:
|
|
|
|
|
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::LOD_TENSOR:
|
|
|
|
|
desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR_ARRAY:
|
|
|
|
|
desc_.mutable_tensor_array()->set_lod_level(lod_level);
|
|
|
|
|
case proto::VarType::LOD_TENSOR_ARRAY:
|
|
|
|
|
desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
|
|
|
|
|
<< "). The Reader is going to be reinitialized.";
|
|
|
|
|
SetTensorDescNum(multiple_lod_level.size());
|
|
|
|
|
}
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::READER: {
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::READER: {
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
|
|
|
|
|
for (auto &lod_tensor :
|
|
|
|
|
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
|
|
|
|
|
lod_tensor.set_lod_level(multiple_lod_level[i++]);
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t VarDesc::GetLoDLevel() const {
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR:
|
|
|
|
|
return desc_.lod_tensor().lod_level();
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR_ARRAY:
|
|
|
|
|
return desc_.tensor_array().lod_level();
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::LOD_TENSOR:
|
|
|
|
|
return desc_.type().lod_tensor().lod_level();
|
|
|
|
|
case proto::VarType::LOD_TENSOR_ARRAY:
|
|
|
|
|
return desc_.type().tensor_array().lod_level();
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Getting 'lod_level' is not supported by the type of var %s.",
|
|
|
|
@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const {
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> VarDesc::GetLoDLevels() const {
|
|
|
|
|
std::vector<int32_t> res;
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::READER:
|
|
|
|
|
res.reserve(desc_.reader().lod_tensor_size());
|
|
|
|
|
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::READER:
|
|
|
|
|
res.reserve(desc_.type().reader().lod_tensor_size());
|
|
|
|
|
for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
|
|
|
|
|
res.push_back(lod_tensor.lod_level());
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const proto::TensorDesc &VarDesc::tensor_desc() const {
|
|
|
|
|
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
|
|
|
|
|
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::SELECTED_ROWS:
|
|
|
|
|
return desc_.selected_rows();
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR:
|
|
|
|
|
return desc_.lod_tensor().tensor();
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR_ARRAY:
|
|
|
|
|
return desc_.tensor_array().tensor();
|
|
|
|
|
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::SELECTED_ROWS:
|
|
|
|
|
return desc_.type().selected_rows();
|
|
|
|
|
case proto::VarType::LOD_TENSOR:
|
|
|
|
|
return desc_.type().lod_tensor().tensor();
|
|
|
|
|
case proto::VarType::LOD_TENSOR_ARRAY:
|
|
|
|
|
return desc_.type().tensor_array().tensor();
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Getting 'tensor_desc' is not supported by the type of var %s.",
|
|
|
|
@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
|
|
|
|
|
std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
|
|
|
|
|
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
|
|
|
|
|
std::vector<proto::TensorDesc> res;
|
|
|
|
|
std::vector<proto::VarType::TensorDesc> res;
|
|
|
|
|
res.reserve(GetTensorDescNum());
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::READER:
|
|
|
|
|
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::READER:
|
|
|
|
|
for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
|
|
|
|
|
res.push_back(lod_tensor.tensor());
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
proto::TensorDesc *VarDesc::mutable_tensor_desc() {
|
|
|
|
|
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
|
|
|
|
|
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::SELECTED_ROWS:
|
|
|
|
|
return desc_.mutable_selected_rows();
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR:
|
|
|
|
|
return desc_.mutable_lod_tensor()->mutable_tensor();
|
|
|
|
|
case proto::VarDesc::LOD_TENSOR_ARRAY:
|
|
|
|
|
return desc_.mutable_tensor_array()->mutable_tensor();
|
|
|
|
|
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::SELECTED_ROWS:
|
|
|
|
|
return desc_.mutable_type()->mutable_selected_rows();
|
|
|
|
|
case proto::VarType::LOD_TENSOR:
|
|
|
|
|
return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
|
|
|
|
|
case proto::VarType::LOD_TENSOR_ARRAY:
|
|
|
|
|
return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Getting 'mutable_tensor_desc' is not supported by the type of var "
|
|
|
|
@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
|
|
|
|
|
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
|
|
|
|
|
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
|
|
|
|
|
std::vector<proto::TensorDesc *> res;
|
|
|
|
|
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
|
|
|
|
|
std::vector<proto::VarType::TensorDesc *> res;
|
|
|
|
|
res.reserve(GetTensorDescNum());
|
|
|
|
|
switch (desc_.type()) {
|
|
|
|
|
case proto::VarDesc::READER:
|
|
|
|
|
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
|
|
|
|
|
switch (desc_.type().type()) {
|
|
|
|
|
case proto::VarType::READER:
|
|
|
|
|
for (auto &lod_tensor :
|
|
|
|
|
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
|
|
|
|
|
res.push_back(lod_tensor.mutable_tensor());
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|