commit
2e309b11c2
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,147 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
// set the batch size before constructing the thread to execute engine
|
||||
int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
|
||||
|
||||
TRTInt8Calibrator::TRTInt8Calibrator(
|
||||
const std::unordered_map<std::string, size_t>& buffers, int batch_size,
|
||||
std::string engine_name, const platform::Place place)
|
||||
: batch_size_(batch_size), engine_name_(engine_name) {
|
||||
int i = 0;
|
||||
VLOG(4) << "Init a new calibrator: " << engine_name_;
|
||||
for (const auto it : buffers) {
|
||||
framework::Tensor temp_tensor;
|
||||
std::string input_name = it.first;
|
||||
int data_size = it.second;
|
||||
int num_ele = data_size / sizeof(int16_t);
|
||||
framework::DDim data_shape = framework::make_ddim({num_ele});
|
||||
temp_tensor.Resize(data_shape);
|
||||
data_tensors_.push_back(temp_tensor);
|
||||
data_buffers_[input_name] = std::pair<void*, size_t>(
|
||||
static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)), num_ele);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
|
||||
: batch_size_(0),
|
||||
calib_running_(false),
|
||||
data_is_set_(false),
|
||||
done_(true),
|
||||
calibration_table_(calib_data) {}
|
||||
|
||||
void TRTInt8Calibrator::waitAndSetDone() {
|
||||
std::unique_lock<std::mutex> lk(mut_);
|
||||
while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
|
||||
if (!done_) {
|
||||
done_ = true;
|
||||
cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
// There might be more than one input for trt subgraph,
|
||||
// So, we use a map to store input information.
|
||||
bool TRTInt8Calibrator::setBatch(
|
||||
const std::unordered_map<std::string, void*>& data) {
|
||||
VLOG(3) << "set batch: " << engine_name_;
|
||||
std::unique_lock<std::mutex> lk(mut_);
|
||||
// There is a producer and a consumer. The producer set the batch data and
|
||||
// the consumer get the batch data. The size of the data pool is one.
|
||||
// So, the producer has to wait for the consumer to finish processing before
|
||||
// they can set the data.
|
||||
while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
|
||||
// The done_ is set to true using waitAndSetDone, When all calibration data
|
||||
// are processed.
|
||||
if (done_) return false;
|
||||
|
||||
// Sets the batch.
|
||||
for (const auto& it : data) {
|
||||
auto dataptr = data_buffers_.find(it.first);
|
||||
if (dataptr == data_buffers_.end()) {
|
||||
LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
|
||||
<< "' does not match with the buffer names";
|
||||
}
|
||||
const auto& d = dataptr->second;
|
||||
PADDLE_ENFORCE(
|
||||
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice),
|
||||
"Fail to cudaMemcpy %s for %s", engine_name_, it.first);
|
||||
}
|
||||
|
||||
data_is_set_ = true;
|
||||
cond_.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
|
||||
int num_bindings) {
|
||||
VLOG(4) << "get batch: " << engine_name_;
|
||||
std::unique_lock<std::mutex> lk(mut_);
|
||||
// The consumer has just finished processing a data.
|
||||
// The producer can set the data again.
|
||||
calib_running_ = false;
|
||||
cond_.notify_all();
|
||||
|
||||
// As long as there is data in the pool, the consumer can get it.
|
||||
while (!data_is_set_ && !done_) cond_.wait(lk);
|
||||
if (done_) return false;
|
||||
|
||||
// Gets the batch
|
||||
for (int i = 0; i < num_bindings; i++) {
|
||||
auto it = data_buffers_.find(names[i]);
|
||||
if (it == data_buffers_.end()) {
|
||||
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
|
||||
<< names[i] << "' at position " << i;
|
||||
}
|
||||
bindings[i] = it->second.first;
|
||||
}
|
||||
|
||||
data_is_set_ = false;
|
||||
calib_running_ = true;
|
||||
VLOG(4) << "get batch done: " << engine_name_;
|
||||
return true;
|
||||
}
|
||||
|
||||
void TRTInt8Calibrator::setDone() {
|
||||
std::unique_lock<std::mutex> lk(mut_);
|
||||
done_ = true;
|
||||
cond_.notify_all();
|
||||
}
|
||||
|
||||
const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) {
|
||||
if (calibration_table_.empty()) return nullptr;
|
||||
length = calibration_table_.size();
|
||||
return calibration_table_.data();
|
||||
}
|
||||
|
||||
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
|
||||
std::size_t length) {
|
||||
calibration_table_ = std::string((const char*)ptr, length);
|
||||
VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
|
||||
<< " length=" << length;
|
||||
}
|
||||
TRTInt8Calibrator::~TRTInt8Calibrator() {
|
||||
VLOG(4) << "Destroying calibrator for " << engine_name_;
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,128 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class TensorRTEngine;
|
||||
|
||||
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
|
||||
public:
|
||||
TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers,
|
||||
int batch_size, std::string engine_name,
|
||||
const platform::Place place);
|
||||
|
||||
explicit TRTInt8Calibrator(const std::string& calibration_data);
|
||||
~TRTInt8Calibrator();
|
||||
|
||||
int getBatchSize() const override;
|
||||
|
||||
bool getBatch(void* bindings[], const char* names[],
|
||||
int num_bindings) override;
|
||||
|
||||
bool setBatch(const std::unordered_map<std::string, void*>& data);
|
||||
void setDone();
|
||||
void waitAndSetDone();
|
||||
|
||||
const void* readCalibrationCache(std::size_t& length) override;
|
||||
void writeCalibrationCache(const void* ptr, std::size_t length) override;
|
||||
const std::string& getCalibrationTableAsString() {
|
||||
return calibration_table_;
|
||||
}
|
||||
|
||||
private:
|
||||
const int batch_size_;
|
||||
|
||||
bool calib_running_{true};
|
||||
bool data_is_set_{false};
|
||||
bool done_{false};
|
||||
|
||||
std::mutex mut_;
|
||||
std::condition_variable cond_;
|
||||
|
||||
std::unordered_map<std::string, std::pair<void*, size_t>> data_buffers_;
|
||||
std::vector<framework::Tensor> data_tensors_;
|
||||
|
||||
std::string engine_name_;
|
||||
std::string calibration_table_;
|
||||
};
|
||||
|
||||
class TRTCalibratorEngine {
|
||||
public:
|
||||
TRTCalibratorEngine() {}
|
||||
std::unique_ptr<TRTInt8Calibrator> calib_;
|
||||
std::unique_ptr<std::thread> thr_;
|
||||
std::unique_ptr<TensorRTEngine> engine_;
|
||||
};
|
||||
/*
|
||||
* Manager to control the TensorRT Int8 calibration creation and deltetion.
|
||||
*/
|
||||
class TRTCalibratorEngineManager {
|
||||
public:
|
||||
bool Has() const { return res_.size() > 0; }
|
||||
bool Has(const std::string& name) const {
|
||||
if (res_.count(name) == 0) return false;
|
||||
return res_.at(name).get() != nullptr;
|
||||
}
|
||||
|
||||
// Get Int8Calibrator via name
|
||||
TRTCalibratorEngine* Get(const std::string& name) const {
|
||||
return res_.at(name).get();
|
||||
}
|
||||
|
||||
// Look up or create a calibrator.
|
||||
TRTCalibratorEngine* LookupOrCreate(const std::string& engine_name) {
|
||||
if (res_.count(engine_name) == 0) {
|
||||
auto* p = new TRTCalibratorEngine;
|
||||
res_[engine_name].reset(p);
|
||||
}
|
||||
return res_.at(engine_name).get();
|
||||
}
|
||||
|
||||
// Create an Int8Calibrator
|
||||
TRTCalibratorEngine* Create(const std::string& engine_name) {
|
||||
auto* p = new TRTCalibratorEngine;
|
||||
res_[engine_name].reset(p);
|
||||
return p;
|
||||
}
|
||||
|
||||
void DeleteALL() {
|
||||
for (auto& item : res_) {
|
||||
item.second.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::unique_ptr<TRTCalibratorEngine>> res_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue