Add bfloat16 passes (#26999)
parent
6947a58a1f
commit
1483ea2304
@ -0,0 +1,159 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
#include "paddle/fluid/string/pretty_log.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
using string::PrettyLogDetail;
|
||||
|
||||
void UnlinkNodes(ir::Node* a, ir::Node* b) {
|
||||
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
|
||||
a->outputs.end());
|
||||
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
|
||||
b->inputs.end());
|
||||
}
|
||||
|
||||
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
|
||||
"first_bfloat16_ops"};
|
||||
bfloat16_ops();
|
||||
int quantize_counter = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
|
||||
|
||||
if (op->Op()->Type() != "conv2d" && prev_op->Op()->Type() != "quantize") {
|
||||
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
|
||||
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
|
||||
|
||||
// create a quantize op node
|
||||
OpDesc q_desc;
|
||||
q_desc.SetType("quantize");
|
||||
q_desc.SetInput("Input", std::vector<std::string>({op_in->Name()}));
|
||||
q_desc.SetOutput("Output",
|
||||
std::vector<std::string>({quantize_out_node->Name()}));
|
||||
q_desc.SetAttr("Scale", 1.f);
|
||||
q_desc.SetAttr("bfloat16", true);
|
||||
q_desc.SetAttr("output_format", Has("data_layout")
|
||||
? Get<std::string>("data_layout")
|
||||
: "NCHW");
|
||||
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
|
||||
|
||||
std::string op_input_name;
|
||||
for (auto name : op->Op()->InputNames()) {
|
||||
for (auto input_name : op->Op()->Input(name)) {
|
||||
if (input_name == op_in->Name()) op_input_name = name;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_NE(
|
||||
op_input_name.empty(), true,
|
||||
platform::errors::NotFound(
|
||||
"Operator before operator should have input as op output"));
|
||||
|
||||
op->Op()->SetInput(op_input_name,
|
||||
std::vector<std::string>({quantize_out_node->Name()}));
|
||||
|
||||
UnlinkNodes(op_in, op);
|
||||
IR_NODE_LINK_TO(op_in, quantize_op);
|
||||
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
|
||||
IR_NODE_LINK_TO(quantize_out_node, op);
|
||||
quantize_counter++;
|
||||
}
|
||||
};
|
||||
gpd(graph, handler);
|
||||
PrettyLogDetail("--- added %d quantize op before bfloat16 op",
|
||||
quantize_counter);
|
||||
}
|
||||
|
||||
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
|
||||
"last_bfloat16_ops"};
|
||||
bfloat16_ops();
|
||||
int force_fp32_counter = 0, dequantize_counter = 0;
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops);
|
||||
|
||||
if ((op->Op()->HasAttr("force_fp32_output") ||
|
||||
op->Op()->HasProtoAttr("force_fp32_output")) &&
|
||||
!op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) {
|
||||
op->Op()->SetAttr("force_fp32_output", true);
|
||||
force_fp32_counter++;
|
||||
} else if (op->Op()->Type() != "prior_box") {
|
||||
// Create dequantize input variable
|
||||
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
|
||||
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
|
||||
|
||||
// create a dequantize op node for output.
|
||||
OpDesc deq_desc;
|
||||
deq_desc.SetType("dequantize");
|
||||
deq_desc.SetInput("Input",
|
||||
std::vector<std::string>({dequantize_in_node->Name()}));
|
||||
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
|
||||
deq_desc.SetAttr("Scale", 1.0f);
|
||||
auto dequantize_op = g->CreateOpNode(&deq_desc);
|
||||
|
||||
std::string op_output_name;
|
||||
for (auto name : op->Op()->OutputNames()) {
|
||||
for (auto output_name : op->Op()->Output(name)) {
|
||||
if (output_name == op_out->Name()) op_output_name = name;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_NE(
|
||||
op_output_name.empty(), true,
|
||||
platform::errors::NotFound(
|
||||
"Operator after operator should have input as op output"));
|
||||
|
||||
op->Op()->SetOutput(op_output_name, std::vector<std::string>(
|
||||
{dequantize_in_node->Name()}));
|
||||
|
||||
UnlinkNodes(op, op_out);
|
||||
IR_NODE_LINK_TO(op, dequantize_in_node);
|
||||
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
|
||||
IR_NODE_LINK_TO(dequantize_op, op_out);
|
||||
dequantize_counter++;
|
||||
}
|
||||
};
|
||||
gpd(graph, handler);
|
||||
PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output",
|
||||
dequantize_counter, force_fp32_counter);
|
||||
}
|
||||
|
||||
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
|
||||
SetInputDataType(graph);
|
||||
SetOutputDataType(graph);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(cpu_bfloat16_pass, paddle::framework::ir::CPUBFloat16Pass);
|
@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class CPUBFloat16Pass : public Pass {
|
||||
protected:
|
||||
void SetInputDataType(ir::Graph* graph) const;
|
||||
void SetOutputDataType(ir::Graph* graph) const;
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,145 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
|
||||
#include "paddle/fluid/framework/naive_executor.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs, bool use_mkldnn,
|
||||
const std::string& mkldnn_data_type = "float32",
|
||||
const bool force_fp32_output = false) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
op->SetAttr("name", name);
|
||||
|
||||
if (type == "conv2d") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
|
||||
op->SetAttr("force_fp32_output", force_fp32_output);
|
||||
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" ||
|
||||
type == "dropout") {
|
||||
op->SetInput("X", {inputs[0]});
|
||||
op->SetOutput("Out", {outputs[0]});
|
||||
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
|
||||
} else if (type == "fc") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetOutput("Out", {outputs[0]});
|
||||
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
|
||||
} else if (type == "concat") {
|
||||
op->SetInput("X", inputs);
|
||||
op->SetOutput("Out", outputs);
|
||||
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
|
||||
} else if (type == "matmul" || type == "elementwise_add") {
|
||||
op->SetInput("X", {inputs[0]});
|
||||
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
|
||||
op->SetOutput("Out", {outputs[0]});
|
||||
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
|
||||
}
|
||||
}
|
||||
|
||||
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
|
||||
const std::initializer_list<std::string> variable_names,
|
||||
int* original_nodes_num, int* current_nodes_num) {
|
||||
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass");
|
||||
|
||||
graph->reset(pass->Apply(graph->release()));
|
||||
|
||||
*original_nodes_num = (*graph)->Nodes().size();
|
||||
(*graph).reset(pass->Apply((*graph).release()));
|
||||
*current_nodes_num = (*graph)->Nodes().size();
|
||||
}
|
||||
|
||||
static const std::initializer_list<std::string> variable_names{
|
||||
"z", "a", "b", "c", "d", "e", "f", "g", "h", "i"};
|
||||
|
||||
ProgramDesc BuildProgramDesc(bool use_mkldnn) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : variable_names) {
|
||||
prog.MutableBlock(0)->Var(v);
|
||||
}
|
||||
SetOp(&prog, "dropout", "Dropout1", {"z"}, {"a"}, use_mkldnn, "float32");
|
||||
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, "bfloat16");
|
||||
SetOp(&prog, "pool2d", "Pool1", {"b"}, {"c"}, use_mkldnn, "bfloat16");
|
||||
SetOp(&prog, "conv2d", "Conv1", {"c"}, {"d"}, use_mkldnn, "bfloat16");
|
||||
SetOp(&prog, "dropout", "Dropout2", {"d"}, {"e"}, use_mkldnn, "float32");
|
||||
SetOp(&prog, "transpose2", "Transpose1", {"e"}, {"f"}, use_mkldnn,
|
||||
"bfloat16");
|
||||
SetOp(&prog, "reshape2", "Reshape1", {"f"}, {"g"}, use_mkldnn, "bfloat16");
|
||||
SetOp(&prog, "concat", "Concat1", {"g"}, {"h"}, use_mkldnn, "bfloat16");
|
||||
SetOp(&prog, "dropout", "Dropout3", {"h"}, {"i"}, use_mkldnn, "float32");
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
|
||||
int transpose_count, int quant_count, int dequant_count,
|
||||
int added_nodes_count) {
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
int original_nodes_num, current_nodes_num;
|
||||
PreparePass(&graph, prog, variable_names, &original_nodes_num,
|
||||
¤t_nodes_num);
|
||||
|
||||
int quantize_nodes_count = 0;
|
||||
int dequantize_nodes_count = 0;
|
||||
int conv2d_nodes_count = 0;
|
||||
int pool2d_nodes_count = 0;
|
||||
int transpose2_nodes_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp()) {
|
||||
auto* op = node->Op();
|
||||
if (op->Type() == "conv2d") {
|
||||
conv2d_nodes_count++;
|
||||
} else if (op->Type() == "pool2d") {
|
||||
pool2d_nodes_count++;
|
||||
} else if (op->Type() == "transpose2") {
|
||||
transpose2_nodes_count++;
|
||||
} else if (op->Type() == "quantize") {
|
||||
quantize_nodes_count++;
|
||||
} else if (op->Type() == "dequantize") {
|
||||
dequantize_nodes_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(conv2d_nodes_count, conv_count);
|
||||
EXPECT_EQ(pool2d_nodes_count, pool_count);
|
||||
EXPECT_EQ(transpose2_nodes_count, transpose_count);
|
||||
EXPECT_EQ(quantize_nodes_count, quant_count);
|
||||
EXPECT_EQ(dequantize_nodes_count, dequant_count);
|
||||
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizePass, quantize) {
|
||||
bool use_mkldnn = true;
|
||||
// 1 quantize + 1 dequantize
|
||||
int added_nodes = 2;
|
||||
MainTest(BuildProgramDesc(use_mkldnn), 2, 1, 1, 1, 2, added_nodes);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(cpu_bfloat16_pass);
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
#include "paddle/fluid/string/pretty_log.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
using string::PrettyLogDetail;
|
||||
|
||||
void CPUBfloat16PlacementPass::SetMkldnnDataType(
|
||||
ir::Graph* graph, int* bfloat16_operators) const {
|
||||
const auto& op_types_list =
|
||||
Get<std::unordered_set<std::string>>("bfloat16_enabled_op_types");
|
||||
// set mkldnn_data_type to bfloat16 to all operators that are in
|
||||
// bfloat16_enabled_op_types vector or they are included to Bfloat16Placement
|
||||
// pattern
|
||||
GraphPatternDetector gpd;
|
||||
patterns::Bfloat16Placement bfloat16_placement_pattern{gpd.mutable_pattern(),
|
||||
"bfloat16_placement"};
|
||||
bfloat16_placement_pattern(op_types_list);
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern);
|
||||
|
||||
if ((op->Op()->HasAttr("mkldnn_data_type") ||
|
||||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
|
||||
!platform::HasOpINT8DataType(op->Op())) {
|
||||
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
|
||||
(*bfloat16_operators)++;
|
||||
}
|
||||
};
|
||||
gpd(graph, handler);
|
||||
}
|
||||
|
||||
void CPUBfloat16PlacementPass::RemoveOrhanedOperators(
|
||||
ir::Graph* graph, int* bfloat16_operators) const {
|
||||
// find orphaned bfloat16 operator that is between two float32 operators
|
||||
// revert mkldnn_data_type attr to float32
|
||||
GraphPatternDetector gpd;
|
||||
patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(),
|
||||
"orphaned_bfloat16"};
|
||||
orphaned_bfloat16_pattern();
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern);
|
||||
|
||||
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
|
||||
bfloat16_operators--;
|
||||
};
|
||||
gpd(graph, handler);
|
||||
}
|
||||
|
||||
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
|
||||
int bfloat16_operators = 0;
|
||||
SetMkldnnDataType(graph, &bfloat16_operators);
|
||||
RemoveOrhanedOperators(graph, &bfloat16_operators);
|
||||
PrettyLogDetail("--- marked %d operators to bfloat16 ",
|
||||
bfloat16_operators);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(cpu_bfloat16_placement_pass,
|
||||
paddle::framework::ir::CPUBfloat16PlacementPass)
|
||||
// a vector of operator type names with bfloat16 support ("conv2d" etc.)
|
||||
// the second param is the default value for this vector
|
||||
.DefaultPassAttr("bfloat16_enabled_op_types",
|
||||
new std::unordered_set<std::string>());
|
@ -0,0 +1,38 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
/*
|
||||
* Specifies which operators should be run on bfloat16.
|
||||
*/
|
||||
class CPUBfloat16PlacementPass : public Pass {
|
||||
protected:
|
||||
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;
|
||||
|
||||
void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
|
||||
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,132 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h"
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::string& mkldnn_data_type = "float32") {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
|
||||
op->SetType(type);
|
||||
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
|
||||
|
||||
if (type == "conv2d") {
|
||||
op->SetAttr("name", name);
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
} else if (type == "relu") {
|
||||
op->SetInput("X", inputs);
|
||||
} else if (type == "concat") {
|
||||
op->SetAttr("axis", 1);
|
||||
op->SetInput("X", {inputs[0], inputs[1]});
|
||||
} else if (type == "pool2d") {
|
||||
op->SetInput("X", {inputs[0]});
|
||||
} else {
|
||||
FAIL() << "Unexpected operator type.";
|
||||
}
|
||||
op->SetOutput("Out", {outputs[0]});
|
||||
}
|
||||
|
||||
// operator mkldnn_data_type
|
||||
// ---------------------------------------
|
||||
// (a,b)->concat->c float32
|
||||
// c->conv->f float32
|
||||
// f->relu->g float32
|
||||
// g->pool->h float32
|
||||
// h->conv->k float32
|
||||
// k->pool->l float32
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
|
||||
for (auto& v :
|
||||
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l"})) {
|
||||
prog.MutableBlock(0)->Var(v);
|
||||
}
|
||||
|
||||
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
|
||||
SetOp(&prog, "conv2d", "conv1", {"c"}, {"f"});
|
||||
SetOp(&prog, "relu", "relu1", {"f"}, {"g"});
|
||||
SetOp(&prog, "pool2d", "pool1", {"g"}, {"h"});
|
||||
SetOp(&prog, "conv2d", "conv2", {"h"}, {"k"});
|
||||
SetOp(&prog, "pool2d", "pool2", {"k"}, {"l"});
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
|
||||
unsigned expected_bfloat16_data_type_count) {
|
||||
auto prog = BuildProgramDesc();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
|
||||
pass->Set("bfloat16_enabled_op_types",
|
||||
new std::unordered_set<std::string>(bfloat16_enabled_op_types));
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
|
||||
unsigned bfloat16_data_type_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp()) {
|
||||
if (platform::HasOpBFLOAT16DataType(node->Op())) {
|
||||
++bfloat16_data_type_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
|
||||
}
|
||||
|
||||
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
|
||||
auto prog = BuildProgramDesc();
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
|
||||
unsigned bfloat16_data_type_count = 0;
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp()) {
|
||||
if (platform::HasOpBFLOAT16DataType(node->Op())) {
|
||||
++bfloat16_data_type_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
|
||||
}
|
||||
|
||||
TEST(Bfloat16PlacementPass, enable_all) {
|
||||
MainTest({"conv2d", "pool2d", "relu", "concat"}, 6);
|
||||
}
|
||||
|
||||
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
|
||||
// 2 conv2d + 2 pool2 - 1 orphaned conv2d
|
||||
MainTest({"conv2d", "pool2d"}, 3);
|
||||
}
|
||||
|
||||
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(0); }
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(cpu_bfloat16_placement_pass);
|
Loading…
Reference in new issue