[oneDNN] Added UT for testing elementwise_mul caching (#30203)

* - Added UT for testing elementwise_mul caching

* lint fixes
revert-31562-mean
Jacek Czaja 4 years ago committed by GitHub
parent be5c2e6050
commit 4aba17b5db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1 +1 @@
cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_add_op activation_op softmax_op softmax scope device_context enforce)
cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op softmax scope device_context enforce)

@ -27,6 +27,8 @@
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(elementwise_mul);
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax);
@ -66,8 +68,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
bool inplace = false) {
framework::Scope scope;
std::map<const std::string, int> num_inputs = {
{"softmax", 1}, {"relu", 1}, {"elementwise_add", 2}};
std::map<const std::string, int> num_inputs = {{"softmax", 1},
{"relu", 1},
{"elementwise_add", 2},
{"elementwise_mul", 2}};
std::string first_input = inplace == true ? output_name : "x";
@ -165,5 +169,17 @@ TEST(test_elementwise_add_reuse_cache, cpu_place) {
"Wrong number of cached oneDNN objects"));
}
TEST(test_elementwises_sequence_reuse_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "elementwise_add", dims, "elementwise_add_out", true);
RunOperator<float>(p, "elementwise_mul", dims, "elementwise_add_out", true);
RunOperator<float>(p, "relu", dims, "elementwise_add_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(11), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
} // namespace operators
} // namespace paddle

@ -516,8 +516,8 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(
dev_ctx, framework::vectorize(x->dims()),
uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) {
dev_ctx, framework::vectorize(x->dims()), uniq_name,
(algo == dnnl::algorithm::binary_mul ? "M" : ""))) {
// bradcasting combined with in-place may require
auto rankdiff = x->dims().size() - y->dims().size();
if (rankdiff > 0) {

Loading…
Cancel
Save