[oneDNN] Clearing mkldnn cache in naiveexecutor destructor (#24756)

revert-24981-add_device_attr_for_regulization
Jacek Czaja 5 years ago committed by GitHub
parent 8468dae213
commit 40a5f3fd86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -81,8 +81,7 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, unless explicitly
// (as set in constructor) marked not to do so
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

@ -146,6 +146,11 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
if (WITH_GPU)
set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)
endif()
cc_test(test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS ${TEST_CONV_BN_PASS_DEPS})
cc_test(test_scale_matmul_fuse_pass SRCS mkldnn/scale_matmul_fuse_pass_tester.cc DEPS scale_matmul_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_mkldnn_inplace_pass SRCS mkldnn/mkldnn_inplace_pass_tester.cc DEPS mkldnn_inplace_pass)

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
#include <algorithm>
#include <functional>
#include <string>
#include <vector>
@ -278,9 +279,48 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
// update weights and biases
float epsilon =
BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon, conv_type());
// if bias is an input to other ops as well then we cannot overwrite it
// so we create separate elementwise Y in nodes
if (eltwise_y_in->outputs.size() > 1) {
// Make a copy of eltwise Y input tensor
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(patterns::PDNodeName(
name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count)));
eltwise_y_in_desc.SetShape(
framework::vectorize(eltwise_y_in_tensor->dims()));
eltwise_y_in_desc.SetDataType(eltwise_y_in_tensor->type());
eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel());
eltwise_y_in_desc.SetPersistable(true);
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor_ex =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
// Initialize eltwise_y
TensorCopy(*eltwise_y_in_tensor, platform::CPUPlace(),
eltwise_y_in_tensor_ex);
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor_ex,
epsilon, conv_type());
// Set new var
eltwise->Op()->RenameInput(eltwise_y_in->Name(),
eltwise_y_in_node->Name());
// Link new bias node to eltwise
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise);
// unlink original bias from eltwise_op
eltwise_y_in->outputs.erase(
std::remove_if(eltwise_y_in->outputs.begin(),
eltwise_y_in->outputs.end(),
[&](Node*& n) {
return n->id() == eltwise->id() ? true : false;
}),
eltwise_y_in->outputs.end());
} else {
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon, conv_type());
}
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);

@ -18,6 +18,7 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
@ -49,6 +50,14 @@ Graph* Pass::Apply(Graph* graph) const {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
graph->Get<PassRecorder>(kPassRecorder).insert(Type());
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace());
dev_ctx->ResetBlobMap();
#endif
return graph;
}

@ -118,5 +118,20 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_.swap(ops);
}
NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext *dev_ctx =
(platform::MKLDNNDeviceContext *)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif
}
} // namespace framework
} // namespace paddle

@ -32,6 +32,8 @@ class NaiveExecutor {
public:
explicit NaiveExecutor(const platform::Place& place) : place_(place) {}
~NaiveExecutor();
// Create child scope.
// Create variables.
// @with_feed_fetch_ops: whether to work with the feed and fetch operators.

Loading…
Cancel
Save