|
|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
|
|
|
|
|
this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
@ -26,6 +26,8 @@ namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
|
namespace tensorrt {
|
|
|
|
|
|
|
|
|
|
int TensorRTEngine::runtime_batch_ = 1;
|
|
|
|
|
|
|
|
|
|
void TensorRTEngine::Build(const DescType &paddle_model) {
|
|
|
|
|
PADDLE_ENFORCE(false, "not implemented");
|
|
|
|
|
}
|
|
|
|
|
@ -42,6 +44,7 @@ void TensorRTEngine::Execute(int batch_size) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(stream_);
|
|
|
|
|
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
|
|
|
|
|
cudaStreamSynchronize(*stream_);
|
|
|
|
|
SetRuntimeBatch(batch_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorRTEngine::~TensorRTEngine() {
|
|
|
|
|
@ -80,17 +83,17 @@ void TensorRTEngine::FreezeNetwork() {
|
|
|
|
|
auto dims = infer_engine_->getBindingDimensions(slot_offset);
|
|
|
|
|
item.second = kDataTypeSize[static_cast<int>(
|
|
|
|
|
infer_engine_->getBindingDataType(slot_offset))] *
|
|
|
|
|
analysis::AccuDims(dims.d, dims.nbDims);
|
|
|
|
|
analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
|
|
|
|
|
PADDLE_ENFORCE_GT(item.second, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &buf = buffer(item.first);
|
|
|
|
|
buf.max_size = item.second * max_batch_;
|
|
|
|
|
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
|
|
|
|
|
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, buf.max_size));
|
|
|
|
|
PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G
|
|
|
|
|
// buf.size will changed in the runtime.
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_));
|
|
|
|
|
buf.size = 0;
|
|
|
|
|
PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G
|
|
|
|
|
buf.device = DeviceType::GPU;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -105,7 +108,7 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
|
|
|
|
|
auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
|
|
|
|
|
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
|
|
|
|
|
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
|
|
|
|
|
analysis::AccuDims(dims.d, dims.nbDims);
|
|
|
|
|
analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
|
|
|
|
|
PADDLE_ENFORCE(input->isNetworkInput());
|
|
|
|
|
TensorRTEngine::SetITensor(name, input);
|
|
|
|
|
return input;
|
|
|
|
|
@ -149,35 +152,42 @@ void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
|
|
|
|
|
void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
|
|
|
|
|
size_t max_size) {
|
|
|
|
|
// determine data size
|
|
|
|
|
auto *output = TensorRTEngine::GetITensor(name);
|
|
|
|
|
nvinfer1::Dims dims = output->getDimensions();
|
|
|
|
|
auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
|
|
|
|
|
size_t dst_size = dim_size * runtime_batch_ *
|
|
|
|
|
kDataTypeSize[static_cast<int>(output->getType())];
|
|
|
|
|
|
|
|
|
|
auto it = buffer_sizes_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != buffer_sizes_.end());
|
|
|
|
|
PADDLE_ENFORCE_GT(it->second, 0);
|
|
|
|
|
PADDLE_ENFORCE_GE(max_size, it->second);
|
|
|
|
|
PADDLE_ENFORCE_LE(dst_size, it->second);
|
|
|
|
|
PADDLE_ENFORCE_GE(max_size, dst_size);
|
|
|
|
|
auto &buf = buffer(name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
|
|
|
|
|
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
|
|
|
|
|
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
|
|
|
|
|
cudaMemcpyDeviceToDevice, *stream_),
|
|
|
|
|
0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
|
|
|
|
|
size_t max_size) {
|
|
|
|
|
VLOG(4) << "get output in cpu";
|
|
|
|
|
auto &buf = buffer(name);
|
|
|
|
|
|
|
|
|
|
// Update needed buffer size.
|
|
|
|
|
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
|
|
|
|
|
auto dims = infer_engine_->getBindingDimensions(slot_offset);
|
|
|
|
|
buf.size = kDataTypeSize[static_cast<int>(
|
|
|
|
|
infer_engine_->getBindingDataType(slot_offset))] *
|
|
|
|
|
analysis::AccuDims(dims.d, dims.nbDims);
|
|
|
|
|
PADDLE_ENFORCE_LE(buf.size, buf.max_size);
|
|
|
|
|
// determine data size
|
|
|
|
|
|
|
|
|
|
auto *output = TensorRTEngine::GetITensor(name);
|
|
|
|
|
nvinfer1::Dims dims = output->getDimensions();
|
|
|
|
|
auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
|
|
|
|
|
size_t dst_size = dim_size * runtime_batch_ *
|
|
|
|
|
kDataTypeSize[static_cast<int>(output->getType())];
|
|
|
|
|
auto it = buffer_sizes_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != buffer_sizes_.end());
|
|
|
|
|
PADDLE_ENFORCE_GT(it->second, 0);
|
|
|
|
|
PADDLE_ENFORCE_LE(dst_size, it->second);
|
|
|
|
|
PADDLE_ENFORCE_GE(max_size, dst_size);
|
|
|
|
|
auto &buf = buffer(name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
|
|
|
|
|
// DEBUG
|
|
|
|
|
memset(dst, 0, buf.size);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
0, cudaMemcpy(dst, buf.buffer, buf.size, cudaMemcpyDeviceToHost));
|
|
|
|
|
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
|
|
|
|
|
cudaMemcpyDeviceToHost, *stream_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Buffer &TensorRTEngine::buffer(const std::string &name) {
|
|
|
|
|
@ -225,6 +235,12 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
|
|
|
|
|
return itensor_map_[name];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
|
|
|
|
|
runtime_batch_ = batch_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
|
|
|
|
|
|
|
|
|
|
} // namespace tensorrt
|
|
|
|
|
} // namespace inference
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|