Add flattern weight of lstm (#27192)

* add flattern weight of lstm
my_2.0rc
GaoWei8 4 years ago committed by GitHub
parent 7779790c61
commit 36bb056ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,6 +15,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
@ -25,7 +26,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
@ -122,7 +122,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
" cudnn concatenate all the weight to one Tensor")
.AsDispensable();
AddInput("WeightList",
"(vector<Tensor>), stores weight and bias data when the weight "
"use the list format. ")
.AsDispensable()
.AsDuplicable();
AddInput("SequenceLength",
"(Tensor) When the input data is padding, "
"set this parameter. This parameter represents "
@ -216,7 +222,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");
@ -228,7 +233,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
};
SetOutGradDim("Input");
SetOutGradDim("W");
if (ctx->HasInputs("WeightList")) {
ctx->SetOutputsDim(framework::GradVarName("WeightList"),
ctx->GetInputsDim("WeightList"));
}
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
@ -251,7 +259,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("WeightList")) {
op->SetInput("WeightList", this->Input("WeightList"));
}
if (this->HasInput("SequenceLength")) {
op->SetInput("SequenceLength", this->Input("SequenceLength"));
}
@ -262,8 +272,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));
if (this->HasInput("WeightList")) {
op->SetOutput(framework::GradVarName("WeightList"),
this->InputGrad("WeightList", false));
}
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(this->Attrs());
@ -290,3 +304,20 @@ REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);
// TODO(Shixiaowei02) Add ModifyInput support
REGISTER_OP_VERSION(cudnn_lstm)
.AddCheckpoint(
R"ROC(
Upgrade cudnn_lstm add a new input [WeightList] and modify input [W] to dispensable.)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput(
"WeightList",
"The WeightList stores weight and bias data. WeightList is "
"dispensable.")
.NewInput("SequenceLength",
"When the input data is padding, set this parameter. "
"SequenceLength is dispensable.")
.NewOutput("StateOut", "Store the global drop state when training")
.NewOutput("Reserve",
"A temporary output Tensor to store the reserve_data"));

File diff suppressed because it is too large Load Diff

@ -2443,23 +2443,17 @@ def lstm(input,
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
num_dirrection = 2 if is_bidirec == True else 1
for i in range(num_layers):
if i == 0:
input_weight_size = (input_size * hidden_size) * 4
input_weight_size = (input_size * hidden_size) * 4 * num_dirrection
else:
if is_bidirec:
input_weight_size = (hidden_size * 2 * hidden_size) * 4
else:
input_weight_size = (hidden_size * hidden_size) * 4
input_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4
if is_bidirec:
weight_size += (input_weight_size + hidden_weight_size) * 2
weight_size += hidden_size * 8 * 2
else:
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8 * num_dirrection
weight = helper.create_parameter(
attr=helper.param_attr,

@ -20,14 +20,44 @@ import math
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
class RandomWeight:
def __init__(self):
pass
def updata_weight(self, hidden_size, input_size, dtype):
std = 1.0 / math.sqrt(hidden_size)
self.hidden_size = hidden_size
self.input_size = input_size
self.dtype = dtype
self.weight_ih = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size,
self.input_size)).astype(dtype)
self.weight_hh = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size,
self.hidden_size)).astype(dtype)
self.bias_ih = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype)
self.bias_hh = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype)
weight = RandomWeight()
class LayerMixin(object):
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@ -51,16 +81,13 @@ class LSTMCell(LayerMixin):
self.bias = bias
self.dtype = np.float64
self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.ones(
(4 * hidden_size, input_size), dtype=self.dtype)
self.weight_hh = np.ones((4 * hidden_size,
hidden_size)).astype(self.dtype)
self.weight_ih = weight.weight_ih
self.weight_hh = weight.weight_hh
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
if bias:
self.bias_ih = np.ones((4 * hidden_size)).astype(self.dtype)
self.bias_hh = np.ones((4 * hidden_size)).astype(self.dtype)
self.bias_ih = weight.bias_ih
self.bias_hh = weight.bias_hh
self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh
else:
@ -353,24 +380,26 @@ class LSTM(RNNMixin):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp(OpTest):
#TODO(GaoWei8): Need to satisfy the result through the new interface
def get_weight_names(self):
weight_names = []
for i in range(2 * self.num_layers):
weight_names.append('weight{}'.format(i))
for i in range(2 * self.num_layers):
weight_names.append('bias{}'.format(i))
return weight_names
def setUp(self):
self.op_type = "cudnn_lstm"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.set_attrs()
seq_length = 12
batch_size = 5
input_size = 21
hidden_size = 21
input_weight_size = (hidden_size * hidden_size) * 4
hidden_weight_size = (hidden_size * hidden_size) * 4
weight_size = input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size *= self.num_layers
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
@ -379,17 +408,39 @@ class TestCUDNNLstmOp(OpTest):
input[9][3:][:] = 0
input[8][4:][:] = 0
weight.updata_weight(hidden_size, input_size, self.dtype)
rnn1 = LSTM(
input_size,
hidden_size,
self.num_layers,
num_layers=self.num_layers,
time_major=True,
direction="forward")
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
flat_w = np.ones((weight_size)).astype(self.dtype)
flat_w = []
num = 0
for i in range(self.num_layers):
if i == 0:
weight_ih = weight.weight_ih
else:
weight_ih = weight.weight_hh
flat_w.append(("weight" + str(num), weight_ih))
num += 1
for i in range(self.num_layers):
weight_hh = weight.weight_hh
flat_w.append(("weight" + str(num), weight_hh))
num += 1
num = 0
for i in range(self.num_layers):
bias_ih = weight.bias_ih
flat_w.append(("bias" + str(num), bias_ih))
num += 1
for i in range(self.num_layers):
bias_hh = weight.bias_hh
flat_w.append(("bias" + str(num), bias_hh))
num += 1
init_h = np.zeros((self.num_layers, batch_size,
hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers, batch_size,
@ -398,7 +449,7 @@ class TestCUDNNLstmOp(OpTest):
self.inputs = {
'Input': input,
'W': flat_w,
'WeightList': flat_w,
'InitH': init_h,
'InitC': init_c,
'SequenceLength': self.sequence_length
@ -408,7 +459,7 @@ class TestCUDNNLstmOp(OpTest):
'is_bidirec': False,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': 1,
'num_layers': self.num_layers,
}
self.outputs = {
'Out': output,
@ -428,16 +479,42 @@ class TestCUDNNLstmOp(OpTest):
def test_grad_with_place(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place,
set(['Input', 'W', 'InitH', 'InitC']),
['Out', 'LastH', 'LastC'])
var_name_list = self.get_weight_names()
for var_name in var_name_list:
self.check_grad_with_place(
place,
set(['Input', var_name, 'InitH', 'InitC']),
['Out', 'LastH', 'LastC'])
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp2(TestCUDNNLstmOp):
def set_attrs(self):
self.num_layers = 2
class TestCUDNNlstmAPI(unittest.TestCase):
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers,
dropout_prob, False)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size,
hidden_size)).astype("float64")
out = exe.run(fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
@unittest.skipIf(not core.is_compiled_with_cuda(),
@ -448,7 +525,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
num_layers = 2
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],

Loading…
Cancel
Save