|
|
|
@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
|
|
|
|
|
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
|
|
|
|
|
cudaStream_t stream) {
|
|
|
|
|
freshDeviceId();
|
|
|
|
|
const std::thread::id tid = std::this_thread::get_id();
|
|
|
|
|
batch_size_ = batch_size;
|
|
|
|
|
infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr);
|
|
|
|
|
if (infer_context_.find(tid) == infer_context_.end()) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
infer_engine_,
|
|
|
|
|
"You should build engine first and then set the context.");
|
|
|
|
|
infer_context_[tid].reset(infer_engine_->createExecutionContext());
|
|
|
|
|
}
|
|
|
|
|
infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr);
|
|
|
|
|
cudaStreamSynchronize(stream);
|
|
|
|
|
SetRuntimeBatch(batch_size);
|
|
|
|
|
}
|
|
|
|
@ -109,8 +116,6 @@ void TensorRTEngine::FreezeNetwork() {
|
|
|
|
|
|
|
|
|
|
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
|
|
|
|
|
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
|
|
|
|
|
|
|
|
|
|
infer_context_.reset(infer_engine_->createExecutionContext());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
|
|
|
|
|