fix tensorrt output shape error (#29308)

* fix tensorrt output shape error

* fix unittest tensorrt_engine_op_test

* fix code style for unitest
revert-31562-mean
Shang Zhizhou 4 years ago committed by GitHub
parent 67c700b479
commit ebf689197d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -151,9 +151,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::set<std::string> output_names; std::set<std::string> output_names;
std::set<std::string> output_names_with_id; std::set<std::string> output_names_with_id;
std::vector<int> origin_output_dims;
for (auto *x : node->outputs) { for (auto *x : node->outputs) {
output_names.insert(x->Name()); output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id())); output_names_with_id.insert(x->Name() + std::to_string(x->id()));
origin_output_dims.push_back(x->Var()->GetShape().size());
} }
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
@ -224,6 +226,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("workspace_size", Get<int>("workspace_size")); op_desc->SetAttr("workspace_size", Get<int>("workspace_size"));
op_desc->SetAttr("gpu_id", Get<int>("gpu_device_id")); op_desc->SetAttr("gpu_id", Get<int>("gpu_device_id"));
op_desc->SetAttr("output_name_mapping", output_mapping); op_desc->SetAttr("output_name_mapping", output_mapping);
op_desc->SetAttr("origin_output_dims", origin_output_dims);
op_desc->SetAttr("parameters", params); op_desc->SetAttr("parameters", params);
// we record all inputs' shapes in attr to check if they are consistent // we record all inputs' shapes in attr to check if they are consistent

@ -288,6 +288,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
// Bind output tensor to TRT. // Bind output tensor to TRT.
int output_index = 0; int output_index = 0;
std::vector<int> origin_output_dims =
Attr<std::vector<int>>("origin_output_dims");
VLOG(4) << "TensorRT Engine Op Outputs:"; VLOG(4) << "TensorRT Engine Op Outputs:";
for (const auto &y : Outputs("Ys")) { for (const auto &y : Outputs("Ys")) {
const int bind_index = const int bind_index =
@ -306,7 +308,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto dims = trt_context->getBindingDimensions(bind_index); auto dims = trt_context->getBindingDimensions(bind_index);
int nb_dims = dims.nbDims; int nb_dims = dims.nbDims;
for (; nb_dims > 0; nb_dims--) { for (; nb_dims > 0; nb_dims--) {
if (dims.d[nb_dims - 1] != 1) break; // some 'x 1' of shape is normal, no need to remove it
if (dims.d[nb_dims - 1] != 1 ||
nb_dims == origin_output_dims[output_index])
break;
} }
for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]); for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]);
#endif #endif

@ -109,6 +109,7 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false)); engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"})); std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("origin_output_dims", std::vector<int>({2}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string("")); engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0; int device_id = 0;
@ -210,6 +211,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false)); engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"})); std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("origin_output_dims", std::vector<int>({2}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string("")); engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0; int device_id = 0;

Loading…
Cancel
Save