|
|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include <boost/logic/tribool.hpp>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
@ -52,13 +53,9 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
op->SetInput("Input", {inputs[0]});
|
|
|
|
|
op->SetInput("Filter", {inputs[1]});
|
|
|
|
|
op->SetInput("Bias", {inputs[2]});
|
|
|
|
|
} else if (type == "gelu") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "leaky_relu") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
} else if (std::unordered_set<std::string>{"gelu", "leaky_relu", "relu",
|
|
|
|
|
"tanh"}
|
|
|
|
|
.count(type)) {
|
|
|
|
|
op->SetInput("X", inputs);
|
|
|
|
|
} else if (type == "softmax") {
|
|
|
|
|
op->SetAttr("axis", -1);
|
|
|
|
|
@ -100,11 +97,11 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
mkldnn_enabled_op.compare("elementwise_add") == 0);
|
|
|
|
|
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
|
|
|
|
|
std::vector<std::string>({"k"}),
|
|
|
|
|
mkldnn_enabled_op.compare("softmax") == 0);
|
|
|
|
|
mkldnn_enabled_op.compare("relu") == 0);
|
|
|
|
|
SetOp(&prog, "tanh", "tanh1", std::vector<std::string>({"k"}),
|
|
|
|
|
std::vector<std::string>({"l"}),
|
|
|
|
|
mkldnn_enabled_op.compare("tanh") == 0);
|
|
|
|
|
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"l"}),
|
|
|
|
|
SetOp(&prog, "relu", "relu3", std::vector<std::string>({"l"}),
|
|
|
|
|
std::vector<std::string>({"m"}),
|
|
|
|
|
mkldnn_enabled_op.compare("relu") == 0);
|
|
|
|
|
SetOp(&prog, "leaky_relu", "leaky_relu1", std::vector<std::string>({"m"}),
|
|
|
|
|
@ -112,7 +109,7 @@ class MKLDNNInplacePassTest {
|
|
|
|
|
mkldnn_enabled_op.compare("leaky_relu") == 0);
|
|
|
|
|
SetOp(&prog, "gelu", "gelu1", std::vector<std::string>({"n"}),
|
|
|
|
|
std::vector<std::string>({"m"}),
|
|
|
|
|
mkldnn_enabled_op.compare("relu") == 0);
|
|
|
|
|
mkldnn_enabled_op.compare("gelu") == 0);
|
|
|
|
|
if (branched == true) {
|
|
|
|
|
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
|
|
|
|
|
std::vector<std::string>({"z"}),
|
|
|
|
|
|