parent
0b38822624
commit
2d7134bc37
@ -1,4 +1,5 @@
|
||||
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
|
||||
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
|
||||
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
|
||||
add_subdirectory(plugin)
|
||||
add_subdirectory(convert)
|
||||
|
Binary file not shown.
@ -0,0 +1,2 @@
|
||||
nv_library(tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc
|
||||
trt_plugin.cc split_op_plugin.cu DEPS enforce)
|
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
|
||||
const void* serial_data,
|
||||
size_t serial_length) {
|
||||
size_t parsed_byte = 0;
|
||||
std::string encoded_op_name =
|
||||
ExtractOpName(serial_data, serial_length, &parsed_byte);
|
||||
|
||||
if (!IsPlugin(encoded_op_name)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto plugin_ptr =
|
||||
plugin_registry_[encoded_op_name].first(serial_data, serial_length);
|
||||
owned_plugins_.emplace_back(plugin_ptr);
|
||||
|
||||
return plugin_ptr;
|
||||
}
|
||||
|
||||
PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(
|
||||
const std::string& op_name) {
|
||||
if (!IsPlugin(op_name)) return nullptr;
|
||||
|
||||
auto plugin_ptr = plugin_registry_[op_name].second();
|
||||
owned_plugins_.emplace_back(plugin_ptr);
|
||||
|
||||
return plugin_ptr;
|
||||
}
|
||||
|
||||
bool PluginFactoryTensorRT::RegisterPlugin(
|
||||
const std::string& op_name, PluginDeserializeFunc deserialize_func,
|
||||
PluginConstructFunc construct_func) {
|
||||
if (IsPlugin(op_name)) return false;
|
||||
|
||||
auto ret = plugin_registry_.emplace(
|
||||
op_name, std::make_pair(deserialize_func, construct_func));
|
||||
|
||||
return ret.second;
|
||||
}
|
||||
|
||||
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "NvInfer.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
|
||||
public:
|
||||
static PluginFactoryTensorRT* GetInstance() {
|
||||
static PluginFactoryTensorRT* factory_instance =
|
||||
new PluginFactoryTensorRT();
|
||||
return factory_instance;
|
||||
}
|
||||
|
||||
// Deserialization method
|
||||
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
|
||||
size_t serial_length) override;
|
||||
|
||||
// Plugin construction, PluginFactoryTensorRT owns the plugin.
|
||||
PluginTensorRT* CreatePlugin(const std::string& op_name);
|
||||
|
||||
bool RegisterPlugin(const std::string& op_name,
|
||||
PluginDeserializeFunc deserialize_func,
|
||||
PluginConstructFunc construct_func);
|
||||
|
||||
bool IsPlugin(const std::string& op_name) {
|
||||
return plugin_registry_.find(op_name) != plugin_registry_.end();
|
||||
}
|
||||
|
||||
size_t CountOwnedPlugins() { return owned_plugins_.size(); }
|
||||
|
||||
void DestroyPlugins();
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string,
|
||||
std::pair<PluginDeserializeFunc, PluginConstructFunc>>
|
||||
plugin_registry_;
|
||||
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
|
||||
};
|
||||
|
||||
class TrtPluginRegistrar {
|
||||
public:
|
||||
TrtPluginRegistrar(const std::string& name,
|
||||
PluginDeserializeFunc deserialize_func,
|
||||
PluginConstructFunc construct_func) {
|
||||
auto factory = PluginFactoryTensorRT::GetInstance();
|
||||
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
|
||||
// construct_func), "Falied to register plugin [%s]", name);
|
||||
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
|
||||
// construct_func));
|
||||
factory->RegisterPlugin(name, deserialize_func, construct_func);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \
|
||||
REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
|
||||
construct_func)
|
||||
#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
|
||||
construct_func) \
|
||||
REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
|
||||
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
|
||||
static ::paddle::inference::tensorrt::TrtPluginRegistrar \
|
||||
trt_plugin_registrar##ctr __attribute__((unused)) = \
|
||||
::paddle::inference::tensorrt::TrtPluginRegistrar( \
|
||||
name, deserialize_func, construct_func)
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,37 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
|
||||
#include <cassert>
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
std::string ExtractOpName(const void* serial_data, size_t serial_length,
|
||||
size_t* incremental) {
|
||||
size_t op_name_char_count = *static_cast<const size_t*>(serial_data);
|
||||
*incremental = sizeof(size_t) + op_name_char_count;
|
||||
|
||||
assert(serial_length >= *incremental);
|
||||
|
||||
const char* buffer = static_cast<const char*>(serial_data) + sizeof(size_t);
|
||||
std::string op_name(buffer, op_name_char_count);
|
||||
|
||||
return op_name;
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
|
||||
#include "NvInfer.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
typedef std::function<PluginTensorRT*(const void*, size_t)>
|
||||
PluginDeserializeFunc;
|
||||
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
|
||||
|
||||
std::string ExtractOpName(const void* serial_data, size_t serial_length,
|
||||
size_t* incremental);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespze paddle
|
@ -0,0 +1,111 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
inline void serialize_value(void** buffer, T const& value);
|
||||
|
||||
template <typename T>
|
||||
inline void deserialize_value(void const** buffer, size_t* buffer_size,
|
||||
T* value);
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct Serializer {};
|
||||
|
||||
template <typename T>
|
||||
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
|
||||
std::is_enum<T>::value ||
|
||||
std::is_pod<T>::value>::type> {
|
||||
static size_t serialized_size(T const& value) { return sizeof(T); }
|
||||
static void serialize(void** buffer, T const& value) {
|
||||
::memcpy(*buffer, &value, sizeof(T));
|
||||
reinterpret_cast<char*&>(*buffer) += sizeof(T);
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
|
||||
assert(*buffer_size >= sizeof(T));
|
||||
::memcpy(value, *buffer, sizeof(T));
|
||||
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
|
||||
*buffer_size -= sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Serializer<const char*> {
|
||||
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
|
||||
static void serialize(void** buffer, const char* value) {
|
||||
::strcpy(static_cast<char*>(*buffer), value);
|
||||
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size,
|
||||
const char** value) {
|
||||
*value = static_cast<char const*>(*buffer);
|
||||
size_t data_size = strnlen(*value, *buffer_size) + 1;
|
||||
assert(*buffer_size >= data_size);
|
||||
reinterpret_cast<char const*&>(*buffer) += data_size;
|
||||
*buffer_size -= data_size;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Serializer<std::vector<T>,
|
||||
typename std::enable_if<std::is_arithmetic<T>::value ||
|
||||
std::is_enum<T>::value ||
|
||||
std::is_pod<T>::value>::type> {
|
||||
static size_t serialized_size(std::vector<T> const& value) {
|
||||
return sizeof(value.size()) + value.size() * sizeof(T);
|
||||
}
|
||||
static void serialize(void** buffer, std::vector<T> const& value) {
|
||||
serialize_value(buffer, value.size());
|
||||
size_t nbyte = value.size() * sizeof(T);
|
||||
::memcpy(*buffer, value.data(), nbyte);
|
||||
reinterpret_cast<char*&>(*buffer) += nbyte;
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size,
|
||||
std::vector<T>* value) {
|
||||
size_t size;
|
||||
deserialize_value(buffer, buffer_size, &size);
|
||||
value->resize(size);
|
||||
size_t nbyte = value->size() * sizeof(T);
|
||||
assert(*buffer_size >= nbyte);
|
||||
::memcpy(value->data(), *buffer, nbyte);
|
||||
reinterpret_cast<char const*&>(*buffer) += nbyte;
|
||||
*buffer_size -= nbyte;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
inline size_t serialized_size(T const& value) {
|
||||
return Serializer<T>::serialized_size(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void serialize_value(void** buffer, T const& value) {
|
||||
return Serializer<T>::serialize(buffer, value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void deserialize_value(void const** buffer, size_t* buffer_size,
|
||||
T* value) {
|
||||
return Serializer<T>::deserialize(buffer, buffer_size, value);
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cassert>
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
SplitPlugin* CreateSplitPlugin() { return new SplitPlugin(); };
|
||||
|
||||
nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
|
||||
const nvinfer1::Dims* inputDims,
|
||||
int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index < this->getNbOutputs());
|
||||
nvinfer1::Dims const& input_dims = inputDims[0];
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
output_dims.d[axis_] = output_lenght_.at(index);
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
int SplitPlugin::initialize() {
|
||||
std::vector<int> segment_offsets(1, 0);
|
||||
for (int i = 0; i < this->getNbOutputs(); ++i) {
|
||||
segment_offsets.push_back(segment_offsets.back() + output_lenght_[i]);
|
||||
}
|
||||
d_segment_offsets_ = segment_offsets;
|
||||
nvinfer1::Dims dims = this->getInputDims(0);
|
||||
nx_ = 1;
|
||||
for (int i = dims.nbDims - 1; i > axis_; --i) {
|
||||
nx_ *= dims.d[i];
|
||||
}
|
||||
ny_ = dims.d[axis_];
|
||||
nz_ = 1;
|
||||
for (int i = axis_ - 1; i >= 0; --i) {
|
||||
nz_ *= dims.d[i];
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ int upper_bound(T const* vals, int n, T const& key) {
|
||||
int i = 0;
|
||||
while (n > 0) {
|
||||
int m = n / 2;
|
||||
int j = i + m;
|
||||
if (!(key < vals[j])) {
|
||||
i = j + 1;
|
||||
n -= m + 1;
|
||||
} else {
|
||||
n = m;
|
||||
}
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void split_kernel(int nsegment,
|
||||
int const* __restrict__ segment_offsets,
|
||||
T const* __restrict__ idata, T* const* odatas,
|
||||
int nx, int srcny_, int nz) {
|
||||
int x0 = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int src_y0 = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
int z0 = threadIdx.z + blockIdx.z * blockDim.z;
|
||||
for (int z = z0; z < nz; z += blockDim.z * gridDim.z) {
|
||||
for (int src_y = src_y0; src_y < srcny_; src_y += blockDim.y * gridDim.y) {
|
||||
for (int x = x0; x < nx; x += blockDim.x * gridDim.x) {
|
||||
int segment = upper_bound(segment_offsets, nsegment, src_y) - 1;
|
||||
int dst_y = src_y - segment_offsets[segment];
|
||||
int dstny_ = segment_offsets[segment + 1] - segment_offsets[segment];
|
||||
odatas[segment][x + nx * (dst_y + dstny_ * z)] =
|
||||
idata[x + nx * (src_y + srcny_ * z)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
|
||||
void** outputs, void* workspace, cudaStream_t stream) {
|
||||
auto const& input_dims = this->getInputDims(0);
|
||||
int const* d_segment_offsets_ptr =
|
||||
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
|
||||
float const* idata = reinterpret_cast<float const*>(inputs[0]);
|
||||
float** odatas = reinterpret_cast<float**>(outputs);
|
||||
|
||||
int nz = nz_ * batchSize;
|
||||
dim3 block(32, 16);
|
||||
dim3 grid(std::min((nx_ - 1) / block.x + 1, 65535u),
|
||||
std::min((ny_ - 1) / block.y + 1, 65535u),
|
||||
std::min((nz_ - 1) / block.z + 1, 65535u));
|
||||
|
||||
split_kernel<<<grid, block, 0, stream>>>(d_segment_offsets_.size(),
|
||||
d_segment_offsets_ptr, idata, odatas,
|
||||
nx_, ny_, nz);
|
||||
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
} // tensorrt
|
||||
} // inference
|
||||
} // paddle
|
@ -0,0 +1,62 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class SplitPlugin : public PluginTensorRT {
|
||||
int axis_;
|
||||
std::vector<int> output_lenght_;
|
||||
int nx_, ny_, nz_;
|
||||
thrust::device_vector<int> d_segment_offsets_;
|
||||
|
||||
protected:
|
||||
virtual size_t getSerializationSize() override {
|
||||
return serialized_size(axis_) + serialized_size(output_lenght_)
|
||||
+ getBaseSerializationSize();
|
||||
}
|
||||
|
||||
virtual void serialize(void *buffer) override {
|
||||
serializeBase(buffer);
|
||||
serialize_value(&buffer, axis_);
|
||||
serialize_value(&buffer, output_lenght_);
|
||||
}
|
||||
|
||||
public:
|
||||
Split() {}
|
||||
SplitPlugin(void const* serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
deserialize_value(&serialData, &serialLength, &axis_);
|
||||
deserialize_value(&serialData, &serialLength, &output_lenght_);
|
||||
}
|
||||
|
||||
SplitPlugin* clone() const override {
|
||||
return new SplitPlugin(axis_, output_lenght_);
|
||||
}
|
||||
|
||||
virtual const char* getPluginType() const override { return "split"; }
|
||||
virtual int getNbOutputs() const override { return output_lenght_.size(); }
|
||||
virtual nvinfer1::Dims getOutputDimensions(int index,
|
||||
const nvinfer1::Dims *inputs, int nbInputDims) override;
|
||||
virtual int initialize() override;
|
||||
virtual int enqueue(int batchSize,
|
||||
const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
|
||||
void setAxis(int axis) {
|
||||
axis_ = axis;
|
||||
}
|
||||
|
||||
void setOutputLengths(const std::vector<int> & output_lengths) {
|
||||
output_length_ = output_lengths;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // tensorrt
|
||||
} // inference
|
||||
} // paddle
|
@ -0,0 +1,63 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
void PluginTensorRT::serializeBase(void*& buffer) {
|
||||
serialize_value(&buffer, input_dims_);
|
||||
serialize_value(&buffer, max_batch_size_);
|
||||
serialize_value(&buffer, data_type_);
|
||||
serialize_value(&buffer, data_format_);
|
||||
}
|
||||
|
||||
void PluginTensorRT::deserializeBase(void const*& serialData,
|
||||
size_t& serialLength) {
|
||||
deserialize_value(&serialData, &serialLength, &input_dims_);
|
||||
deserialize_value(&serialData, &serialLength, &max_batch_size_);
|
||||
deserialize_value(&serialData, &serialLength, &data_type_);
|
||||
deserialize_value(&serialData, &serialLength, &data_format_);
|
||||
}
|
||||
|
||||
size_t PluginTensorRT::getBaseSerializationSize() {
|
||||
return (serialized_size(input_dims_) + serialized_size(max_batch_size_) +
|
||||
serialized_size(data_type_) + serialized_size(data_format_));
|
||||
}
|
||||
|
||||
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
|
||||
nvinfer1::PluginFormat format) const {
|
||||
return ((type == nvinfer1::DataType::kFLOAT ||
|
||||
type == nvinfer1::DataType::kHALF) &&
|
||||
(format == nvinfer1::PluginFormat::kNCHW));
|
||||
}
|
||||
|
||||
void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims,
|
||||
int nbInputs,
|
||||
const nvinfer1::Dims* outputDims,
|
||||
int nbOutputs, nvinfer1::DataType type,
|
||||
nvinfer1::PluginFormat format,
|
||||
int maxBatchSize) {
|
||||
data_type_ = type;
|
||||
data_format_ = format;
|
||||
input_dims_.assign(inputDims, inputDims + nbInputs);
|
||||
max_batch_size_ = maxBatchSize;
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/serialize.hpp"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class PluginTensorRT : public nvinfer1::IPluginExt {
|
||||
public:
|
||||
PluginTensorRT() {}
|
||||
PluginTensorRT(const void* serialized_data, size_t length) {}
|
||||
nvinfer1::Dims const& getInputDims(int index) const {
|
||||
return input_dims_.at(index);
|
||||
}
|
||||
size_t getMaxBatchSize() const { return max_batch_size_; }
|
||||
nvinfer1::DataType getDataType() const { return data_type_; }
|
||||
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
|
||||
virtual const char* getPluginVersion() const { return "1"; }
|
||||
size_t getWorkspaceSize(int) const override { return 0; }
|
||||
void terminate() override {}
|
||||
virtual ~PluginTensorRT() {}
|
||||
|
||||
// The following functions need to be overrided in the subclass.
|
||||
virtual nvinfer1::IPluginExt* clone() const = 0;
|
||||
virtual const char* getPluginType() const = 0;
|
||||
int initialize() override { return 0; }
|
||||
bool supportsFormat(nvinfer1::DataType type,
|
||||
nvinfer1::PluginFormat format) const override;
|
||||
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs,
|
||||
const nvinfer1::Dims* outputDims, int nbOutputs,
|
||||
nvinfer1::DataType type,
|
||||
nvinfer1::PluginFormat format,
|
||||
int maxBatchSize) override;
|
||||
virtual void serialize(void* buffer) override;
|
||||
virtual size_t getSerializationSize() override;
|
||||
|
||||
protected:
|
||||
void deserializeBase(void const*& serialData, size_t& serialLength);
|
||||
size_t getBaseSerializationSize();
|
||||
void serializeBase(void*& buffer);
|
||||
|
||||
std::vector<nvinfer1::Dims> input_dims_;
|
||||
size_t max_batch_size_;
|
||||
nvinfer1::DataType data_type_;
|
||||
nvinfer1::PluginFormat data_format_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue