Added vanilla LSTM and LSTM with peepholes oneDNN fp32 kernel (#30661)

* added external reorder to profiler

* resolved conflict

* added enable_static

* initial version of lstm, not working yet

* added lstm to operators.cmake

* added vanilla lstm mkldnn op

* added peephole weights integration

* minor changes

* added formatting

* added fusion_lstm_mkldnn to static_whitelist

* added formatting

* removed comment

* moved use_peepholes attribute inside is_cached block

* reverted wrong changes

* minor formatting change

* minor changes
revert-31068-fix_conv3d_windows
jakpiase 4 years ago committed by GitHub
parent 1a13626f5f
commit d834f4e6e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -197,7 +197,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)

@ -14,11 +14,15 @@ register_operators(EXCLUDES
fused_embedding_eltwise_layernorm_op
fusion_group_op
fusion_gru_op
fusion_lstm_op
fused_bn_add_activation_op)
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n")
op_library(fusion_lstm_op)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\nUSE_CPU_ONLY_OP(fusion_lstm);\n")
if (WITH_GPU)
# fused_bn_activation_op needs cudnn 7.4.1 above

@ -18,6 +18,9 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
@ -145,8 +148,17 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
}
void FusionLSTMOpMaker::Make() {
@ -235,6 +247,9 @@ void FusionLSTMOpMaker::Make() {
"`tanh` by default.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.

File diff suppressed because it is too large Load Diff

@ -0,0 +1,229 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using paddle::framework::LoDTensor;
using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast;
template <typename T, typename T_alg, typename T_out = T>
class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
public:
RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const LoDTensor* input,
const Tensor* weight_h, const Tensor* h0,
const bool is_reverse, const int64_t N, const int64_t Ti,
const int64_t IC, const int64_t OC, const int64_t G,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, T_alg>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
N(N),
Ti(Ti),
IC(IC),
OC(OC),
G(G) {
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
// Is it int8 kernel
const bool is_INT8 = std::is_same<T, uint8_t>::value;
if (is_INT8) {
// Int8 attributes
const float scale_data = ctx.Attr<float>("Scale_data");
const float shift_data = ctx.Attr<float>("Shift_data");
const auto scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
const int weights_scale_mask =
0 +
(1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
+
(1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`
attr_.set_rnn_data_qparams(scale_data, shift_data);
attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights);
}
}
bool is_NTC() {
return (platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc()) ==
dnnl::memory::format_tag::ntc);
}
void reorderRNNdata(void* input_data, void* output_data,
std::vector<size_t> lod, const bool is_reverse,
platform::RNNReorderType reorder_type) {
switch (reorder_type) {
// Reorder input memory [WORDS, C] + LoD -> [N, T, C]
case platform::RNNReorderType::PP_NTC: {
auto* input_data_iter = reinterpret_cast<T*>(input_data);
auto* output_data_iter = reinterpret_cast<T*>(output_data);
for (int n = 0; n < N; ++n) {
const auto num_elements = (lod[n + 1] - lod[n]) * IC;
const auto offset = is_reverse ? (Ti * IC - num_elements) : 0;
memcpy(output_data_iter + n * Ti * IC + offset, input_data_iter,
sizeof(T) * num_elements);
input_data_iter += num_elements;
}
} break;
// Reorder input memory [WORDS, C] + LoD -> [T, N, C]
case platform::RNNReorderType::PP_TNC: {
auto* input_data_iter = reinterpret_cast<T*>(input_data);
auto* output_data_iter = reinterpret_cast<T*>(output_data);
for (int n = 0; n < N; ++n) {
const auto num_elements = (lod[n + 1] - lod[n]);
const auto offset = is_reverse ? (Ti - num_elements) : 0;
for (size_t t = 0; t < num_elements; ++t) {
memcpy(output_data_iter + (t + offset) * N * IC + n * IC,
input_data_iter, sizeof(T) * IC);
input_data_iter += IC;
}
}
} break;
// Reorder output values to PP format [N, T, C] -> [WORDS, C]
case platform::RNNReorderType::NTC_PP: {
auto* input_data_iter = reinterpret_cast<T_out*>(input_data);
auto* output_data_iter = reinterpret_cast<T_out*>(output_data);
for (int n = 0; n < N; ++n) {
const auto num_elements = (lod[n + 1] - lod[n]) * OC;
const auto offset = is_reverse ? (Ti * OC - num_elements) : 0;
memcpy(output_data_iter, input_data_iter + n * Ti * OC + offset,
sizeof(T_out) * num_elements);
output_data_iter += num_elements;
}
} break;
// Reorder output values to PP format [T, N, C] -> [WORDS, C]
case platform::RNNReorderType::TNC_PP: {
auto* input_data_iter = reinterpret_cast<T_out*>(input_data);
auto* output_data_iter = reinterpret_cast<T_out*>(output_data);
for (int n = 0; n < N; ++n) {
const auto num_elements = lod[n + 1] - lod[n];
const auto offset = is_reverse ? (Ti - num_elements) : 0;
for (size_t t = 0; t < num_elements; ++t) {
memcpy(output_data_iter,
input_data_iter + (t + offset) * N * OC + n * OC,
sizeof(T_out) * OC);
output_data_iter += OC;
}
}
} break;
}
}
std::shared_ptr<dnnl::memory> AcquireInputMemoryWithReorder(
const LoDTensor* input, const bool is_reverse) {
const auto name = this->key_ + "@input_mem";
auto memory_p =
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(name));
if (!memory_p) {
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_desc(),
this->engine_);
this->dev_ctx_.SetBlob(name, memory_p);
}
const auto& input_lod = input->lod()[0];
auto* x_data = to_void_cast(input->data<T>());
auto* x_onednn_data = memory_p->get_data_handle();
memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC);
if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) ==
dnnl::memory::format_tag::ntc) {
reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse,
platform::RNNReorderType::PP_NTC);
} else {
reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse,
platform::RNNReorderType::PP_TNC);
}
return memory_p;
}
std::shared_ptr<dnnl::memory> AcquireOutputMemory() {
const auto name = this->key_ + "@output_mem";
auto memory_p =
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(name));
if (!memory_p) {
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->dst_desc(),
this->engine_);
this->dev_ctx_.SetBlob(name, memory_p);
}
return memory_p;
}
// TODO(grygielski) H0 is for now persistable
// TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does
// not support in yet)
std::shared_ptr<dnnl::memory> AcquireH0Memory(const Tensor* h0) {
const std::string h0_key = memory_key_ + "@h0";
auto memory_p =
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(h0_key));
if (!memory_p) {
auto user_h0_memory = dnnl::memory();
if (h0) {
user_h0_memory =
dnnl::memory({{1, 1, N, OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
this->engine_, to_void_cast(h0->data<float>()));
} else {
user_h0_memory = dnnl::memory({{1, 1, N, OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
this->engine_);
memset(user_h0_memory.get_data_handle(), 0, sizeof(float) * N * OC);
}
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_iter_desc(),
this->engine_);
dnnl::stream astream(this->engine_);
dnnl::reorder(user_h0_memory, *memory_p, attr_)
.execute(astream, user_h0_memory, *memory_p);
this->dev_ctx_.SetBlob(h0_key, memory_p);
}
return memory_p;
}
protected:
// RNN dimensions
// N - Batch Size
// Ti - Max sentence length
// IC - Input Channels
// OC - Output Channels
// G - Number of gates
const int64_t N, Ti, IC, OC, G;
// Memory size of weights, bias and h0 does not depend
// on Ti size, thus we need another key to cache them
std::string memory_key_;
dnnl::primitive_attr attr_;
};
} // namespace operators
} // namespace paddle

@ -75,4 +75,6 @@ class TestFusionGRUMKLDNNOpBS1(TestFusionGRUOp):
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()

@ -0,0 +1,81 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from paddle.fluid.tests.unittests.test_fusion_lstm_op import TestFusionLSTMOp
class TestFusionLSTMONEDNNOp(TestFusionLSTMOp):
def set_conf(self):
self.use_mkldnn = True
def test_check_output(self):
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output(check_dygraph=False, no_check_set=["Cell"])
class TestFusionLSTMONEDNNOpReverse(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.is_reverse = True
self.use_mkldnn = True
class TestFusionLSTMONEDNNOpInitReverse(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.has_initial_state = True
self.is_reverse = True
self.use_mkldnn = True
class TestFusionLSTMONEDNNOpMD1(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.M = 36
self.D = 8
self.use_mkldnn = True
class TestFusionLSTMONEDNNOpMD2(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.M = 8
self.D = 8
self.use_mkldnn = True
class TestFusionLSTMONEDNNOpMD3(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.M = 15
self.D = 3
self.use_mkldnn = True
class TestFusionLSTMONEDNNOpBS1(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.lod = [[3]]
self.D = 16
self.use_mkldnn = True
class TestFusionLSTMONEDNNOpPeepholesInit(TestFusionLSTMONEDNNOp):
def set_conf(self):
self.use_peepholes = True
self.has_initial_state = True
self.use_mkldnn = True
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()

@ -144,4 +144,6 @@ class TestFusionGRUOpBS1(TestFusionGRUOp):
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()

@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.use_mkldnn = False
self.set_conf()
T = sum(self.lod[0])
@ -110,7 +111,8 @@ class TestFusionLSTMOp(OpTest):
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
'candidate_activation': self.act_cand,
'use_mkldnn': self.use_mkldnn
}
def test_check_output(self):
@ -191,4 +193,6 @@ class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()

@ -29,4 +29,5 @@ no_check_set_white_list = [
'update_loss_scaling',
'cudnn_lstm',
'rnn',
'fusion_lstm',
]

@ -601,6 +601,7 @@ STATIC_MODE_TESTING_LIST = [
'test_bilinear_interp_mkldnn_op',
'test_fusion_gru_int8_mkldnn_op',
'test_fusion_gru_mkldnn_op',
'test_fusion_lstm_mkldnn_op',
'test_gaussian_random_mkldnn_op',
'test_lrn_mkldnn_op',
'test_matmul_mkldnn_op',

Loading…
Cancel
Save