Reset DeviceContext after quantization warmup (#18182)

test=develop
nan-debug-tool
Michał Gallus 6 years ago committed by Tao Luo
parent b7128bac5f
commit 8409693272

@ -355,6 +355,13 @@ AnalysisPredictor::MkldnnQuantizer::Histogram(
return std::make_pair(std::move(hist), std::move(bin_width));
}
void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_);
dev_ctx->ResetBlobMap();
}
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto& arg = predictor_.argument_;
if (!arg.scope_valid()) arg.SetScope(new framework::Scope);
@ -380,6 +387,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
bool AnalysisPredictor::MkldnnQuantizer::Quantize() {
if (!RunWarmup()) return false;
if (!CalculateScales()) return false;
ClearDeviceContext();
predictor_.PrepareScope(predictor_.scope_);
predictor_.CreateExecutor();
if (!RunQuantizePasses()) return false;

@ -68,6 +68,7 @@ class AnalysisPredictor::MkldnnQuantizer {
const framework::LoDTensor& var_tensor,
bool is_unsigned);
void PrepareArgument() const;
void ClearDeviceContext() const;
bool RunQuantizePasses() const;
std::vector<int> ExpandQuantizedBins(std::vector<int> quantized_bins,

@ -408,6 +408,8 @@ thread_local int cur_thread_id = 0;
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get();

@ -391,6 +391,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
/* \brief Get the active engine */
const mkldnn::engine& GetEngine() const { return engine_; }
// Remove all entries from the blob map
void ResetBlobMap() const;
// Set data to blob (i.e. name/data pair). Create blob if not existing
void SetBlob(const std::string& name, std::shared_ptr<void> data) const;

Loading…
Cancel
Save