Merge pull request #12878 from tensor-tang/feature/op/attention_lstm

Add attention lstm cpu forward
revert-12864-feature/process_lod_grad
tensor-tang 7 years ago committed by GitHub
commit f0f06992c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -138,12 +138,6 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release)
endif()
if(WITH_MKL)
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
if (MKL_SPLIT_GEMM)
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
endif()
endif()
set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,41 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class AttentionLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {

@ -90,6 +90,11 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
template <typename T>
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
@ -109,6 +114,10 @@ class Blas {
void GEMM_FREE(T* data) const;
#endif
template <typename T>
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
T* C) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha,
@ -140,10 +149,19 @@ class Blas {
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
template <typename T>
void VEXP(int n, const T* x, T* y) const;
template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const;
template <typename T>
T DOT(int n, const T* x, const T* y) const;
template <typename T>
void SCAL(int n, const T a, T* x) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C,
@ -215,11 +233,26 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS>
void VEXP(ARGS... args) const {
Base()->template VEXP<T>(args...);
}
template <typename... ARGS>
void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...);
}
template <typename... ARGS>
T DOT(ARGS... args) const {
return Base()->template DOT<T>(args...);
}
template <typename... ARGS>
void SCAL(ARGS... args) const {
Base()->template SCAL<T>(args...);
}
template <typename... ARGS>
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...);

@ -73,6 +73,16 @@ struct CBlas<float> {
platform::dynload::cblas_sgemv(args...);
}
template <typename... ARGS>
static float DOT(ARGS... args) {
return platform::dynload::cblas_sdot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_sscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
@ -87,6 +97,11 @@ struct CBlas<float> {
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
template <typename... ARGS>
static void VEXP(ARGS... args) {
platform::dynload::vsExp(args...);
}
};
template <>
@ -138,6 +153,16 @@ struct CBlas<double> {
platform::dynload::cblas_dgemv(args...);
}
template <typename... ARGS>
static double DOT(ARGS... args) {
return platform::dynload::cblas_ddot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_dscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...);
@ -152,6 +177,11 @@ struct CBlas<double> {
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
template <typename... ARGS>
static void VEXP(ARGS... args) {
platform::dynload::vdExp(args...);
}
};
#else
@ -210,6 +240,9 @@ struct CBlas<platform::float16> {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
@ -217,64 +250,6 @@ struct CBlas<platform::float16> {
#endif
};
template <typename T>
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
bool transb, const T &alpha, const T &beta) {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
std::abs<T>(alpha - static_cast<T>(1) >
std::numeric_limits<T>::epsilon()) ||
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
return false;
} else {
return true;
}
#endif
return false;
}
template <>
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
bool transa, bool transb,
const platform::float16 &alpha,
const platform::float16 &beta) {
return false;
}
template <typename T>
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
const T *A, int lda, const T *B, int ldb, T beta, T *C,
int ldc) {
#ifdef PADDLE_WITH_LIBXSMM
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) {
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc);
return;
}
#endif
#ifdef PADDLE_MKL_SPLIT_GEMM
constexpr int bs = 2;
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
for (int off = 0; off < M; off += bs) {
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
}
return;
}
#endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
#ifdef PADDLE_WITH_MKLML
template <>
template <typename T>
@ -319,8 +294,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
@ -329,9 +304,20 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <typename DeviceContext>
@ -399,6 +385,47 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VEXP(n, x, y);
#else
// try to find if openblas support vexp
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
#endif
}
template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, 1, y, 1);
#else
// try to find if openblas support cblas_dot
T sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] * y[i];
}
return sum;
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
// try to find if openblas support cblas_scal
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
@ -440,6 +467,42 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
const T *A, const T *B, T *C) const {
this->template GEMM<T>(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C,
N);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
const int K, const T *A,
const T *B, T *C) const {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
// Since the matrix is very small,
// so the unit of calculation is already very fast,
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
// use xsmm directly.
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
const T alpha = static_cast<T>(1);
const T beta = static_cast<T>(0);
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
C, &N);
return;
#endif
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,

@ -0,0 +1,105 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
template <typename T>
inline T sigmoid(T x) {
return 1. / (1. + exp(-x));
}
template <typename T>
inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_identity(const int n, const T* x, T* y) {
// do nothing
return;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 1.0 / (1.0 + std::exp(-tmp));
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_tanh(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = tanh<T>(x[i]);
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <>
inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
float* y) {
// TODO(TJ): complete me
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
float* y) {
// TODO(TJ): complete me
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
class VecActivations {
public:
std::function<void(const int, const T*, T*)> operator()(
const std::string& type) {
if (type == "sigmoid") {
return vec_sigmoid<T, isa>;
} else if (type == "relu") {
return vec_relu<T, isa>;
} else if (type == "tanh") {
return vec_tanh<T, isa>;
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
PADDLE_THROW("Not support type %s.", type);
}
};
} // namespace math
} // namespace operators
} // namespace paddle

@ -25,17 +25,25 @@ namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL) {
blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast<T>(1), X, W,
static_cast<T>(0), Y);
if (B) {
const T* B = NULL, bool relu = false) {
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
if (!relu) {
return;
}
// TODO(TJ): fuse relu
LOG(FATAL) << "Not implemented!";
}
} // namespace math

@ -103,15 +103,16 @@ size_t CUDAPinnedMaxChunkSize() {
return CUDAPinnedMaxAllocSize() / 256;
}
#ifdef PADDLE_WITH_XBYAK
namespace jit {
#ifdef PADDLE_WITH_XBYAK
static Xbyak::util::Cpu cpu;
bool MayIUse(const cpu_isa_t cpu_isa) {
using namespace Xbyak::util; // NOLINT
switch (cpu_isa) {
case sse42:
return cpu.has(Cpu::tSSE42);
case avx:
return cpu.has(Cpu::tAVX);
case avx2:
return cpu.has(Cpu::tAVX2);
case avx512_common:
@ -134,8 +135,16 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
}
return false;
}
#else
bool MayIUse(const cpu_isa_t cpu_isa) {
if (cpu_isa == isa_any) {
return true;
} else {
return false;
}
}
#endif
} // namespace jit
#endif
} // namespace platform
} // namespace paddle

@ -37,12 +37,11 @@ size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator.
size_t CUDAPinnedMaxChunkSize();
#ifdef PADDLE_WITH_XBYAK
namespace jit {
typedef enum {
isa_any,
sse42,
avx,
avx2,
avx512_common,
avx512_core,
@ -55,7 +54,6 @@ typedef enum {
inline bool MayIUse(const cpu_isa_t cpu_isa);
} // namespace jit
#endif
} // namespace platform
} // namespace paddle

@ -66,10 +66,16 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \
__macro(cblas_ddot); \
__macro(cblas_sscal); \
__macro(cblas_dscal); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
__macro(vdMul); \
__macro(vsExp); \
__macro(vdExp); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);

