parent
d09d6eadc0
commit
4e3522e5b4
@ -0,0 +1,144 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
|
||||||
|
#include "glog/logging.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
// set the batch size before constructing the thread to execute engine
|
||||||
|
int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
|
||||||
|
|
||||||
|
TRTInt8Calibrator::TRTInt8Calibrator(
|
||||||
|
const std::unordered_map<std::string, size_t>& buffers, int batch_size,
|
||||||
|
std::string engine_name, const platform::Place place)
|
||||||
|
: batch_size_(batch_size),
|
||||||
|
calib_running_(true),
|
||||||
|
data_is_set_(false),
|
||||||
|
done_(false),
|
||||||
|
engine_name_(engine_name) {
|
||||||
|
int i = 0;
|
||||||
|
VLOG(4) << "Init a new calibrator: " << engine_name_;
|
||||||
|
for (const auto it : buffers) {
|
||||||
|
framework::Tensor temp_tensor;
|
||||||
|
std::string input_name = it.first;
|
||||||
|
int data_size = it.second;
|
||||||
|
int num_ele = data_size / sizeof(int16_t);
|
||||||
|
framework::DDim data_shape = framework::make_ddim({num_ele});
|
||||||
|
temp_tensor.Resize(data_shape);
|
||||||
|
data_tensors_.push_back(temp_tensor);
|
||||||
|
data_buffers_[input_name] = std::pair<void*, size_t>(
|
||||||
|
static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)), num_ele);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
|
||||||
|
: batch_size_(0),
|
||||||
|
calib_running_(false),
|
||||||
|
data_is_set_(false),
|
||||||
|
done_(true),
|
||||||
|
calibration_table_(calib_data) {}
|
||||||
|
|
||||||
|
void TRTInt8Calibrator::waitAndSetDone() {
|
||||||
|
std::unique_lock<std::mutex> lk(mut_);
|
||||||
|
while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
|
||||||
|
if (!done_) {
|
||||||
|
done_ = true;
|
||||||
|
cond_.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TRTInt8Calibrator::setBatch(
|
||||||
|
const std::unordered_map<std::string, void*>& data) {
|
||||||
|
VLOG(3) << "set batch: " << engine_name_;
|
||||||
|
std::unique_lock<std::mutex> lk(mut_);
|
||||||
|
while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
|
||||||
|
if (done_) return false;
|
||||||
|
|
||||||
|
// Sets the batch.
|
||||||
|
for (const auto it : data) {
|
||||||
|
auto dataptr = data_buffers_.find(it.first);
|
||||||
|
if (dataptr == data_buffers_.end()) {
|
||||||
|
LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
|
||||||
|
<< "' does not match with the buffer names";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& d = dataptr->second;
|
||||||
|
auto status =
|
||||||
|
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice);
|
||||||
|
if (status != cudaSuccess) {
|
||||||
|
LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
|
||||||
|
<< "' failed with " << status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data_is_set_ = true;
|
||||||
|
cond_.notify_all();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
|
||||||
|
int num_bindings) {
|
||||||
|
VLOG(4) << "get batch: " << engine_name_;
|
||||||
|
std::unique_lock<std::mutex> lk(mut_);
|
||||||
|
calib_running_ = false;
|
||||||
|
cond_.notify_all();
|
||||||
|
|
||||||
|
while (!data_is_set_ && !done_) cond_.wait(lk);
|
||||||
|
if (done_) return false;
|
||||||
|
|
||||||
|
// Gets the batch
|
||||||
|
for (int i = 0; i < num_bindings; i++) {
|
||||||
|
auto it = data_buffers_.find(names[i]);
|
||||||
|
if (it == data_buffers_.end()) {
|
||||||
|
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
|
||||||
|
<< names[i] << "' at position " << i;
|
||||||
|
}
|
||||||
|
bindings[i] = it->second.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
data_is_set_ = false;
|
||||||
|
calib_running_ = true;
|
||||||
|
VLOG(4) << "get batch done: " << engine_name_;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TRTInt8Calibrator::setDone() {
|
||||||
|
std::unique_lock<std::mutex> lk(mut_);
|
||||||
|
done_ = true;
|
||||||
|
cond_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
|
||||||
|
if (calibration_table_.empty()) return nullptr;
|
||||||
|
length = calibration_table_.size();
|
||||||
|
return calibration_table_.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
|
||||||
|
std::size_t length) {
|
||||||
|
calibration_table_ = std::string((const char*)ptr, length);
|
||||||
|
VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
|
||||||
|
<< " length=" << length;
|
||||||
|
}
|
||||||
|
TRTInt8Calibrator::~TRTInt8Calibrator() {
|
||||||
|
VLOG(4) << "Destroying calibrator for " << engine_name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,128 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "NvInfer.h"
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class TensorRTEngine;
|
||||||
|
|
||||||
|
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
|
||||||
|
public:
|
||||||
|
TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers,
|
||||||
|
int batch_size, std::string engine_name,
|
||||||
|
const platform::Place place);
|
||||||
|
|
||||||
|
explicit TRTInt8Calibrator(const std::string& calibration_data);
|
||||||
|
~TRTInt8Calibrator();
|
||||||
|
|
||||||
|
int getBatchSize() const override;
|
||||||
|
|
||||||
|
bool getBatch(void* bindings[], const char* names[],
|
||||||
|
int num_bindings) override;
|
||||||
|
|
||||||
|
bool setBatch(const std::unordered_map<std::string, void*>& data);
|
||||||
|
void setDone();
|
||||||
|
void waitAndSetDone();
|
||||||
|
|
||||||
|
const void* readCalibrationCache(std::size_t& length) override;
|
||||||
|
void writeCalibrationCache(const void* ptr, std::size_t length) override;
|
||||||
|
const std::string& getCalibrationTableAsString() {
|
||||||
|
return calibration_table_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int batch_size_;
|
||||||
|
|
||||||
|
bool calib_running_;
|
||||||
|
bool data_is_set_;
|
||||||
|
bool done_;
|
||||||
|
|
||||||
|
std::mutex mut_;
|
||||||
|
std::condition_variable cond_;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::pair<void*, size_t>> data_buffers_;
|
||||||
|
std::vector<framework::Tensor> data_tensors_;
|
||||||
|
|
||||||
|
std::string engine_name_;
|
||||||
|
std::string calibration_table_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TRTCalibratorRes {
|
||||||
|
public:
|
||||||
|
TRTCalibratorRes() {}
|
||||||
|
std::unique_ptr<TRTInt8Calibrator> calib_;
|
||||||
|
std::unique_ptr<std::thread> thr_;
|
||||||
|
std::unique_ptr<TensorRTEngine> engine_;
|
||||||
|
};
|
||||||
|
/*
|
||||||
|
* Manager to control the TensorRT Int8 calibration creation and deltetion.
|
||||||
|
*/
|
||||||
|
class TRTCalibratorResManager {
|
||||||
|
public:
|
||||||
|
bool Has() const { return res_.size() > 0; }
|
||||||
|
bool Has(const std::string& name) const {
|
||||||
|
if (res_.count(name) == 0) return false;
|
||||||
|
return res_.at(name).get() != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Int8Calibrator via name
|
||||||
|
TRTCalibratorRes* Get(const std::string& name) const {
|
||||||
|
return res_.at(name).get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up or create a calibrator.
|
||||||
|
TRTCalibratorRes* LookupOrCreate(const std::string& engine_name) {
|
||||||
|
if (res_.count(engine_name) == 0) {
|
||||||
|
auto* p = new TRTCalibratorRes();
|
||||||
|
res_[engine_name].reset(p);
|
||||||
|
}
|
||||||
|
return res_.at(engine_name).get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an Int8Calibrator
|
||||||
|
TRTCalibratorRes* Create(const std::string& engine_name) {
|
||||||
|
auto* p = new TRTCalibratorRes();
|
||||||
|
res_[engine_name].reset(p);
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeleteALL() {
|
||||||
|
for (auto& item : res_) {
|
||||||
|
item.second.reset(nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<TRTCalibratorRes>> res_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
Loading…
Reference in new issue