You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc

148 lines
5.0 KiB

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "glog/logging.h"
namespace paddle {
namespace inference {
namespace tensorrt {
// set the batch size before constructing the thread to execute engine
int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
TRTInt8Calibrator::TRTInt8Calibrator(
const std::unordered_map<std::string, size_t>& buffers, int batch_size,
std::string engine_name, const platform::Place place)
: batch_size_(batch_size), engine_name_(engine_name) {
int i = 0;
VLOG(4) << "Init a new calibrator: " << engine_name_;
for (const auto it : buffers) {
framework::Tensor temp_tensor;
std::string input_name = it.first;
int data_size = it.second;
int num_ele = data_size / sizeof(int16_t);
framework::DDim data_shape = framework::make_ddim({num_ele});
temp_tensor.Resize(data_shape);
data_tensors_.push_back(temp_tensor);
data_buffers_[input_name] = std::pair<void*, size_t>(
static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)), num_ele);
i += 1;
}
}
TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
: batch_size_(0),
calib_running_(false),
data_is_set_(false),
done_(true),
calibration_table_(calib_data) {}
void TRTInt8Calibrator::waitAndSetDone() {
std::unique_lock<std::mutex> lk(mut_);
while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
if (!done_) {
done_ = true;
cond_.notify_all();
}
}
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
bool TRTInt8Calibrator::setBatch(
const std::unordered_map<std::string, void*>& data) {
VLOG(3) << "set batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_);
// There is a producer and a consumer. The producer set the batch data and
// the consumer get the batch data. The size of the data pool is one.
// So, the producer has to wait for the consumer to finish processing before
// they can set the data.
while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
// The done_ is set to true using waitAndSetDone, When all calibration data
// are processed.
if (done_) return false;
// Sets the batch.
for (const auto& it : data) {
auto dataptr = data_buffers_.find(it.first);
if (dataptr == data_buffers_.end()) {
LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
<< "' does not match with the buffer names";
}
const auto& d = dataptr->second;
PADDLE_ENFORCE(
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice),
"Fail to cudaMemcpy %s for %s", engine_name_, it.first);
}
data_is_set_ = true;
cond_.notify_all();
return true;
}
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) {
VLOG(4) << "get batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_);
// The consumer has just finished processing a data.
// The producer can set the data again.
calib_running_ = false;
cond_.notify_all();
// As long as there is data in the pool, the consumer can get it.
while (!data_is_set_ && !done_) cond_.wait(lk);
if (done_) return false;
// Gets the batch
for (int i = 0; i < num_bindings; i++) {
auto it = data_buffers_.find(names[i]);
if (it == data_buffers_.end()) {
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
<< names[i] << "' at position " << i;
}
bindings[i] = it->second.first;
}
data_is_set_ = false;
calib_running_ = true;
VLOG(4) << "get batch done: " << engine_name_;
return true;
}
void TRTInt8Calibrator::setDone() {
std::unique_lock<std::mutex> lk(mut_);
done_ = true;
cond_.notify_all();
}
const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) {
if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size();
return calibration_table_.data();
}
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {
calibration_table_ = std::string((const char*)ptr, length);
VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
<< " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(4) << "Destroying calibrator for " << engine_name_;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle