|
|
|
@ -23,7 +23,12 @@ USE_OP(softmax);
|
|
|
|
|
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
|
|
|
|
|
USE_OP(elementwise_add);
|
|
|
|
|
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
|
|
|
|
|
USE_OP(leaky_relu);
|
|
|
|
|
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
|
|
|
|
|
USE_OP(gelu);
|
|
|
|
|
USE_OP(relu);
|
|
|
|
|
USE_OP(tanh);
|
|
|
|
|
USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -47,8 +52,14 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
op->SetInput("Input", {inputs[0]});
|
|
|
|
|
op->SetInput("Filter", {inputs[1]});
|
|
|
|
|
op->SetInput("Bias", {inputs[2]});
|
|
|
|
|
} else if (type == "gelu") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "leaky_relu") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "softmax") {
|
|
|
|
|
op->SetAttr("axis", -1);
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
@ -67,7 +78,7 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
|
|
|
|
|
for (auto& v :
|
|
|
|
|
std::vector<std::string>({"a", "weights", "bias", "f", "g", "h", "i",
|
|
|
|
|
"j", "k", "l", "m", "z"})) {
|
|
|
|
|
"j", "k", "l", "m", "n", "z"})) {
|
|
|
|
|
auto* var = prog.MutableBlock(0)->Var(v);
|
|
|
|
|
var->SetType(proto::VarType::SELECTED_ROWS);
|
|
|
|
|
if (v == "weights" || v == "bias") {
|
|
|
|
@ -90,6 +101,18 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
|
|
|
|
|
std::vector<std::string>({"k"}),
|
|
|
|
|
mkldnn_enabled_op.compare("softmax") == 0);
|
|
|
|
|
SetOp(&prog, "tanh", "tanh1", std::vector<std::string>({"k"}),
|
|
|
|
|
std::vector<std::string>({"l"}),
|
|
|
|
|
mkldnn_enabled_op.compare("tanh") == 0);
|
|
|
|
|
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"l"}),
|
|
|
|
|
std::vector<std::string>({"m"}),
|
|
|
|
|
mkldnn_enabled_op.compare("relu") == 0);
|
|
|
|
|
SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector<std::string>({"m"}),
|
|
|
|
|
std::vector<std::string>({"n"}),
|
|
|
|
|
mkldnn_enabled_op.compare("leaky_relu") == 0);
|
|
|
|
|
SetOp(&prog, "gelu", "gelu1", std::vector<std::string>({"n"}),
|
|
|
|
|
std::vector<std::string>({"m"}),
|
|
|
|
|
mkldnn_enabled_op.compare("relu") == 0);
|
|
|
|
|
if (branched == true) {
|
|
|
|
|
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
|
|
|
|
|
std::vector<std::string>({"z"}),
|
|
|
|
@ -113,11 +136,6 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
std::unordered_map<std::string, std::string> input_names;
|
|
|
|
|
std::unordered_map<std::string, std::string> output_names;
|
|
|
|
|
|
|
|
|
|
input_names["softmax"] = "X";
|
|
|
|
|
output_names["softmax"] = "Out";
|
|
|
|
|
input_names["elementwise_add"] = "X";
|
|
|
|
|
output_names["elementwise_add"] = "Out";
|
|
|
|
|
|
|
|
|
|
VLOG(3) << DebugString(graph);
|
|
|
|
|
|
|
|
|
|
for (auto* node : graph->Nodes()) {
|
|
|
|
@ -127,8 +145,9 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
auto ins = op->Inputs();
|
|
|
|
|
auto outs = op->Outputs();
|
|
|
|
|
// Input and output are the same var
|
|
|
|
|
if (ins[input_names[mkldnn_enabled_op]] ==
|
|
|
|
|
outs[output_names[mkldnn_enabled_op]]) {
|
|
|
|
|
// All inplace ops are inplacing input named: X
|
|
|
|
|
// and output : Out
|
|
|
|
|
if (ins["X"] == outs["Out"]) {
|
|
|
|
|
++use_mkldnn_true_count;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -153,6 +172,15 @@ TEST(MKLDNNInplacePass, inplace_elementwise_add) {
|
|
|
|
|
// Two elementwise_add mkl-dnn enabled op instances to be made inplace
|
|
|
|
|
MKLDNNInplacePassTest().MainTest("elementwise_add", false, 1);
|
|
|
|
|
}
|
|
|
|
|
TEST(MKLDNNInplacePass, inplace_tanh) {
|
|
|
|
|
MKLDNNInplacePassTest().MainTest("tanh", false, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(MKLDNNInplacePass, inplace_leaky_relu) {
|
|
|
|
|
// Input of leaky_relu is used as output of subsequent gelu, so no inplace
|
|
|
|
|
// cannot be done
|
|
|
|
|
MKLDNNInplacePassTest().MainTest("leaky_relu", false, 0);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ir
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|