|
|
|
@ -63,9 +63,11 @@ void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
|
|
|
|
|
void TensorRTEngine::FreezeNetwork() {
|
|
|
|
|
freshDeviceId();
|
|
|
|
|
VLOG(3) << "TRT to freeze network";
|
|
|
|
|
PADDLE_ENFORCE(infer_builder_ != nullptr,
|
|
|
|
|
"Call InitNetwork first to initialize network.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(network() != nullptr, true,
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(infer_builder_,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Inference builder of TRT is null. Please make "
|
|
|
|
|
"sure you call InitNetwork first."));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(network(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Call InitNetwork first to initialize network."));
|
|
|
|
|
// build engine.
|
|
|
|
@ -210,7 +212,10 @@ void TensorRTEngine::FreezeNetwork() {
|
|
|
|
|
} else {
|
|
|
|
|
infer_engine_.reset(infer_builder_->buildCudaEngine(*network()));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
infer_engine_, platform::errors::Fatal(
|
|
|
|
|
"Build TensorRT cuda engine failed! Please recheck "
|
|
|
|
|
"you configurations related to paddle-TensorRT."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
|
|
|
|
@ -220,8 +225,16 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The TRT network should be initialized first."));
|
|
|
|
|
auto *input = network()->addInput(name.c_str(), dtype, dims);
|
|
|
|
|
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
|
|
|
|
|
PADDLE_ENFORCE(input->isNetworkInput());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
input, platform::errors::InvalidArgument("Adding input %s failed in "
|
|
|
|
|
"TensorRT inference network. "
|
|
|
|
|
"Please recheck your input.",
|
|
|
|
|
name));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input->isNetworkInput(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input %s is not the input of TRT inference network. "
|
|
|
|
|
"Please recheck your input.",
|
|
|
|
|
name));
|
|
|
|
|
TensorRTEngine::SetITensor(name, input);
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
@ -230,31 +243,53 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
|
|
|
|
|
const std::string &name) {
|
|
|
|
|
auto *output = layer->getOutput(offset);
|
|
|
|
|
SetITensor(name, output);
|
|
|
|
|
PADDLE_ENFORCE(output != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
output, platform::errors::InvalidArgument(
|
|
|
|
|
"The output %s of TRT engine should not be null.", name));
|
|
|
|
|
output->setName(name.c_str());
|
|
|
|
|
PADDLE_ENFORCE(!output->isNetworkInput());
|
|
|
|
|
PADDLE_ENFORCE_EQ(output->isNetworkInput(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The output %s of TRT engine should not be the input "
|
|
|
|
|
"of the network at the same time.",
|
|
|
|
|
name));
|
|
|
|
|
network()->markOutput(*output);
|
|
|
|
|
PADDLE_ENFORCE(output->isNetworkOutput());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
output->isNetworkOutput(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The output %s of TRT engine should be the output of the network.",
|
|
|
|
|
name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorRTEngine::DeclareOutput(const std::string &name) {
|
|
|
|
|
auto *output = TensorRTEngine::GetITensor(name);
|
|
|
|
|
PADDLE_ENFORCE(output != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
output, platform::errors::InvalidArgument(
|
|
|
|
|
"The output %s of TRT engine should not be null.", name));
|
|
|
|
|
output->setName(name.c_str());
|
|
|
|
|
PADDLE_ENFORCE(!output->isNetworkInput());
|
|
|
|
|
PADDLE_ENFORCE_EQ(output->isNetworkInput(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The output %s of TRT engine should not be the input "
|
|
|
|
|
"of the network at the same time.",
|
|
|
|
|
name));
|
|
|
|
|
network()->markOutput(*output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorRTEngine::SetITensor(const std::string &name,
|
|
|
|
|
nvinfer1::ITensor *tensor) {
|
|
|
|
|
PADDLE_ENFORCE(tensor != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
|
|
|
|
|
name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
tensor, platform::errors::InvalidArgument(
|
|
|
|
|
"Tensor named %s of TRT engine should not be null.", name));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
0, itensor_map_.count(name),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Tensor named %s of TRT engine should not be duplicated", name));
|
|
|
|
|
itensor_map_[name] = tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
|
|
|
|
|
PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(itensor_map_.count(name), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Tensor named %s is not found in TRT engine", name));
|
|
|
|
|
return itensor_map_[name];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -271,11 +306,11 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
|
|
|
|
|
std::string splitter = "__";
|
|
|
|
|
std::string name_with_suffix = name + splitter + name_suffix;
|
|
|
|
|
platform::CPUPlace cpu_place;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_map.count(name_with_suffix), 0,
|
|
|
|
|
"During TRT Op converter: We set weight %s with the same name "
|
|
|
|
|
"twice into the weight_map",
|
|
|
|
|
name_with_suffix);
|
|
|
|
|
PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), 0,
|
|
|
|
|
platform::errors::AlreadyExists(
|
|
|
|
|
"The weight named %s is set into the weight map "
|
|
|
|
|
"twice in TRT OP converter.",
|
|
|
|
|
name_with_suffix));
|
|
|
|
|
weight_map[name_with_suffix].reset(new framework::Tensor());
|
|
|
|
|
weight_map[name_with_suffix]->Resize(weight_tensor->dims());
|
|
|
|
|
TensorCopySync(*weight_tensor, cpu_place, weight_map[name_with_suffix].get());
|
|
|
|
@ -297,7 +332,10 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
|
|
|
|
|
void TensorRTEngine::freshDeviceId() {
|
|
|
|
|
int count;
|
|
|
|
|
cudaGetDeviceCount(&count);
|
|
|
|
|
PADDLE_ENFORCE_LT(device_id_, count);
|
|
|
|
|
PADDLE_ENFORCE_LT(device_id_, count,
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
"Device id %d exceeds the current device count: %d.",
|
|
|
|
|
device_id_, count));
|
|
|
|
|
cudaSetDevice(device_id_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|