Merge pull request #12737 from tensor-tang/feature/op/fusion_lstm

add fusion lstm
dataset_flowers_-_change_md5
tensor-tang 7 years ago committed by GitHub
commit e955361267
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,8 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
#include "paddle/fluid/operators/math/fc_compute.h"
namespace paddle {
namespace operators {
@ -110,13 +109,8 @@ void FCOpMaker::Make() {
AddComment(R"DOC(
Fully Connected Operator.
The fully connected operation calculates the output based on the input, weights and bias attribute.
The fully connected operation calculates the output based on the input, weights and bias.
The size of each dimension of the parameters checked in the infer-shape.
The matrix of bias is generated by the mkldnn framework, when the bias_attr is True.
Additional parametrs are use_mkldnn and bias_attr.
The input(X) size and output(Out) size may be diffrent.
The fully connected layer only supports MKLDNN version
)DOC");
}
@ -133,26 +127,15 @@ class FCOpKernel : public framework::OpKernel<T> {
auto in_dims = input->dims();
auto w_dims = w->dims();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
math::FCCompute<platform::CPUDeviceContext, T>(
blas, in_dims[0], w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL);
blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0],
static_cast<T>(1), input_data, w_data, static_cast<T>(0),
output_data);
if (bias) {
const T* bias_data = bias->data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dims[1]);
}
}
// TODO(TJ): fuse act
}
};

File diff suppressed because it is too large Load Diff

@ -0,0 +1,42 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,43 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL) {
blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast<T>(1), X, W,
static_cast<T>(0), Y);
if (B) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle

@ -64,27 +64,47 @@ class TestFCOp(OpTest):
self.check_output()
class TestFCOpBiasBoth(TestFCOp):
class TestFCOpNoBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
for with_bias in {True, False}:
self.with_bias = with_bias
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
self.with_bias = False
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOp1(TestFCOpBiasBoth):
class TestFCOpWithBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
self.with_bias = True
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOp1(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(2, 8, 10, 1, 1)
class TestFCOp2(TestFCOpBiasBoth):
class TestFCOp2(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOp4(TestFCOpBiasBoth):
class TestFCOp4(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(1, 32, 64, 3, 3)
class TestFCOpWithBias1(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(3, 8, 10, 2, 1)
class TestFCOpWithBias2(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOpWithBias3(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(1, 64, 32, 3, 3)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,151 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_lstm_op import lstm, ACTIVATION
def fc(x, w, b):
return np.dot(x, w) + b
def fusion_lstm(
x, # T x M
lod, # 1 x N
wx=None, # M x 4D
bx=None, # 1 x 4D
h0=None, # N x D
c0=None, # N x D
w_h=None, # D x 4D
w_b=None, # 1 x 4D
w_c=None, # 1 x 3D
is_reverse=False,
act_gate=None,
act_cell=None,
act_cand=None):
return lstm(
fc(x, wx, bx), lod, h0, c0, w_h, w_b, w_c, is_reverse, act_gate,
act_cell, act_cand)
class TestLstmOp(OpTest):
def set_argument(self):
self.lod = [[2, 3, 2]]
def setUp(self):
self.op_type = 'fusion_lstm'
self.lod = [[2, 3, 2]]
self.M = 8
self.D = 16
self.has_initial_state = False
self.is_reverse = False
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.use_peepholes = False
self.set_argument()
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float64')
if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float64')
c0 = np.random.normal(size=(bs, self.D)).astype('float64')
else:
h0 = np.zeros((bs, self.D)).astype('float64')
c0 = np.zeros((bs, self.D)).astype('float64')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
# this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate],
ACTIVATION[self.act_cell], ACTIVATION[self.act_cand])
self.inputs = {
'X': (x, self.lod),
'WeightX': wx,
'WeightH': wh,
'Bias': b
}
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
}
self.attrs = {
'use_peepholes': self.use_peepholes,
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
}
def test_check_output(self):
self.check_output(atol=1e-8)
class TestLstmOpInitReverse(TestLstmOp):
def set_argument(self):
self.has_initial_state = True
self.is_reverse = True
class TestLstmOpMD1(TestLstmOp):
def set_argument(self):
self.M = 36
self.D = 8
class TestLstmOpMD2(TestLstmOp):
def set_argument(self):
self.M = 8
self.D = 8
class TestLstmOpMD3(TestLstmOp):
def set_argument(self):
self.M = 15
self.D = 3
class TestLstmOpBS1(TestLstmOp):
def set_argument(self):
self.lod = [[3]]
self.D = 16
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save