@ -0,0 +1,208 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_fusion_lstm_op import fc, ACTIVATION
from test_softmax_op import stable_softmax
def attention_lstm(
x, # T x M
lod, # 1 x N
h0, # N x D
c0, # N x D
fcws, # (M+D) x 1, 1x1
fcbs, # 1 x 1, 1x1
w, # (M+D) x 4D
b, # 1 x 4D
act_gate,
act_cell,
act_cand):
T = sum(lod[0])
N = len(lod[0])
M = x.shape[1]
D = b.shape[1] / 4
assert T == x.shape[0]
assert len(fcws) == len(fcbs)
hidden = []
cell = []
start_offset = 0
for bid in range(N):
seq_len = lod[0][bid]
xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len,
M)
prev_cell = np.copy(c0[bid]).reshape([1, D])
prev_hidden = np.copy(h0[bid]).reshape([1, D])
for step in range(seq_len):
expanded_cell = np.repeat(prev_cell, seq_len, axis=0)
tmp = np.concatenate((xi, expanded_cell), axis=1)
assert tmp.shape[0] == seq_len
assert tmp.shape[1] == M + D
for fcid in range(len(fcbs)):
tmp = fc(tmp, fcws[fcid], fcbs[fcid])
tmp = ACTIVATION['relu'](tmp)
tmp = np.reshape(tmp, (1, seq_len))
tmp = stable_softmax(tmp).reshape(seq_len, 1)
lstmx = xi * tmp # seq * M
lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M])
lstmin = np.concatenate((prev_hidden, lstmx), axis=1)
lstmout = fc(lstmin, w, b).reshape([1, 4 * D])
g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1)
g_f = act_gate(g_f).reshape([1, D])
g_i = act_gate(g_i).reshape([1, D])
g_o = act_gate(g_o).reshape([1, D])
cand = act_cand(cand).reshape([1, D])
cell_t = (prev_cell * g_f) + (g_i * cand)
hidden_t = g_o * act_cell(cell_t)
hidden.append(hidden_t.flatten())
cell.append(cell_t.flatten())
prev_cell = cell_t.reshape([1, D])
prev_hidden = hidden_t.reshape([1, D])
start_offset += seq_len
hidden = np.array(hidden).astype('float32').reshape([T, D])
cell = np.array(cell).astype('float32').reshape([T, D])
return hidden, cell
class TestAttentionLSTMOp(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'attention_lstm'
self.lod = [[3]]
self.M = 30
self.D = 15
self.has_initial_hidden = True
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.set_conf()
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float32')
if self.has_initial_hidden:
h0 = np.random.normal(size=(bs, self.D)).astype('float32')
else:
h0 = np.zeros((bs, self.D)).astype('float32')
fcw1 = np.random.normal(size=(self.M + self.D, 1)).astype('float32')
fcb1 = np.random.normal(size=(1, 1)).astype('float32')
fcw2 = np.random.normal(size=(1, 1)).astype('float32')
fcb2 = np.random.normal(size=(1, 1)).astype('float32')
# lstm weight and bias
w = np.random.normal(size=(self.M + self.D,
self.D * 4)).astype('float32')
b = np.random.normal(size=(1, self.D * 4)).astype('float32')
h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2],
w, b, ACTIVATION[self.act_gate],
ACTIVATION[self.act_cell],
ACTIVATION[self.act_cand])
self.inputs = {
'X': (x, self.lod),
'C0': c0,
'AttentionWeight': fcw1,
'AttentionBias': fcb1,
'AttentionScalar': fcw2,
'AttentionScalarBias': fcb2,
'LSTMWeight': w,
'LSTMBias': b
}
if self.has_initial_hidden:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
}
self.attrs = {
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
}
def test_check_output(self):
self.check_output()
class TestAttentionOpNonInit(TestAttentionLSTMOp):
def set_conf(self):
self.has_initial_hidden = False
class TestAttentionOpAct(TestAttentionLSTMOp):
def set_conf(self):
self.M = 3
self.D = 2
self.act_gate = 'relu'
self.act_cell = 'tanh'
self.act_cand = 'sigmoid'
class TestAttentionOpMD1(TestAttentionLSTMOp):
def set_conf(self):
self.M = 36
self.D = 8
class TestAttentionOpMD2(TestAttentionLSTMOp):
def set_conf(self):
self.M = 8
self.D = 8
class TestAttentionOpMD3(TestAttentionLSTMOp):
def set_conf(self):
self.M = 15
self.D = 30
class TestAttentionOpBS1(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[5]]
self.M = 16
self.D = 32
class TestAttentionOpBS2(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[3, 6]]
class TestAttentionOpBS5(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[3, 2, 4, 7, 5]]
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save