Add seq2seq api related code (#19820)
parent
e87cabb7f2
commit
dfd1eee7f7
@ -0,0 +1,78 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/gather_tree_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class GatherTreeOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Ids"),
|
||||
"Input(Ids) of GatherTreeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Parents"),
|
||||
"Input(Parents) of GatherTreeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of GatherTreeOp should not be null.");
|
||||
|
||||
auto ids_dims = ctx->GetInputDim("Ids");
|
||||
auto parents_dims = ctx->GetInputDim("Parents");
|
||||
PADDLE_ENFORCE(ids_dims == parents_dims,
|
||||
"The shape of Input(Parents) must be same with the shape of "
|
||||
"Input(Ids).");
|
||||
ctx->SetOutputDim("Out", ids_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("Ids")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class GatherTreeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Ids",
|
||||
"The Tensor with shape [length, batch_size, beam_size] containing "
|
||||
"the selected ids of all time steps.");
|
||||
AddInput("Parents",
|
||||
"The Tensor has the same shape as Ids and contains the parents "
|
||||
"corresponding to selected ids when searching among beams.");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"A Tensor with shape [length, batch_size, beam_size] containing the "
|
||||
"full sequences. The sequences is collected by backtracing from the "
|
||||
"last time step of Ids.");
|
||||
AddComment(R"DOC(
|
||||
GatherTree Operator.
|
||||
|
||||
Backtrace from the last time step and generate the full sequences by collecting beam search
|
||||
selected ids.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel<int32_t>,
|
||||
ops::GatherTreeOpKernel<int64_t>);
|
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/gather_tree_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void GatherTree(const T *ids_data, const T *parents_data,
|
||||
T *out_data, const int64_t max_length,
|
||||
const int64_t batch_size, const int64_t beam_size) {
|
||||
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_size) {
|
||||
int batch = i / beam_size;
|
||||
int beam = i % beam_size;
|
||||
auto idx =
|
||||
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
|
||||
out_data[idx] = ids_data[idx];
|
||||
auto parent = parents_data[idx];
|
||||
for (int step = max_length - 2; step >= 0; step--) {
|
||||
idx = step * batch_size * beam_size + batch * beam_size;
|
||||
out_data[idx + beam] = ids_data[idx + parent];
|
||||
parent = parents_data[idx + parent];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class GatherTreeOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *ids = ctx.Input<Tensor>("Ids");
|
||||
auto *parents = ctx.Input<Tensor>("Parents");
|
||||
auto *out = ctx.Output<Tensor>("Out");
|
||||
|
||||
const auto *ids_data = ids->data<T>();
|
||||
const auto *parents_data = parents->data<T>();
|
||||
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto &ids_dims = ids->dims();
|
||||
int64_t max_length = ids_dims[0];
|
||||
int64_t batch_size = ids_dims[1];
|
||||
int64_t beam_size = ids_dims[2];
|
||||
|
||||
auto &dev_ctx = ctx.cuda_device_context();
|
||||
|
||||
const int block = 512;
|
||||
int max_threads =
|
||||
std::min(static_cast<int64_t>(dev_ctx.GetMaxPhysicalThreadCount()),
|
||||
batch_size * beam_size);
|
||||
const int grid = std::max(max_threads / block, 1);
|
||||
GatherTree<<<grid, block>>>(ids_data, parents_data, out_data, max_length,
|
||||
batch_size, beam_size);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(gather_tree, ops::GatherTreeOpCUDAKernel<int32_t>,
|
||||
ops::GatherTreeOpCUDAKernel<int64_t>);
|
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class GatherTreeOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *ids = ctx.Input<Tensor>("Ids");
|
||||
auto *parents = ctx.Input<Tensor>("Parents");
|
||||
auto *out = ctx.Output<Tensor>("Out");
|
||||
|
||||
const auto *ids_data = ids->data<T>();
|
||||
const auto *parents_data = parents->data<T>();
|
||||
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto &ids_dims = ids->dims();
|
||||
auto max_length = ids_dims[0];
|
||||
auto batch_size = ids_dims[1];
|
||||
auto beam_size = ids_dims[2];
|
||||
|
||||
for (int batch = 0; batch < batch_size; batch++) {
|
||||
for (int beam = 0; beam < beam_size; beam++) {
|
||||
auto idx = (max_length - 1) * batch_size * beam_size +
|
||||
batch * beam_size + beam;
|
||||
out_data[idx] = ids_data[idx];
|
||||
auto parent = parents_data[idx];
|
||||
for (int step = max_length - 2; step >= 0; step--) {
|
||||
idx = step * batch_size * beam_size + batch * beam_size;
|
||||
out_data[idx + beam] = ids_data[idx + parent];
|
||||
parent = parents_data[idx + parent];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class TestGatherTreeOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "gather_tree"
|
||||
max_length, batch_size, beam_size = 5, 2, 2
|
||||
ids = np.random.randint(
|
||||
0, high=10, size=(max_length, batch_size, beam_size))
|
||||
parents = np.random.randint(
|
||||
0, high=beam_size, size=(max_length, batch_size, beam_size))
|
||||
self.inputs = {"Ids": ids, "Parents": parents}
|
||||
self.outputs = {'Out': self.backtrace(ids, parents)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
@staticmethod
|
||||
def backtrace(ids, parents):
|
||||
out = np.zeros_like(ids)
|
||||
(max_length, batch_size, beam_size) = ids.shape
|
||||
for batch in range(batch_size):
|
||||
for beam in range(beam_size):
|
||||
out[max_length - 1, batch, beam] = ids[max_length - 1, batch,
|
||||
beam]
|
||||
parent = parents[max_length - 1, batch, beam]
|
||||
for step in range(max_length - 2, -1, -1):
|
||||
out[step, batch, beam] = ids[step, batch, parent]
|
||||
parent = parents[step, batch, parent]
|
||||
return out
|
||||
|
||||
|
||||
class TestGatherTreeOpAPI(OpTest):
|
||||
def test_case(self):
|
||||
ids = fluid.layers.data(
|
||||
name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False)
|
||||
parents = fluid.layers.data(
|
||||
name='parents',
|
||||
shape=[5, 2, 2],
|
||||
dtype='int64',
|
||||
append_batch_size=False)
|
||||
final_sequences = fluid.layers.gather_tree(ids, parents)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,249 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid.core as core
|
||||
|
||||
from paddle.fluid.executor import Executor
|
||||
from paddle.fluid import framework
|
||||
|
||||
from paddle.fluid.layers.rnn import LSTMCell, GRUCell, RNNCell
|
||||
from paddle.fluid.layers import rnn as dynamic_rnn
|
||||
from paddle.fluid import contrib
|
||||
from paddle.fluid.contrib.layers import basic_lstm
|
||||
import paddle.fluid.layers.utils as utils
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestLSTMCell(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.batch_size = 4
|
||||
self.input_size = 16
|
||||
self.hidden_size = 16
|
||||
|
||||
def test_run(self):
|
||||
inputs = fluid.data(
|
||||
name='inputs', shape=[None, self.input_size], dtype='float32')
|
||||
pre_hidden = fluid.data(
|
||||
name='pre_hidden', shape=[None, self.hidden_size], dtype='float32')
|
||||
pre_cell = fluid.data(
|
||||
name='pre_cell', shape=[None, self.hidden_size], dtype='float32')
|
||||
|
||||
cell = LSTMCell(self.hidden_size)
|
||||
lstm_hidden_new, lstm_states_new = cell(inputs, [pre_hidden, pre_cell])
|
||||
|
||||
lstm_unit = contrib.layers.rnn_impl.BasicLSTMUnit(
|
||||
"basicLSTM", self.hidden_size, None, None, None, None, 1.0,
|
||||
"float32")
|
||||
lstm_hidden, lstm_cell = lstm_unit(inputs, pre_hidden, pre_cell)
|
||||
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
else:
|
||||
place = core.CPUPlace()
|
||||
exe = Executor(place)
|
||||
exe.run(framework.default_startup_program())
|
||||
|
||||
inputs_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.input_size)).astype('float32')
|
||||
pre_hidden_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32')
|
||||
pre_cell_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32')
|
||||
|
||||
param_names = [[
|
||||
"LSTMCell/BasicLSTMUnit_0.w_0", "basicLSTM/BasicLSTMUnit_0.w_0"
|
||||
], ["LSTMCell/BasicLSTMUnit_0.b_0", "basicLSTM/BasicLSTMUnit_0.b_0"]]
|
||||
|
||||
for names in param_names:
|
||||
param = np.array(fluid.global_scope().find_var(names[0]).get_tensor(
|
||||
))
|
||||
param = np.random.uniform(
|
||||
-0.1, 0.1, size=param.shape).astype('float32')
|
||||
fluid.global_scope().find_var(names[0]).get_tensor().set(param,
|
||||
place)
|
||||
fluid.global_scope().find_var(names[1]).get_tensor().set(param,
|
||||
place)
|
||||
|
||||
out = exe.run(feed={
|
||||
'inputs': inputs_np,
|
||||
'pre_hidden': pre_hidden_np,
|
||||
'pre_cell': pre_cell_np
|
||||
},
|
||||
fetch_list=[lstm_hidden_new, lstm_hidden])
|
||||
|
||||
self.assertTrue(np.allclose(out[0], out[1], rtol=1e-4, atol=0))
|
||||
|
||||
|
||||
class TestGRUCell(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.batch_size = 4
|
||||
self.input_size = 16
|
||||
self.hidden_size = 16
|
||||
|
||||
def test_run(self):
|
||||
inputs = fluid.data(
|
||||
name='inputs', shape=[None, self.input_size], dtype='float32')
|
||||
pre_hidden = layers.data(
|
||||
name='pre_hidden',
|
||||
shape=[None, self.hidden_size],
|
||||
append_batch_size=False,
|
||||
dtype='float32')
|
||||
|
||||
cell = GRUCell(self.hidden_size)
|
||||
gru_hidden_new, _ = cell(inputs, pre_hidden)
|
||||
|
||||
gru_unit = contrib.layers.rnn_impl.BasicGRUUnit(
|
||||
"basicGRU", self.hidden_size, None, None, None, None, "float32")
|
||||
gru_hidden = gru_unit(inputs, pre_hidden)
|
||||
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
else:
|
||||
place = core.CPUPlace()
|
||||
exe = Executor(place)
|
||||
exe.run(framework.default_startup_program())
|
||||
|
||||
inputs_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.input_size)).astype('float32')
|
||||
pre_hidden_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32')
|
||||
|
||||
param_names = [
|
||||
["GRUCell/BasicGRUUnit_0.w_0", "basicGRU/BasicGRUUnit_0.w_0"],
|
||||
["GRUCell/BasicGRUUnit_0.w_1", "basicGRU/BasicGRUUnit_0.w_1"],
|
||||
["GRUCell/BasicGRUUnit_0.b_0", "basicGRU/BasicGRUUnit_0.b_0"],
|
||||
["GRUCell/BasicGRUUnit_0.b_1", "basicGRU/BasicGRUUnit_0.b_1"]
|
||||
]
|
||||
|
||||
for names in param_names:
|
||||
param = np.array(fluid.global_scope().find_var(names[0]).get_tensor(
|
||||
))
|
||||
param = np.random.uniform(
|
||||
-0.1, 0.1, size=param.shape).astype('float32')
|
||||
fluid.global_scope().find_var(names[0]).get_tensor().set(param,
|
||||
place)
|
||||
fluid.global_scope().find_var(names[1]).get_tensor().set(param,
|
||||
place)
|
||||
|
||||
out = exe.run(feed={'inputs': inputs_np,
|
||||
'pre_hidden': pre_hidden_np},
|
||||
fetch_list=[gru_hidden_new, gru_hidden])
|
||||
|
||||
self.assertTrue(np.allclose(out[0], out[1], rtol=1e-4, atol=0))
|
||||
|
||||
|
||||
class TestRnn(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.batch_size = 4
|
||||
self.input_size = 16
|
||||
self.hidden_size = 16
|
||||
self.seq_len = 4
|
||||
|
||||
def test_run(self):
|
||||
inputs_basic_lstm = fluid.data(
|
||||
name='inputs_basic_lstm',
|
||||
shape=[None, None, self.input_size],
|
||||
dtype='float32')
|
||||
sequence_length = fluid.data(
|
||||
name="sequence_length", shape=[None], dtype='int64')
|
||||
|
||||
inputs_dynamic_rnn = layers.transpose(inputs_basic_lstm, perm=[1, 0, 2])
|
||||
cell = LSTMCell(self.hidden_size, name="LSTMCell_for_rnn")
|
||||
output, final_state = dynamic_rnn(
|
||||
cell=cell,
|
||||
inputs=inputs_dynamic_rnn,
|
||||
sequence_length=sequence_length,
|
||||
is_reverse=False)
|
||||
output_new = layers.transpose(output, perm=[1, 0, 2])
|
||||
|
||||
rnn_out, last_hidden, last_cell = basic_lstm(inputs_basic_lstm, None, None, self.hidden_size, num_layers=1, \
|
||||
batch_first = False, bidirectional=False, sequence_length=sequence_length, forget_bias = 1.0)
|
||||
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
else:
|
||||
place = core.CPUPlace()
|
||||
exe = Executor(place)
|
||||
exe.run(framework.default_startup_program())
|
||||
|
||||
inputs_basic_lstm_np = np.random.uniform(
|
||||
-0.1, 0.1,
|
||||
(self.seq_len, self.batch_size, self.input_size)).astype('float32')
|
||||
sequence_length_np = np.ones(
|
||||
self.batch_size, dtype='int64') * self.seq_len
|
||||
|
||||
inputs_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.input_size)).astype('float32')
|
||||
pre_hidden_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32')
|
||||
pre_cell_np = np.random.uniform(
|
||||
-0.1, 0.1, (self.batch_size, self.hidden_size)).astype('float32')
|
||||
|
||||
param_names = [[
|
||||
"LSTMCell_for_rnn/BasicLSTMUnit_0.w_0",
|
||||
"basic_lstm_layers_0/BasicLSTMUnit_0.w_0"
|
||||
], [
|
||||
"LSTMCell_for_rnn/BasicLSTMUnit_0.b_0",
|
||||
"basic_lstm_layers_0/BasicLSTMUnit_0.b_0"
|
||||
]]
|
||||
|
||||
for names in param_names:
|
||||
param = np.array(fluid.global_scope().find_var(names[0]).get_tensor(
|
||||
))
|
||||
param = np.random.uniform(
|
||||
-0.1, 0.1, size=param.shape).astype('float32')
|
||||
fluid.global_scope().find_var(names[0]).get_tensor().set(param,
|
||||
place)
|
||||
fluid.global_scope().find_var(names[1]).get_tensor().set(param,
|
||||
place)
|
||||
|
||||
out = exe.run(feed={
|
||||
'inputs_basic_lstm': inputs_basic_lstm_np,
|
||||
'sequence_length': sequence_length_np,
|
||||
'inputs': inputs_np,
|
||||
'pre_hidden': pre_hidden_np,
|
||||
'pre_cell': pre_cell_np
|
||||
},
|
||||
fetch_list=[output_new, rnn_out])
|
||||
|
||||
self.assertTrue(np.allclose(out[0], out[1], rtol=1e-4))
|
||||
|
||||
|
||||
class TestRnnUtil(unittest.TestCase):
|
||||
"""
|
||||
Test cases for rnn apis' utility methods for coverage.
|
||||
"""
|
||||
|
||||
def test_case(self):
|
||||
inputs = {"key1": 1, "key2": 2}
|
||||
func = lambda x: x + 1
|
||||
outputs = utils.map_structure(func, inputs)
|
||||
utils.assert_same_structure(inputs, outputs)
|
||||
try:
|
||||
inputs["key3"] = 3
|
||||
utils.assert_same_structure(inputs, outputs)
|
||||
except ValueError as identifier:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,214 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid.core as core
|
||||
|
||||
from paddle.fluid.executor import Executor
|
||||
from paddle.fluid import framework
|
||||
|
||||
from paddle.fluid.layers.rnn import LSTMCell, GRUCell, RNNCell, BeamSearchDecoder, dynamic_decode
|
||||
from paddle.fluid.layers import rnn as dynamic_rnn
|
||||
from paddle.fluid import contrib
|
||||
from paddle.fluid.contrib.layers import basic_lstm
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EncoderCell(RNNCell):
|
||||
def __init__(self, num_layers, hidden_size, dropout_prob=0.):
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout_prob = dropout_prob
|
||||
self.lstm_cells = []
|
||||
for i in range(num_layers):
|
||||
self.lstm_cells.append(LSTMCell(hidden_size))
|
||||
|
||||
def call(self, step_input, states):
|
||||
new_states = []
|
||||
for i in range(self.num_layers):
|
||||
out, new_state = self.lstm_cells[i](step_input, states[i])
|
||||
step_input = layers.dropout(
|
||||
out, self.dropout_prob) if self.dropout_prob > 0 else out
|
||||
new_states.append(new_state)
|
||||
return step_input, new_states
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
return [cell.state_shape for cell in self.lstm_cells]
|
||||
|
||||
|
||||
class DecoderCell(RNNCell):
|
||||
def __init__(self, num_layers, hidden_size, dropout_prob=0.):
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout_prob = dropout_prob
|
||||
self.lstm_cells = []
|
||||
for i in range(num_layers):
|
||||
self.lstm_cells.append(LSTMCell(hidden_size))
|
||||
|
||||
def attention(self, hidden, encoder_output, encoder_padding_mask):
|
||||
query = layers.fc(hidden,
|
||||
size=encoder_output.shape[-1],
|
||||
bias_attr=False)
|
||||
attn_scores = layers.matmul(
|
||||
layers.unsqueeze(query, [1]), encoder_output, transpose_y=True)
|
||||
if encoder_padding_mask is not None:
|
||||
attn_scores = layers.elementwise_add(attn_scores,
|
||||
encoder_padding_mask)
|
||||
attn_scores = layers.softmax(attn_scores)
|
||||
attn_out = layers.squeeze(
|
||||
layers.matmul(attn_scores, encoder_output), [1])
|
||||
attn_out = layers.concat([attn_out, hidden], 1)
|
||||
attn_out = layers.fc(attn_out, size=self.hidden_size, bias_attr=False)
|
||||
return attn_out
|
||||
|
||||
def call(self,
|
||||
step_input,
|
||||
states,
|
||||
encoder_output,
|
||||
encoder_padding_mask=None):
|
||||
lstm_states, input_feed = states
|
||||
new_lstm_states = []
|
||||
step_input = layers.concat([step_input, input_feed], 1)
|
||||
for i in range(self.num_layers):
|
||||
out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
|
||||
step_input = layers.dropout(
|
||||
out, self.dropout_prob) if self.dropout_prob > 0 else out
|
||||
new_lstm_states.append(new_lstm_state)
|
||||
out = self.attention(step_input, encoder_output, encoder_padding_mask)
|
||||
return out, [new_lstm_states, out]
|
||||
|
||||
|
||||
class TestDynamicDecode(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.batch_size = 4
|
||||
self.input_size = 16
|
||||
self.hidden_size = 16
|
||||
self.seq_len = 4
|
||||
|
||||
def test_run(self):
|
||||
start_token = 0
|
||||
end_token = 1
|
||||
src_vocab_size = 10
|
||||
trg_vocab_size = 10
|
||||
num_layers = 1
|
||||
hidden_size = self.hidden_size
|
||||
beam_size = 8
|
||||
max_length = self.seq_len
|
||||
|
||||
src = layers.data(name="src", shape=[-1, 1], dtype='int64')
|
||||
src_len = layers.data(name="src_len", shape=[-1], dtype='int64')
|
||||
|
||||
trg = layers.data(name="trg", shape=[-1, 1], dtype='int64')
|
||||
trg_len = layers.data(name="trg_len", shape=[-1], dtype='int64')
|
||||
|
||||
src_embeder = lambda x: fluid.embedding(
|
||||
x,
|
||||
size=[src_vocab_size, hidden_size],
|
||||
param_attr=fluid.ParamAttr(name="src_embedding"))
|
||||
|
||||
trg_embeder = lambda x: fluid.embedding(
|
||||
x,
|
||||
size=[trg_vocab_size, hidden_size],
|
||||
param_attr=fluid.ParamAttr(name="trg_embedding"))
|
||||
|
||||
# use basic_lstm
|
||||
encoder_cell = EncoderCell(num_layers, hidden_size)
|
||||
encoder_output, encoder_final_state = dynamic_rnn(
|
||||
cell=encoder_cell,
|
||||
inputs=src_embeder(src),
|
||||
sequence_length=src_len,
|
||||
is_reverse=False)
|
||||
|
||||
src_mask = layers.sequence_mask(
|
||||
src_len, maxlen=layers.shape(src)[1], dtype='float32')
|
||||
encoder_padding_mask = (src_mask - 1.0) * 1000000000
|
||||
encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
|
||||
|
||||
decoder_cell = DecoderCell(num_layers, hidden_size)
|
||||
decoder_initial_states = [
|
||||
encoder_final_state, decoder_cell.get_initial_states(
|
||||
batch_ref=encoder_output, shape=[hidden_size])
|
||||
]
|
||||
|
||||
decoder_output, _ = dynamic_rnn(
|
||||
cell=decoder_cell,
|
||||
inputs=trg_embeder(trg),
|
||||
initial_states=decoder_initial_states,
|
||||
sequence_length=None,
|
||||
encoder_output=encoder_output,
|
||||
encoder_padding_mask=encoder_padding_mask)
|
||||
|
||||
output_layer = lambda x: layers.fc(x,
|
||||
size=trg_vocab_size,
|
||||
num_flatten_dims=len(x.shape) - 1,
|
||||
param_attr=fluid.ParamAttr(
|
||||
name="output_w"),
|
||||
bias_attr=False)
|
||||
|
||||
# inference
|
||||
encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch(
|
||||
encoder_output, beam_size)
|
||||
encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch(
|
||||
encoder_padding_mask, beam_size)
|
||||
beam_search_decoder = BeamSearchDecoder(
|
||||
decoder_cell,
|
||||
start_token,
|
||||
end_token,
|
||||
beam_size,
|
||||
embedding_fn=trg_embeder,
|
||||
output_fn=output_layer)
|
||||
outputs, _ = dynamic_decode(
|
||||
beam_search_decoder,
|
||||
inits=decoder_initial_states,
|
||||
max_step_num=max_length,
|
||||
encoder_output=encoder_output,
|
||||
encoder_padding_mask=encoder_padding_mask)
|
||||
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
else:
|
||||
place = core.CPUPlace()
|
||||
exe = Executor(place)
|
||||
exe.run(framework.default_startup_program())
|
||||
|
||||
src_np = np.random.randint(
|
||||
0, src_vocab_size, (self.batch_size, max_length)).astype('int64')
|
||||
src_len_np = np.ones(self.batch_size, dtype='int64') * max_length
|
||||
trg_np = np.random.randint(
|
||||
0, trg_vocab_size, (self.batch_size, max_length)).astype('int64')
|
||||
trg_len_np = np.ones(self.batch_size, dtype='int64') * max_length
|
||||
|
||||
out = exe.run(feed={
|
||||
'src': src_np,
|
||||
'src_len': src_len_np,
|
||||
'trg': trg_np,
|
||||
'trg_len': trg_len_np
|
||||
},
|
||||
fetch_list=[outputs])
|
||||
|
||||
self.assertTrue(out[0].shape[0] == self.batch_size)
|
||||
self.assertTrue(out[0].shape[1] <= max_length + 1)
|
||||
self.assertTrue(out[0].shape[2] == beam_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue