Added support for inference using quantization aware trained dygraph (#30288)
* added support for inference using qunatization aware trained dygraph * added support for inference using qunatization aware trained dygraph correct boost get usage * Delete incorrect warning message (#30196) * fix warning and no grad * clean redundant API alias in 2.0 - part 2 (#30013) * delete paddle.nn.functional.assign * fix dynamic to static error * just add the op error message for the matmul xpu (#30246) add the op error message for the matmul xpu * Add Static Variable Clone (#30208) Add clone method for static Variable so that this interface will be same as dygraph. It fixed some bugs in dy2stat * use wget to replace curl to download the lcov file (#30229) * use wget to replace curl to download the lcov file * add cache for lcov * fix test_pool3d_op timeout issue (#30248) * Fix unittests bugs. (#30250) * modify error message based on comments (#30189) * modify error message based on comments * edit code according to review. * Correct spelling according to review. * Fix bug for 'save mutiple method' (#30218) * Fix bug for 'save mutiple method' * To pass coverage. * edit code to pass coverage. * edit code to pass coverage. * add unittest for coverage. * change for coverage. * edit for coverage. * added support for inference using qunatization aware trained dygraph * Alias from paddle.fluid.layers.auc to paddle.static.auc (#30206) * add alias from fluid.layers.auc to static.auc * Update __init__.py * added support for inference using qunatization aware trained dygraph correct boost get usage * corrected boost get usage * corrected naming issues and enforcing zero check * correct paddle enforce message * added more error checkings * corrected error report message and optimized code * corrected findvar usage * corrected paddle_enforce in scope * correct error messages * correct error reporting format Co-authored-by: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Co-authored-by: XiaoguangHu <46782768+XiaoguangHu01@users.noreply.github.com> Co-authored-by: wawltor <fangzeyang0904@hotmail.com> Co-authored-by: Huihuang Zheng <zhhsplendid@gmail.com> Co-authored-by: YUNSHEN XIE <1084314248@qq.com> Co-authored-by: Bai Yifan <me@ethanbai.com> Co-authored-by: gongweibao <weibao.gong@gmail.com> Co-authored-by: WeiXin <weixin10@baidu.com> Co-authored-by: Jiaqi Liu <liujiaqi06@baidu.com>revert-31562-mean
parent
180877e988
commit
7bbf3ac5ab
@ -0,0 +1,237 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
|
||||||
|
#define GET_NODES \
|
||||||
|
GET_IR_NODE(quant_dequant_op_x); \
|
||||||
|
GET_IR_NODE(quant_dequant_op); \
|
||||||
|
GET_IR_NODE(quant_dequant_op_out); \
|
||||||
|
GET_IR_NODE(quant_dequant_op_outscale); \
|
||||||
|
GET_IR_NODE(any_op2);
|
||||||
|
|
||||||
|
// Delete quant_dequant_op, then quantize and dequantize weight
|
||||||
|
void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
|
||||||
|
const std::string pattern_name = "delete_quantdequant_filter_op_pattern";
|
||||||
|
FusePassBase::Init(pattern_name, graph);
|
||||||
|
|
||||||
|
GraphPatternDetector gpd;
|
||||||
|
|
||||||
|
// Create pattern
|
||||||
|
patterns::DeleteQuantDequantFilterOpPattern pattern(gpd.mutable_pattern(),
|
||||||
|
pattern_name);
|
||||||
|
pattern();
|
||||||
|
auto* scope = param_scope();
|
||||||
|
int found_count = 0;
|
||||||
|
|
||||||
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||||
|
Graph* g) {
|
||||||
|
GET_NODES;
|
||||||
|
|
||||||
|
std::unordered_set<const Node*> nodes2rm = {};
|
||||||
|
int bit_length =
|
||||||
|
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
|
||||||
|
int range = ((1 << (bit_length - 1)) - 1);
|
||||||
|
std::vector<float> weight_scale;
|
||||||
|
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();
|
||||||
|
|
||||||
|
auto* any_op2_desc = any_op2->Op();
|
||||||
|
auto var_map = any_op2_desc->Inputs();
|
||||||
|
std::string arg_name = "";
|
||||||
|
for (auto& name_m : var_map) {
|
||||||
|
if (std::find(name_m.second.begin(), name_m.second.end(),
|
||||||
|
quant_dequant_op_out_name) != name_m.second.end()) {
|
||||||
|
arg_name = name_m.first;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
|
||||||
|
"can not find the input %s.",
|
||||||
|
quant_dequant_op_out_name));
|
||||||
|
any_op2_desc->SetAttr("enable_int8", true);
|
||||||
|
any_op2_desc->SetAttr("bit_length", bit_length);
|
||||||
|
// modify the any_op2's inputs
|
||||||
|
any_op2_desc->Flush();
|
||||||
|
auto dequant_type = quant_dequant_op->Op()->Type();
|
||||||
|
auto quantized_op_type = any_op2_desc->Type();
|
||||||
|
|
||||||
|
// Get weight scale
|
||||||
|
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
|
||||||
|
auto scales_name = quant_dequant_op->Op()->Output("OutScale");
|
||||||
|
PADDLE_ENFORCE_EQ(scales_name.size(), 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Scales size in channel-wise quant dequantize op "
|
||||||
|
"should be 1, got %d.",
|
||||||
|
scales_name.size()));
|
||||||
|
const LoDTensor& channel_scale_tensor =
|
||||||
|
scope->GetVar(scales_name[0])->Get<LoDTensor>();
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Channel scale tensor's place should be CPU."));
|
||||||
|
const float* channel_scale_data = channel_scale_tensor.data<float>();
|
||||||
|
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
|
||||||
|
weight_scale.push_back(range / channel_scale_data[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto scale_name = quant_dequant_op_outscale->Name();
|
||||||
|
const LoDTensor& scale_tensor =
|
||||||
|
scope->GetVar(scale_name)->Get<LoDTensor>();
|
||||||
|
const float* scale_data = scale_tensor.data<float>();
|
||||||
|
weight_scale.push_back((range * range) / scale_data[0] / range);
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes2rm.insert(quant_dequant_op_outscale);
|
||||||
|
// perform quantize dequantize operations
|
||||||
|
auto* weight_tensor =
|
||||||
|
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
|
||||||
|
auto w_dims = weight_tensor->dims();
|
||||||
|
float* quantized_weight_data =
|
||||||
|
weight_tensor->mutable_data<float>(platform::CPUPlace());
|
||||||
|
// If quantized op is fc, weight scale size = 1;
|
||||||
|
// If quantized op is conv2d, weight scale size = weight dims[0]
|
||||||
|
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
|
||||||
|
if (dequant_type == "fake_quantize_dequantize_abs_max") {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_scale.size(), 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
|
||||||
|
"requires weight scale size = 1, but got %d.",
|
||||||
|
quantized_op_type, weight_scale.size()));
|
||||||
|
PADDLE_ENFORCE_NE(weight_scale[0], 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Weight scale should be nonzero, but get zero"));
|
||||||
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
||||||
|
// quantized
|
||||||
|
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];
|
||||||
|
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
|
||||||
|
// dequantized
|
||||||
|
quantized_weight_data[j] /= weight_scale[0];
|
||||||
|
}
|
||||||
|
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
|
||||||
|
quantized_op_type == "fc") {
|
||||||
|
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_scale.size(), static_cast<size_t>(w_dims[1]),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"mul op weight dequantized by "
|
||||||
|
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
|
||||||
|
"weight scale "
|
||||||
|
"size = 2nd dim of mul's weight, which is %zu, but got %zu.",
|
||||||
|
static_cast<size_t>(w_dims[1]), weight_scale.size()));
|
||||||
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
||||||
|
// quantized
|
||||||
|
PADDLE_ENFORCE_NE(
|
||||||
|
weight_scale[j % w_dims[1]], 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"fc op weight scale should be nonzero, but get zero"));
|
||||||
|
quantized_weight_data[j] =
|
||||||
|
quantized_weight_data[j] * weight_scale[j % w_dims[1]];
|
||||||
|
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
|
||||||
|
// dequantized
|
||||||
|
quantized_weight_data[j] /= weight_scale[j % w_dims[1]];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||||
|
"Unsupported quantized op type: %s", quantized_op_type));
|
||||||
|
}
|
||||||
|
} else if (quantized_op_type == "conv2d" ||
|
||||||
|
quantized_op_type == "depthwise_conv2d") {
|
||||||
|
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_scale.size(), static_cast<size_t>(w_dims[0]),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"conv2d op requires weight scale size = channel size of the "
|
||||||
|
"weight, which is %zu, but got %zu.",
|
||||||
|
static_cast<size_t>(w_dims[0]), weight_scale.size()));
|
||||||
|
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
|
||||||
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
||||||
|
// quantized
|
||||||
|
PADDLE_ENFORCE_NE(
|
||||||
|
weight_scale[j / inner_size], 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"conv2d op weight scale should be nonzero, but get zero"));
|
||||||
|
quantized_weight_data[j] =
|
||||||
|
quantized_weight_data[j] * weight_scale[j / inner_size];
|
||||||
|
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
|
||||||
|
// dequantized
|
||||||
|
quantized_weight_data[j] /= weight_scale[j / inner_size];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||||
|
"Unsupported quantized op type: %s", quantized_op_type));
|
||||||
|
}
|
||||||
|
} else if (quantized_op_type == "conv2d_transpose") {
|
||||||
|
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
weight_scale.size(), static_cast<size_t>(w_dims[0]),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"conv2d_transpose op requires weight scale size = channel size "
|
||||||
|
"of the "
|
||||||
|
"weight, which is %zu, but got %zu.",
|
||||||
|
static_cast<size_t>(w_dims[1]), weight_scale.size()));
|
||||||
|
int inner_size = w_dims[2] * w_dims[3];
|
||||||
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
||||||
|
// quantized
|
||||||
|
PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"conv2d_transpose op weight scale should be "
|
||||||
|
"nonzero, but get zero"));
|
||||||
|
quantized_weight_data[j] = quantized_weight_data[j] *
|
||||||
|
weight_scale[(j / inner_size) % w_dims[1]];
|
||||||
|
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
|
||||||
|
// dequantized
|
||||||
|
quantized_weight_data[j] /=
|
||||||
|
weight_scale[(j / inner_size) % w_dims[1]];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||||
|
"Unsupported quantized op type: %s", quantized_op_type));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||||
|
"Unsupported quantized op type: %s", quantized_op_type));
|
||||||
|
}
|
||||||
|
nodes2rm.insert(quant_dequant_op_out);
|
||||||
|
|
||||||
|
// link weight in quant_dequant_op_x to any_op2
|
||||||
|
any_op2_desc->RenameInput(quant_dequant_op_out->Var()->Name(),
|
||||||
|
quant_dequant_op_x->Var()->Name());
|
||||||
|
any_op2_desc->SetAttr("weight_scale", weight_scale);
|
||||||
|
any_op2_desc->Flush();
|
||||||
|
IR_NODE_LINK_TO(quant_dequant_op_x, any_op2);
|
||||||
|
nodes2rm.insert(quant_dequant_op);
|
||||||
|
GraphSafeRemoveNodes(graph, nodes2rm);
|
||||||
|
found_count++;
|
||||||
|
};
|
||||||
|
gpd(graph, handler);
|
||||||
|
AddStatis(found_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(delete_quant_dequant_filter_op_pass,
|
||||||
|
paddle::framework::ir::DeleteQuantDequantFilterOpPass);
|
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
class Graph;
|
||||||
|
|
||||||
|
class DeleteQuantDequantFilterOpPass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~DeleteQuantDequantFilterOpPass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void ApplyImpl(ir::Graph* graph) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue