fix trt engine test error.

test=develop
revert-16807-engine2-interface
nhzlx 6 years ago
parent d065b5bf2b
commit 7cde2d9e84

@ -82,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
}
if (!calibration_mode_) {
if (!calibration_mode_ && !engine_serialized_data_.empty()) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_));
@ -236,6 +236,9 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (!trt_engine_) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_));
PrepareTRTEngine(scope, trt_engine_.get());
}
return trt_engine_.get();

@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) {
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);

Loading…
Cancel
Save