|
|
|
@ -26,9 +26,21 @@ namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
|
namespace tensorrt {
|
|
|
|
|
|
|
|
|
|
class ConverterBase {
|
|
|
|
|
class OpConverter {
|
|
|
|
|
public:
|
|
|
|
|
ConverterBase() {}
|
|
|
|
|
OpConverter() {}
|
|
|
|
|
|
|
|
|
|
void Convert(const framework::OpDesc& op) {
|
|
|
|
|
std::string type = op.Type();
|
|
|
|
|
OpConverter& op_converter = this->register_op_converter_[type];
|
|
|
|
|
op_converter.Convert(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void Register(const std::string key) {
|
|
|
|
|
register_op_converter_[key] = T();
|
|
|
|
|
}
|
|
|
|
|
static std::unordered_map<std::string, OpConverter> register_op_converter_;
|
|
|
|
|
|
|
|
|
|
// fluid inference scope
|
|
|
|
|
framework::Scope* scope_;
|
|
|
|
@ -37,30 +49,14 @@ class ConverterBase {
|
|
|
|
|
std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpConverter : public ConverterBase {
|
|
|
|
|
public:
|
|
|
|
|
OpConverter() {}
|
|
|
|
|
virtual ~OpConverter() {}
|
|
|
|
|
|
|
|
|
|
// convert fluid op to tensorrt layer
|
|
|
|
|
virtual void Convert(const framework::OpDesc& op) = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpConverter*>& GetOpConverter() {
|
|
|
|
|
static std::unordered_map<std::string, OpConverter*> register_op_converter;
|
|
|
|
|
return register_op_converter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
|
|
|
|
|
class convert_class##Register { \
|
|
|
|
|
public: \
|
|
|
|
|
convert_class##Register() { \
|
|
|
|
|
GetOpConverter()[#op_type] = new convert_class; \
|
|
|
|
|
} \
|
|
|
|
|
}; \
|
|
|
|
|
convert_class##Register convert_class##reg;
|
|
|
|
|
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
|
|
|
|
|
class convert_class : public OpConverter { \
|
|
|
|
|
public: \
|
|
|
|
|
convert_class() { OpConverter::Register<convert_class>(#op_type); } \
|
|
|
|
|
void Convert(const framework::OpDesc& op); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class TensorRTConverter : public ConverterBase {
|
|
|
|
|
class TensorRTConverter {
|
|
|
|
|
public:
|
|
|
|
|
TensorRTConverter() {}
|
|
|
|
|
|
|
|
|
|