|
|
|
@ -292,7 +292,7 @@ class MultiGRUHandler {
|
|
|
|
|
|
|
|
|
|
auto gru_forward_p0 = AcquireGruPrimitive(layer, dir);
|
|
|
|
|
|
|
|
|
|
dnnl::stream astream(engine_);
|
|
|
|
|
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
|
|
|
|
|
gru_forward_p0->execute(astream, gru_args);
|
|
|
|
|
astream.wait();
|
|
|
|
|
return out_mem;
|
|
|
|
@ -315,7 +315,7 @@ class MultiGRUHandler {
|
|
|
|
|
memory_p = std::make_shared<dnnl::memory>(
|
|
|
|
|
gru_pds_[{layer, dir}]->src_iter_desc(), engine_);
|
|
|
|
|
|
|
|
|
|
dnnl::stream astream(engine_);
|
|
|
|
|
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
|
|
|
|
|
dnnl::reorder(user_h0_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
|
|
|
|
|
.execute(astream, user_h0_memory, *memory_p);
|
|
|
|
|
|
|
|
|
@ -354,7 +354,7 @@ class MultiGRUHandler {
|
|
|
|
|
memory_p = std::make_shared<dnnl::memory>(
|
|
|
|
|
gru_pds_[{layer, dir}]->weights_layer_desc(), engine_);
|
|
|
|
|
|
|
|
|
|
dnnl::stream astream(engine_);
|
|
|
|
|
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
|
|
|
|
|
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
|
|
|
|
|
.execute(astream, user_memory, *memory_p);
|
|
|
|
|
|
|
|
|
@ -410,7 +410,7 @@ class MultiGRUHandler {
|
|
|
|
|
memory_p = std::make_shared<dnnl::memory>(
|
|
|
|
|
gru_pds_[{layer, dir}]->weights_iter_desc(), engine_);
|
|
|
|
|
|
|
|
|
|
dnnl::stream astream(engine_);
|
|
|
|
|
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
|
|
|
|
|
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
|
|
|
|
|
.execute(astream, user_memory, *memory_p);
|
|
|
|
|
|
|
|
|
@ -516,7 +516,7 @@ class MultiGRUHandler {
|
|
|
|
|
|
|
|
|
|
auto concat_p = AcquireConcatPrimitive(layer);
|
|
|
|
|
|
|
|
|
|
dnnl::stream astream(engine_);
|
|
|
|
|
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
|
|
|
|
|
concat_p->execute(astream, concat_args);
|
|
|
|
|
astream.wait();
|
|
|
|
|
return out_mem;
|
|
|
|
|