Merge pull request from tensor-tang/feature/op/fusion_gru

add fusion gru
createGenDocLib
tensor-tang 7 years ago committed by GitHub
commit 89d6d69ce4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -0,0 +1,41 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionGRUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,133 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
from op_test import OpTest
from test_gru_op import gru
from test_fusion_lstm_op import fc, ACTIVATION
def fusion_gru(
x, # T x M
lod, # 1 x N
h0, # N x D
wx, # M x 3D
wh, # D x 3D
bias, # 1 x 3D
is_reverse,
act_state,
act_gate):
return gru(fc(x, wx, bias),
lod,
h0,
wh,
np.zeros(
(1, wh.shape[1]), dtype='float64'),
is_reverse,
act_state,
act_gate)
class TestFusionGRUOp(OpTest):
def set_confs(self):
pass
def setUp(self):
self.op_type = "fusion_gru"
self.lod = [[2, 4, 3]]
self.M = 3
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs()
T = sum(self.lod[0])
N = len(self.lod[0])
x = np.random.rand(T, self.M).astype('float64')
wx = np.random.rand(self.M, 3 * self.D).astype('float64')
wh = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
_, _, _, hidden = fusion_gru(
x, self.lod, h0, wx, wh, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {'Hidden': (hidden, self.lod)}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self):
self.check_output(atol=1e-8)
class TestFusionGRUOpNoInitial(TestFusionGRUOp):
def set_confs(self):
self.with_h0 = False
class TestFusionGRUOpNoBias(TestFusionGRUOp):
def set_confs(self):
self.with_bias = False
class TestFusionGRUOpReverse(TestFusionGRUOp):
def set_confs(self):
self.is_reverse = True
class TestFusionGRUOpMD1(TestFusionGRUOp):
def set_confs(self):
self.M = 36
self.D = 8
class TestFusionGRUOpMD2(TestFusionGRUOp):
def set_confs(self):
self.M = 8
self.D = 8
class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self):
self.lod = [[3]]
self.D = 16
if __name__ == "__main__":
unittest.main()

@ -19,22 +19,19 @@ import numpy as np
import math
import functools
from op_test import OpTest
from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest):
lod = [[2, 4, 3]]
batch_size = sum(lod[0])
frame_size = 5
activate = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
}
@staticmethod
def seq_to_batch(lod, is_reverse):
from test_lstm_op import ACTIVATION
def gru(
input, # T x 3D
lod, # 1 x N
h0, # N x D
weight, # D x 3D
bias, # 1 x 3D
is_reverse,
act_state,
act_gate):
def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_lens = lod[0]
seq_starts = [0]
@ -56,121 +53,125 @@ class TestGRUOp(OpTest):
idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0]
frame_size = w.shape[0]
g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2))
u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
def _step(x, h_p, w, b, act_state, act_gate):
T = x.shape[0]
D = w.shape[0]
g = x + np.tile(b, (T, 1))
w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
u = u_r[:, :D]
r = u_r[:, D:D * 2]
r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size))
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p
return g, r_h_p, h
def gru(self):
input, lod = self.inputs['Input']
w = self.inputs['Weight']
b = self.inputs['Bias'] if 'Bias' in self.inputs else np.zeros(
(1, self.frame_size * 3))
batch_gate = self.outputs['BatchGate']
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev']
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'][
self.sorted_seqs] if 'H0' in self.inputs else np.zeros(
(len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = self.gru_step(x, h_p, w, b)
if batch_idx < (num_batch - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
batch_gate[start_idx:end_idx] = g
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = self.lod
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
frame_size).astype('float64')
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
bias = np.random.rand(1, frame_size * 3).astype('float64')
self.inputs = {
'Input': (input, lod),
'H0': h0,
'Weight': weight,
'Bias': bias
}
T = sum(lod[0])
N = len(lod[0])
D = weight.shape[0]
batch_gate = np.zeros((T, 3 * D), dtype='float64')
batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
batch_hidden = np.zeros((T, D), dtype='float64')
hidden = np.zeros((T, D), dtype='float64')
idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
h_p = h0[sorted_seqs]
max_seq_len = len(idx_in_seq_list)
assert len(idx_in_seq_list[0]) == N
end_idx = 0
for batch_idx in range(max_seq_len):
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
if batch_idx < (max_seq_len - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
batch_gate[start_idx:end_idx] = g
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
self.outputs = {
'BatchGate': np.zeros(
(batch_size, frame_size * 3), dtype='float64'),
'BatchResetHiddenPrev': np.zeros(
(batch_size, frame_size), dtype='float64'),
'BatchHidden': np.zeros(
(batch_size, frame_size), dtype='float64'),
'Hidden': np.zeros(
(batch_size, frame_size), dtype='float64')
}
class TestGRUOp(OpTest):
def set_confs(self):
self.is_reverse = False
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
pass
def setUp(self):
self.op_type = "gru"
self.lod = [[2, 4, 3]]
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs()
self.set_data()
self.gru()
T = sum(self.lod[0])
N = len(self.lod[0])
input = np.random.rand(T, 3 * self.D).astype('float64')
weight = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (hidden, self.lod),
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden,
}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self):
self.check_output()
self.check_output(atol=1e-8)
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoInitial(TestGRUOp):
def set_data(self):
super(TestGRUOpNoInitial, self).set_data()
self.inputs.pop('H0')
def set_confs(self):
self.with_h0 = False
def test_check_grad(self):
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoBias(TestGRUOp):
def set_confs(self):
self.with_bias = False
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])
class TestGRUOpReverse(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
if __name__ == "__main__":

Loading…
Cancel
Save