1). add static trt load model 2). fix bug: when device_id is not 0, the trt will have a bug test=developalign_pyramid
parent
2070fb246d
commit
1d5ef7c9ee
@ -1,4 +1,5 @@
|
||||
nv_library(tensorrt_plugin
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc
|
||||
avg_pool_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu)
|
||||
|
@ -0,0 +1,48 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
|
||||
const void* serial_data,
|
||||
size_t serial_length) {
|
||||
const char* plugin_type;
|
||||
DeserializeValue(&serial_data, &serial_length, &plugin_type);
|
||||
|
||||
PADDLE_ENFORCE(Has(plugin_type),
|
||||
"trt plugin type %s does not exists, check it.", plugin_type);
|
||||
auto plugin = plugin_registry_[plugin_type](serial_data, serial_length);
|
||||
owned_plugins_.emplace_back(plugin);
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
bool PluginFactoryTensorRT::RegisterPlugin(
|
||||
const std::string& op_name, PluginDeserializeFunc deserialize_func) {
|
||||
if (Has(op_name)) return false;
|
||||
auto ret = plugin_registry_.emplace(op_name, deserialize_func);
|
||||
return ret.second;
|
||||
}
|
||||
|
||||
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
|
||||
#include "paddle/fluid/inference/utils/singleton.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
|
||||
public:
|
||||
// Deserialization method
|
||||
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
|
||||
size_t serial_length) override;
|
||||
|
||||
bool RegisterPlugin(const std::string& op_name,
|
||||
PluginDeserializeFunc deserialize_func);
|
||||
|
||||
bool Has(const std::string& op_name) {
|
||||
return plugin_registry_.find(op_name) != plugin_registry_.end();
|
||||
}
|
||||
|
||||
void DestroyPlugins();
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, PluginDeserializeFunc> plugin_registry_;
|
||||
|
||||
std::list<std::unique_ptr<PluginTensorRT>> owned_plugins_;
|
||||
};
|
||||
|
||||
class TrtPluginRegistrar {
|
||||
public:
|
||||
TrtPluginRegistrar(const std::string& name,
|
||||
PluginDeserializeFunc deserialize_func) {
|
||||
inference::Singleton<PluginFactoryTensorRT>::Global().RegisterPlugin(
|
||||
name, deserialize_func);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_TRT_PLUGIN(name, deserialize_func) \
|
||||
REGISTER_TRT_PLUGIN_UNIQ(__COUNTER__, name, deserialize_func)
|
||||
|
||||
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \
|
||||
static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \
|
||||
trt_plugin_registrar##ctr __attribute__((unused)) = \
|
||||
paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \
|
||||
name, deserialize_func)
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue