|
|
|
@ -12,12 +12,10 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/string/pretty_log.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -33,8 +31,38 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) {
|
|
|
|
|
b->inputs.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checking whether a reorder from FP32 to BF16 should be added before the input
|
|
|
|
|
// to the operator
|
|
|
|
|
bool IsPermittedInputName(const std::string& input_name) {
|
|
|
|
|
// Only the inputs listed in \"permitted_names\" requires quanitization before
|
|
|
|
|
// the bfloat16 operator. Other inputs, such as Filter and Bias are reordered
|
|
|
|
|
// in the kernel.
|
|
|
|
|
const std::vector<std::string> permitted_names = {"X", "Y", "Input",
|
|
|
|
|
"ResidualData"};
|
|
|
|
|
return (std::find(permitted_names.begin(), permitted_names.end(),
|
|
|
|
|
input_name) != permitted_names.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checking whether a reorder from BF16 to FP32 should be added after the output
|
|
|
|
|
// to the operator
|
|
|
|
|
bool IsPermittedOutputName(const std::string& output_name) {
|
|
|
|
|
// XShape is output in transpose2 and reshape2 operators used to store the
|
|
|
|
|
// shape and lod of X. So this output do not need dequantize before.
|
|
|
|
|
return (output_name != "XShape");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
|
|
|
|
|
int* quantize_counter) {
|
|
|
|
|
std::vector<std::string> input_names;
|
|
|
|
|
|
|
|
|
|
// Find the name of the input linking op to op_in
|
|
|
|
|
for (auto name : op->Op()->InputNames())
|
|
|
|
|
for (auto input_name : op->Op()->Input(name))
|
|
|
|
|
if (input_name == op_in->Name() && IsPermittedInputName(name))
|
|
|
|
|
input_names.push_back(name);
|
|
|
|
|
|
|
|
|
|
if (input_names.empty()) return;
|
|
|
|
|
|
|
|
|
|
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
|
|
|
|
|
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
|
|
|
|
|
|
|
|
|
@ -44,23 +72,12 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
|
|
|
|
|
q_desc.SetOutput("Output",
|
|
|
|
|
std::vector<std::string>({quantize_out_node->Name()}));
|
|
|
|
|
q_desc.SetAttr("Scale", 1.f);
|
|
|
|
|
q_desc.SetAttr("Shift", 0.0f);
|
|
|
|
|
q_desc.SetAttr("bfloat16", true);
|
|
|
|
|
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
|
|
|
|
|
? op->Op()->GetAttr("data_layout")
|
|
|
|
|
: std::string("NCHW"));
|
|
|
|
|
auto quantize_op = g->CreateOpNode(&q_desc);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> input_names;
|
|
|
|
|
for (auto name : op->Op()->InputNames()) {
|
|
|
|
|
for (auto input_name : op->Op()->Input(name)) {
|
|
|
|
|
if (input_name == op_in->Name()) input_names.push_back(name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
input_names.empty(), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Operator before operator should have input as op output"));
|
|
|
|
|
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
|
|
|
|
|
|
|
|
|
|
for (auto name = input_names.begin(); name < input_names.end(); name++)
|
|
|
|
|
op->Op()->SetInput(*name,
|
|
|
|
@ -99,11 +116,12 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
|
|
|
|
|
q_desc.SetOutput("Output",
|
|
|
|
|
std::vector<std::string>({quantize_out_node_names[i]}));
|
|
|
|
|
q_desc.SetAttr("Scale", 1.f);
|
|
|
|
|
q_desc.SetAttr("Shift", 0.0f);
|
|
|
|
|
q_desc.SetAttr("bfloat16", true);
|
|
|
|
|
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
|
|
|
|
|
? op->Op()->GetAttr("data_layout")
|
|
|
|
|
: std::string("NCHW"));
|
|
|
|
|
auto quantize_op = g->CreateOpNode(&q_desc);
|
|
|
|
|
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
|
|
|
|
|
|
|
|
|
|
UnlinkNodes(inputs[i], op);
|
|
|
|
|
IR_NODE_LINK_TO(inputs[i], quantize_op);
|
|
|
|
@ -115,6 +133,9 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
|
|
|
|
|
op->Op()->SetInput("X", quantize_out_node_names);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Operators like Concat and Sum have a single input name X, which actually
|
|
|
|
|
// consists of multiple inputs. Such operators require a different way to find
|
|
|
|
|
// pattern and add quantize ops.
|
|
|
|
|
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
|
|
|
|
@ -128,38 +149,8 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RemoveUnnecessaryReorders(ir::Graph* graph, int* quantize_counter) {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::UnnecessaryReorders unnecessary_reorders{gpd.mutable_pattern(),
|
|
|
|
|
"unnecessary_reorders"};
|
|
|
|
|
unnecessary_reorders();
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, unnecessary_reorders);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, unnecessary_reorders);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, unnecessary_reorders);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, unnecessary_reorders);
|
|
|
|
|
|
|
|
|
|
std::string op_output_name;
|
|
|
|
|
for (auto name : prev_op->Op()->OutputNames())
|
|
|
|
|
for (auto output_name : prev_op->Op()->Output(name))
|
|
|
|
|
if (output_name == quant_in->Name()) op_output_name = name;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
op_output_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Operator before operator should have input as op output"));
|
|
|
|
|
|
|
|
|
|
prev_op->Op()->SetOutput(op_output_name,
|
|
|
|
|
std::vector<std::string>({quant_out->Name()}));
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(prev_op, quant_out);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {quant_in, quant_op});
|
|
|
|
|
(*quantize_counter)--;
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Adding quantize ops before all operators except Concat and Sum, which have
|
|
|
|
|
// already been handled in AddReoderBeforeDuplicatedInputs
|
|
|
|
|
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
|
|
|
|
@ -167,12 +158,9 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
|
|
|
|
|
bfloat16_ops();
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
|
|
|
|
|
auto prev_op_type = prev_op->Op()->Type();
|
|
|
|
|
if (op->Op()->Type() != "conv2d" && prev_op_type != "quantize" &&
|
|
|
|
|
prev_op_type != "sum" && prev_op_type != "concat") {
|
|
|
|
|
if (op->Op()->Type() != "sum" && op->Op()->Type() != "concat") {
|
|
|
|
|
AddQuantize(g, op, op_in, quantize_counter);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -182,9 +170,8 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
|
|
|
|
|
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
|
|
|
|
|
int quantize_counter = 0;
|
|
|
|
|
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter);
|
|
|
|
|
RemoveUnnecessaryReorders(graph, &quantize_counter);
|
|
|
|
|
AddReoderBeforeSingleInputs(graph, &quantize_counter);
|
|
|
|
|
PrettyLogDetail("--- added %d quantize op before bfloat16 op",
|
|
|
|
|
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
|
|
|
|
|
quantize_counter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -193,55 +180,51 @@ void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
|
|
|
|
|
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
|
|
|
|
|
"last_bfloat16_ops"};
|
|
|
|
|
bfloat16_ops();
|
|
|
|
|
int force_fp32_counter = 0, dequantize_counter = 0;
|
|
|
|
|
int dequantize_counter = 0;
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops);
|
|
|
|
|
if ((op->Op()->HasAttr("force_fp32_output") ||
|
|
|
|
|
op->Op()->HasProtoAttr("force_fp32_output")) &&
|
|
|
|
|
!op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) {
|
|
|
|
|
op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
|
force_fp32_counter++;
|
|
|
|
|
} else if (op->Op()->Type() != "prior_box") {
|
|
|
|
|
VarDesc dequantize_out_desc(patterns::PDNodeName("dequantize", "out"));
|
|
|
|
|
auto* dequantize_out_node = g->CreateVarNode(&dequantize_out_desc);
|
|
|
|
|
|
|
|
|
|
if (op->Op()->Type() != "prior_box") {
|
|
|
|
|
// Find the name of the output linking op to op_out
|
|
|
|
|
std::vector<std::string> output_names;
|
|
|
|
|
for (auto name : op->Op()->OutputNames())
|
|
|
|
|
for (auto output_name : op->Op()->Output(name))
|
|
|
|
|
if (output_name == op_out->Name() && IsPermittedOutputName(name))
|
|
|
|
|
output_names.push_back(name);
|
|
|
|
|
|
|
|
|
|
if (output_names.empty()) return;
|
|
|
|
|
|
|
|
|
|
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
|
|
|
|
|
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
|
|
|
|
|
|
|
|
|
|
OpDesc deq_desc;
|
|
|
|
|
deq_desc.SetType("dequantize");
|
|
|
|
|
deq_desc.SetInput("Input", std::vector<std::string>({op_out->Name()}));
|
|
|
|
|
deq_desc.SetOutput(
|
|
|
|
|
"Output", std::vector<std::string>({dequantize_out_node->Name()}));
|
|
|
|
|
deq_desc.SetInput("Input",
|
|
|
|
|
std::vector<std::string>({dequantize_in_node->Name()}));
|
|
|
|
|
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
|
|
|
|
|
deq_desc.SetAttr("Scale", 1.0f);
|
|
|
|
|
auto dequantize_op = g->CreateOpNode(&deq_desc);
|
|
|
|
|
|
|
|
|
|
std::string next_op_input_name;
|
|
|
|
|
for (auto name : next_op->Op()->InputNames()) {
|
|
|
|
|
for (auto input_name : next_op->Op()->Input(name)) {
|
|
|
|
|
if (input_name == op_out->Name()) next_op_input_name = name;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
next_op_input_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Operator before operator should have input as op output"));
|
|
|
|
|
|
|
|
|
|
next_op->Op()->SetInput(
|
|
|
|
|
next_op_input_name,
|
|
|
|
|
std::vector<std::string>({dequantize_out_node->Name()}));
|
|
|
|
|
UnlinkNodes(op_out, next_op);
|
|
|
|
|
IR_NODE_LINK_TO(op_out, dequantize_op);
|
|
|
|
|
IR_NODE_LINK_TO(dequantize_op, dequantize_out_node);
|
|
|
|
|
IR_NODE_LINK_TO(dequantize_out_node, next_op);
|
|
|
|
|
deq_desc.SetAttr("Shift", 0.0f);
|
|
|
|
|
auto dequantize_op =
|
|
|
|
|
g->CreateOpNode(&deq_desc); // OpDesc will be copied.
|
|
|
|
|
|
|
|
|
|
for (auto name = output_names.begin(); name < output_names.end(); name++)
|
|
|
|
|
op->Op()->SetOutput(
|
|
|
|
|
*name, std::vector<std::string>({dequantize_in_node->Name()}));
|
|
|
|
|
|
|
|
|
|
UnlinkNodes(op, op_out);
|
|
|
|
|
IR_NODE_LINK_TO(op, dequantize_in_node);
|
|
|
|
|
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
|
|
|
|
|
IR_NODE_LINK_TO(dequantize_op, op_out);
|
|
|
|
|
|
|
|
|
|
dequantize_counter++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output",
|
|
|
|
|
dequantize_counter, force_fp32_counter);
|
|
|
|
|
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
|
|
|
|
|
dequantize_counter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|