update cudnn rnn weights, test=develop (#23929)

revert-22778-infer_var_type
Xing Wu 5 years ago committed by GitHub
parent 720d18990c
commit f6e8bf0d24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,7 +39,6 @@ class LSTMCell(Layer):
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c) \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t}
h_t &= o_t \\odot tanh(c_t) h_t &= o_t \\odot tanh(c_t)
Args: Args:
hidden_size (integer): The hidden size used in the Cell. hidden_size (integer): The hidden size used in the Cell.
input_size (integer): The input size used in the Cell. input_size (integer): The input size used in the Cell.
@ -64,30 +63,25 @@ class LSTMCell(Layer):
Returns: Returns:
None None
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle import fluid from paddle import fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph.rnn import LSTMCell from paddle.fluid.dygraph.rnn import LSTMCell
import numpy as np import numpy as np
batch_size = 64 batch_size = 64
input_size = 128 input_size = 128
hidden_size = 256 hidden_size = 256
step_input_np = np.random.uniform(-0.1, 0.1, ( step_input_np = np.random.uniform(-0.1, 0.1, (
batch_size, input_size)).astype('float64') batch_size, input_size)).astype('float64')
pre_hidden_np = np.random.uniform(-0.1, 0.1, ( pre_hidden_np = np.random.uniform(-0.1, 0.1, (
batch_size, hidden_size)).astype('float64') batch_size, hidden_size)).astype('float64')
pre_cell_np = np.random.uniform(-0.1, 0.1, ( pre_cell_np = np.random.uniform(-0.1, 0.1, (
batch_size, hidden_size)).astype('float64') batch_size, hidden_size)).astype('float64')
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
else: else:
place = core.CPUPlace() place = core.CPUPlace()
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
cudnn_lstm = LSTMCell(hidden_size, input_size) cudnn_lstm = LSTMCell(hidden_size, input_size)
step_input_var = fluid.dygraph.to_variable(step_input_np) step_input_var = fluid.dygraph.to_variable(step_input_np)
@ -139,12 +133,12 @@ class LSTMCell(Layer):
self._weight_ih = self.create_parameter( self._weight_ih = self.create_parameter(
attr=weight_ih_param_attr, attr=weight_ih_param_attr,
shape=[self._input_size, 4 * self._hidden_size], shape=[4 * self._hidden_size, self._input_size],
dtype=self._dtype) dtype=self._dtype)
self._weight_hh = self.create_parameter( self._weight_hh = self.create_parameter(
attr=weight_hh_param_attr, attr=weight_hh_param_attr,
shape=[self._hidden_size, 4 * self._hidden_size], shape=[4 * self._hidden_size, self._hidden_size],
dtype=self._dtype) dtype=self._dtype)
self._bias_ih = self.create_parameter( self._bias_ih = self.create_parameter(
@ -180,10 +174,10 @@ class LSTMCell(Layer):
def forward(self, input, pre_hidden, pre_cell): def forward(self, input, pre_hidden, pre_cell):
if self._use_cudnn_impl: if self._use_cudnn_impl:
igates = layers.matmul(input, y=self._weight_ih, transpose_y=True)
igates = layers.matmul(input, y=self._weight_ih)
igates = layers.elementwise_add(igates, self._bias_ih) igates = layers.elementwise_add(igates, self._bias_ih)
hgates = layers.matmul(pre_hidden, self._weight_hh) hgates = layers.matmul(
pre_hidden, self._weight_hh, transpose_y=True)
hgates = layers.elementwise_add(hgates, self._bias_hh) hgates = layers.elementwise_add(hgates, self._bias_hh)
chunked_igates = layers.split(igates, num_or_sections=4, dim=1) chunked_igates = layers.split(igates, num_or_sections=4, dim=1)
@ -264,28 +258,23 @@ class GRUCell(Layer):
Returns: Returns:
None None
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle import fluid from paddle import fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph.rnn import GRUCell from paddle.fluid.dygraph.rnn import GRUCell
import numpy as np import numpy as np
batch_size = 64 batch_size = 64
input_size = 128 input_size = 128
hidden_size = 256 hidden_size = 256
step_input_np = np.random.uniform(-0.1, 0.1, ( step_input_np = np.random.uniform(-0.1, 0.1, (
batch_size, input_size)).astype('float64') batch_size, input_size)).astype('float64')
pre_hidden_np = np.random.uniform(-0.1, 0.1, ( pre_hidden_np = np.random.uniform(-0.1, 0.1, (
batch_size, hidden_size)).astype('float64') batch_size, hidden_size)).astype('float64')
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
else: else:
place = core.CPUPlace() place = core.CPUPlace()
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
cudnn_gru = GRUCell(hidden_size, input_size) cudnn_gru = GRUCell(hidden_size, input_size)
step_input_var = fluid.dygraph.to_variable(step_input_np) step_input_var = fluid.dygraph.to_variable(step_input_np)
@ -334,12 +323,12 @@ class GRUCell(Layer):
self._weight_ih = self.create_parameter( self._weight_ih = self.create_parameter(
attr=weight_ih_param_attr, attr=weight_ih_param_attr,
shape=[self._input_size, 3 * self._hidden_size], shape=[3 * self._hidden_size, self._input_size],
dtype=self._dtype) dtype=self._dtype)
self._weight_hh = self.create_parameter( self._weight_hh = self.create_parameter(
attr=weight_hh_param_attr, attr=weight_hh_param_attr,
shape=[self._hidden_size, 3 * self._hidden_size], shape=[3 * self._hidden_size, self._hidden_size],
dtype=self._dtype) dtype=self._dtype)
self._bias_ih = self.create_parameter( self._bias_ih = self.create_parameter(
@ -402,9 +391,10 @@ class GRUCell(Layer):
if self._use_cudnn_impl: if self._use_cudnn_impl:
igates = layers.matmul(input, y=self._weight_ih) igates = layers.matmul(input, y=self._weight_ih, transpose_y=True)
igates = layers.elementwise_add(igates, self._bias_ih) igates = layers.elementwise_add(igates, self._bias_ih)
hgates = layers.matmul(pre_hidden, self._weight_hh) hgates = layers.matmul(
pre_hidden, self._weight_hh, transpose_y=True)
hgates = layers.elementwise_add(hgates, self._bias_hh) hgates = layers.elementwise_add(hgates, self._bias_hh)
chunked_igates = layers.split(igates, num_or_sections=3, dim=1) chunked_igates = layers.split(igates, num_or_sections=3, dim=1)

@ -34,9 +34,9 @@ def tanh(x):
def cudnn_step(step_input_np, pre_hidden_np, weight_ih, bias_ih, weight_hh, def cudnn_step(step_input_np, pre_hidden_np, weight_ih, bias_ih, weight_hh,
bias_hh): bias_hh):
igates = np.matmul(step_input_np, weight_ih) igates = np.matmul(step_input_np, weight_ih.transpose(1, 0))
igates += bias_ih igates += bias_ih
hgates = np.matmul(pre_hidden_np, weight_hh) hgates = np.matmul(pre_hidden_np, weight_hh.transpose(1, 0))
hgates += bias_hh hgates += bias_hh
chunked_igates = np.split(igates, indices_or_sections=3, axis=1) chunked_igates = np.split(igates, indices_or_sections=3, axis=1)

@ -32,7 +32,12 @@ def tanh(x):
return 2. * sigmoid(2. * x) - 1. return 2. * sigmoid(2. * x) - 1.
def cudnn_step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0): def non_cudnn_step(step_in,
pre_hidden,
pre_cell,
gate_w,
gate_b,
forget_bias=1.0):
concat_1 = np.concatenate([step_in, pre_hidden], 1) concat_1 = np.concatenate([step_in, pre_hidden], 1)
gate_input = np.matmul(concat_1, gate_w) gate_input = np.matmul(concat_1, gate_w)
@ -45,12 +50,12 @@ def cudnn_step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0):
return new_hidden, new_cell return new_hidden, new_cell
def non_cudnn_step(step_input_np, pre_hidden_np, pre_cell_np, weight_ih, def cudnn_step(step_input_np, pre_hidden_np, pre_cell_np, weight_ih, bias_ih,
bias_ih, weight_hh, bias_hh): weight_hh, bias_hh):
igates = np.matmul(step_input_np, weight_ih) igates = np.matmul(step_input_np, weight_ih.transpose(1, 0))
igates = igates + bias_ih igates = igates + bias_ih
hgates = np.matmul(pre_hidden_np, weight_hh) hgates = np.matmul(pre_hidden_np, weight_hh.transpose(1, 0))
hgates = hgates + bias_hh hgates = hgates + bias_hh
chunked_igates = np.split(igates, indices_or_sections=4, axis=1) chunked_igates = np.split(igates, indices_or_sections=4, axis=1)
@ -102,7 +107,6 @@ class TestCudnnLSTM(unittest.TestCase):
bias_ih_name = "_bias_ih" bias_ih_name = "_bias_ih"
weight_hh_name = "_weight_hh" weight_hh_name = "_weight_hh"
bias_hh_name = "_bias_hh" bias_hh_name = "_bias_hh"
weight_ih = param_list[weight_ih_name].numpy() weight_ih = param_list[weight_ih_name].numpy()
weight_ih = np.random.uniform( weight_ih = np.random.uniform(
-0.1, 0.1, size=weight_ih.shape).astype('float64') -0.1, 0.1, size=weight_ih.shape).astype('float64')
@ -146,10 +150,9 @@ class TestCudnnLSTM(unittest.TestCase):
named_api_hidden_out = named_api_out[0] named_api_hidden_out = named_api_out[0]
named_api_cell_out = named_api_out[1] named_api_cell_out = named_api_out[1]
np_hidden_out, np_cell_out = non_cudnn_step( np_hidden_out, np_cell_out = cudnn_step(
step_input_np, pre_hidden_np, pre_cell_np, weight_ih, bias_ih, step_input_np, pre_hidden_np, pre_cell_np, weight_ih, bias_ih,
weight_hh, bias_hh) weight_hh, bias_hh)
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
api_hidden_out.numpy(), np_hidden_out, rtol=1e-5, atol=0)) api_hidden_out.numpy(), np_hidden_out, rtol=1e-5, atol=0))
@ -230,7 +233,7 @@ class TestNonCudnnLSTM(unittest.TestCase):
named_api_hidden_out = named_api_out[0] named_api_hidden_out = named_api_out[0]
named_api_cell_out = named_api_out[1] named_api_cell_out = named_api_out[1]
np_hidden_out, np_cell_out = cudnn_step( np_hidden_out, np_cell_out = non_cudnn_step(
step_input_np, pre_hidden_np, pre_cell_np, gate_w, gate_b) step_input_np, pre_hidden_np, pre_cell_np, gate_w, gate_b)
self.assertTrue( self.assertTrue(

Loading…
Cancel
Save