fix lite unit test. (#29233)

release/2.0-rc1
Wilber 4 years ago committed by GitHub
parent b6a26749dc
commit 74c43ac638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,6 +29,7 @@ namespace inference {
namespace lite {
using inference::lite::AddTensorToBlockDesc;
using paddle::inference::lite::AddFetchListToBlockDesc;
using inference::lite::CreateTensor;
using inference::lite::serialize_params;
@ -65,7 +66,7 @@ void make_fake_model(std::string* model, std::string* param) {
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4}), true);
AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({2, 4}), true);
AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 4}), false);
AddTensorToBlockDesc(block_, "out", std::vector<int64_t>({2, 4}), false);
AddFetchListToBlockDesc(block_, "out");
*block_->add_ops() = *feed0->Proto();
*block_->add_ops() = *feed1->Proto();

@ -25,6 +25,7 @@
USE_NO_KERNEL_OP(lite_engine)
using paddle::inference::lite::AddTensorToBlockDesc;
using paddle::inference::lite::AddFetchListToBlockDesc;
using paddle::inference::lite::CreateTensor;
using paddle::inference::lite::serialize_params;
namespace paddle {
@ -60,7 +61,7 @@ TEST(LiteEngineOp, engine_op) {
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4}), true);
AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({2, 4}), true);
AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 4}), false);
AddTensorToBlockDesc(block_, "out", std::vector<int64_t>({2, 4}), false);
AddFetchListToBlockDesc(block_, "out");
*block_->add_ops() = *feed1->Proto();
*block_->add_ops() = *feed0->Proto();
*block_->add_ops() = *elt_add->Proto();

@ -42,6 +42,16 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
desc.SetPersistable(persistable);
*var = *desc.Proto();
}
void AddFetchListToBlockDesc(framework::proto::BlockDesc* block,
const std::string& name) {
using framework::proto::VarType;
auto* var = block->add_vars();
framework::VarDesc desc(name);
desc.SetType(VarType::FETCH_LIST);
*var = *desc.Proto();
}
void serialize_params(std::string* str, framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;

Loading…
Cancel
Save