|
|
|
@ -183,7 +183,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
auto stream =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs");
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_names_.empty(), false,
|
|
|
|
|
"should pass at least one input");
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> output_maps =
|
|
|
|
|
Attr<std::vector<std::string>>("output_name_mapping");
|
|
|
|
@ -203,7 +204,21 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
// convert input and copy to TRT engine's buffer
|
|
|
|
|
auto &t =
|
|
|
|
|
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
|
|
|
|
|
auto t_shape = framework::vectorize(t.dims());
|
|
|
|
|
auto t_shape = framework::vectorize<int64_t>(t.dims());
|
|
|
|
|
// check if the input shapes are consistent with model.
|
|
|
|
|
if (HasAttr(x + "_shape")) {
|
|
|
|
|
std::vector<int64_t> i_shape = Attr<std::vector<int64_t>>(x + "_shape");
|
|
|
|
|
std::vector<int64_t> model_input_shape(i_shape.begin() + 1,
|
|
|
|
|
i_shape.end());
|
|
|
|
|
std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1,
|
|
|
|
|
t_shape.end());
|
|
|
|
|
PADDLE_ENFORCE_EQ(model_input_shape == runtime_input_shape, true,
|
|
|
|
|
"Input shapes are inconsistent with the model. TRT 5 "
|
|
|
|
|
"or lower version "
|
|
|
|
|
"does not support dynamic input shapes. Please check "
|
|
|
|
|
"your input shapes.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
runtime_batch = t_shape[0];
|
|
|
|
|
|
|
|
|
|
const int bind_index = engine->engine()->getBindingIndex(x.c_str());
|
|
|
|
|