Add sequence_conv_op

revert-4814-Add_sequence_project_op
chengduoZH 8 years ago
parent 0ab2c436ae
commit f2ccef26bf

@ -115,7 +115,8 @@ set(DEPS_OPS
softmax_with_cross_entropy_op
sum_op
pool_op
pool_with_index_op)
pool_with_index_op
sequence_conv_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
@ -126,6 +127,8 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(sequence_conv_op DEPS sequence_project)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})

@ -12,34 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_project_op.h"
#include "paddle/operators/sequence_conv_op.h"
namespace paddle {
namespace operators {
class SequenceProjectOp : public framework::OperatorWithKernel {
class SequenceConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceProjectOp should not be null.");
"Input(X) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceProjectOp should not be null.");
"Output(Out) of SequenceConvOp should not be null.");
// PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() >
// 0 failed, 0 <= 0)
PADDLE_ENFORCE(
ctx->HasInput("PaddingData"),
"Input(PaddingData) of SequenceProjectOp should not be null.");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(in_dims.size() == 2, "Input(X) should be 2-D tensor.");
PADDLE_ENFORCE(ctx->HasInput("PaddingData"),
"Input(PaddingData) of SequenceConvOp should not be null.");
int context_length = ctx->Attrs().Get<int>("context_length");
bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
int context_start = ctx->Attrs().Get<int>("context_start");
auto in_dims = ctx->GetInputDim("X");
auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(
filter_dims[0] == context_length && filter_dims[1] == in_dims[1],
"Filter's shape should be (context_length x "
"number_of_input_features).");
if (padding_trainable) {
framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
int up_pad = std::max(0, -context_start);
@ -60,12 +67,12 @@ class SequenceProjectOp : public framework::OperatorWithKernel {
"and 'context_length'.");
}
in_dims[1] = in_dims[1] * context_length;
in_dims[1] = 1;
ctx->SetOutputDim("Out", in_dims);
}
};
class SequenceProjectGradOp : public framework::OperatorWithKernel {
class SequenceConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@ -77,60 +84,66 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("padding_trainable") &&
ctx->HasOutput(framework::GradVarName("PaddingData"))) {
auto padding_dims = ctx->GetInputDim("PaddingData");
ctx->SetOutputDim(framework::GradVarName("PaddingData"), padding_dims);
ctx->SetOutputDim(framework::GradVarName("PaddingData"),
ctx->GetInputDim("PaddingData"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"),
ctx->GetInputDim("Filter"));
}
}
};
class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceProjectOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequenceConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(A float LoDTensor) the input of SequenceProjectOp, a vector of "
"(A float LoDTensor) the input of SequenceConvOp, a vector of "
"2-D matrix of size (minibatch, number_of_input_features).");
AddOutput("Out",
"(A float LoDTensor) the output of SequenceProjectOp, a vector "
"of 2-D matrix of size (minibatch, number_of_input_features x "
"context_length).");
AddInput("PaddingData",
"(A float LoDTensor) the input of SequenceProjectOp, a vector of "
"(A float LoDTensor) the input of SequenceConvOp, a vector of "
"2-D matrix of size (up_pad + down_pad, "
"number_of_input_features). ");
AddInput("Filter",
"(A float LoDTensor) the input of SequenceConvOp, a vector of "
"2-D matrix of size (context_length x number_of_input_features).");
AddOutput("Out",
"(A float LoDTensor) the output of SequenceConvOp, a vector "
"of 2-D matrix of size (minibatch, 1).");
AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceProjectOp "
"(bool, default false) the padding data of SequenceConvOp "
"is trainable or not.")
.SetDefault(false);
AddAttr<int>("context_length",
"(int, default 3) the context_length of SequenceProjectOp.")
"(int, default 3) the context_length of SequenceConvOp.")
.SetDefault(3)
.GreaterThan(0);
AddAttr<int>("context_start",
"(int, default 0) the context_start of SequenceProjectOp.")
"(int, default 0) the context_start of SequenceConvOp.")
.SetDefault(0);
AddAttr<int>("context_stride",
"(int, default 1) the context_stride of SequenceProjectOp. "
"(int, default 1) the context_stride of SequenceConvOp. "
"Currently, sequence_project_op only support "
"context_stride=1.")
.SetDefault(1)
.GreaterThan(0);
AddComment(R"DOC(
SequenceProjectOp projects features of context_length time-steps of each instance.
SequenceConvOp projects features of context_length time-steps of each instance.
For a mini-batch of 2 variable lengths sentences, containing 3, and 1 time-steps:
Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3, 4].
Besides, for the sake of simplicity, we assume M=1 and N=2.
X = [[a1, a2,
b1, b2.
X = [[a1, a2;
b1, b2;
c1, c2]
[d1, d2]]
@ -141,19 +154,19 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
If context_start is -1 and padding_trainable is false, we use zero to pad instead of learned weight to pad,
and the context_lenth is 3, the output (Out) is:
Out = [0, 0, a1, a2, b1, b2;
Out =[[0, 0, a1, a2, b1, b2;
a1, a2, b1, b2, c1, c2;
b1, b2, c1, c2, 0, 0;
0, 0, d1, d2, 0, 0]
b1, b2, c1, c2, 0, 0 ]
[0, 0, d1, d2, 0, 0 ]]
- Case2:
If context_start is -1 and padding_trainable is true, we use learned weight to pad,
and the context_lenth is 3, the output (Out) is:
Out = [w1, w2, a1, a2, b1, b2;
Out = [[w1, w2, a1, a2, b1, b2;
a1, a2, b1, b2, c1, c2;
b1, b2, c1, c2, w3, w4;
w1, w2, d1, d2, w3, w4]
b1, b2, c1, c2, w3, w4]
[w1, w2, d1, d2, w3, w4]]
)DOC");
}
@ -163,13 +176,11 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_project, ops::SequenceProjectOp,
ops::SequenceProjectOpMaker, sequence_project_grad,
ops::SequenceProjectGradOp);
REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
sequence_conv_grad, ops::SequenceConvGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_project,
ops::SequenceProjectKernel<paddle::platform::CPUPlace, float>);
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_project_grad,
ops::SequenceProjectGradKernel<paddle::platform::CPUPlace, float>);
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>);

@ -14,12 +14,11 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_project_op.h"
#include "paddle/operators/sequence_conv_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_project,
ops::SequenceProjectKernel<paddle::platform::GPUPlace, float>);
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_project_grad,
ops::SequenceProjectGradKernel<paddle::platform::GPUPlace, float>);
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>);

@ -1,212 +0,0 @@
import unittest
import numpy as np
import random
from op_test import OpTest
class TestSeqProject(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_project'
if self.context_length == 1 and self.context_start == 0 and self.padding_trainable:
print "If context_start is 0 and context_length is 1, padding_trainable should be false."
return
# one level, batch size
x = np.random.uniform(
0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32')
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max([0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
if self.total_pad == 0:
self.total_pad = 1
# PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() > 0 failed, 0 <= 0)
padding_data = np.random.uniform(
0.1, 1, [self.total_pad, self.input_size[1]]).astype('float32')
self.inputs = {
'X': (x, self.lod),
'PaddingData': (padding_data, [[0, self.total_pad]])
}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
self.outputs = {'Out': out}
self.compute()
def compute(self):
x, lod = self.inputs['X']
pading_data, _ = self.inputs['PaddingData']
out = self.outputs['Out']
lod = lod[0]
begin_pad = np.max([0, -self.context_start])
for i in range(len(lod) - 1):
for j in range(self.context_length):
in_begin = lod[i] + self.context_start + j
in_end = lod[i + 1] + self.context_start + j
out_begin = lod[i]
out_end = lod[i + 1]
if in_begin < lod[i]:
pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = pading_data[j:j + pad_size, :]
out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:(
j + 1) * self.input_size[1]] = sub_w
out_begin = lod[i] + pad_size
in_begin = lod[i]
if in_end > lod[i + 1]:
pad_size = np.min(
[in_end - lod[i + 1], lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = pading_data[begin_pad + self.context_start + j -
pad_size:begin_pad +
self.context_start + j, :]
out[lod[i + 1] - pad_size:lod[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
in_end = lod[i + 1]
out_end = lod[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.padding_trainable:
self.check_grad(
set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['X'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['PaddingData']))
def test_check_grad_no_input(self):
if self.padding_trainable:
self.check_grad(
['PaddingData'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['X']))
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11
self.context_start = 0
self.context_length = 1
self.padding_trainable = False
self.context_stride = 1
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11
self.context_start = -1
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 25
self.context_start = 2
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1
self.input_size = [self.input_row, 23]
idx = range(self.input_size[0])
del idx[0]
self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
[self.input_size[0]]]
'''
class TestSeqProjectCases(TestSeqProject):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_project'
num = 0
for context_start in [-5, -3, -1, 0, 3]:
for context_length in [1, 2, 5, 7]:
for batch_size in [1, 2, 5, 7]:
for padding_trainable in [False, True]:
if context_length == 1 and context_start == 0 and padding_trainable:
continue
self.context_start = context_start
self.context_length = context_length
self.padding_trainable = padding_trainable
self.input_size = [batch_size, 23]
x = np.random.uniform(0.1, 1,
self.input_size).astype('float32')
self.lod = [[0, self.input_size[0]]]
if self.input_size[0] > 2:
idx = range(self.input_size[0])
del idx[0]
self.lod = [
[0] + np.sort(random.sample(idx, 2)).tolist() +
[self.input_size[0]]
]
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max([0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
if self.total_pad == 0:
self.total_pad = 1
# PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() > 0 failed, 0 <= 0)
padding_data = np.random.uniform(
0.1, 1, [self.total_pad, self.input_size[1]]).astype('float32')
self.inputs = {
'X': (x, self.lod),
'PaddingData': (padding_data, [[0, self.total_pad]])
}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
self.outputs = {'Out': out}
print num
print self.attrs
print batch_size
print padding_trainable
print "$$$$$$$$$$$$$"
self.compute()
self.test_check_output()
num += 1
'''
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save