Merge pull request #4814 from chengduoZH/Add_sequence_project_op
Add sequence_conv_op and sequence_projection functorrevert-4814-Add_sequence_project_op
commit
8e3ecf5d11
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/context_project.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class ContextProjectFunctor<platform::CPUPlace, float>;
|
||||
template class ContextProjectFunctor<platform::CPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,28 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/math/context_project.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class ContextProjectFunctor<platform::GPUPlace, float>;
|
||||
template class ContextProjectFunctor<platform::GPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,231 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/operators/math/im2col.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
/*
|
||||
* \brief Context projection concatenate features in adjacent time steps in
|
||||
* a sequence. The i-th row of the output is the concatenation of
|
||||
* context_length rows of the input. The context_length rows are the
|
||||
* consecutive rows from the i+shift_start row.
|
||||
|
||||
* \param in Input data.
|
||||
* \param Shape The shape of Input data,
|
||||
* [minibatch, number_of_input_features].
|
||||
* \param type A float LoDTensor.
|
||||
*
|
||||
* \param padding_data Padding data.
|
||||
* \param Shape The shape of Padding data,
|
||||
* [up_pad + down_pad, number_of_input_features].
|
||||
* \param type A float Tensor.
|
||||
*
|
||||
* \param col Col data.
|
||||
* \param Shape The shape of Col data,
|
||||
* [minibatch, context_length * number_of_input_features].
|
||||
* \param type A float Tensor.
|
||||
*
|
||||
* For a mini-batch of 2 variable lengths sentences, containing 3, and 1
|
||||
* time-steps:
|
||||
*
|
||||
* Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3,
|
||||
* 4].
|
||||
* Besides, for the sake of simplicity, we assume M=1 and N=2.
|
||||
*
|
||||
* X = [[a1, a2;
|
||||
* b1, b2;
|
||||
* c1, c2]
|
||||
* [d1, d2]]
|
||||
*
|
||||
* This is to say that input (X) has 4 words and the dimension of each word
|
||||
* representation is 2.
|
||||
*
|
||||
* - Case1:
|
||||
* If context_start is -1 and padding_trainable is false, we use zero to pad
|
||||
* instead of learned weight to pad,
|
||||
* and the context_lenth is 3, the output (Out) is:
|
||||
*
|
||||
* Out =[[0, 0, a1, a2, b1, b2;
|
||||
* a1, a2, b1, b2, c1, c2;
|
||||
* b1, b2, c1, c2, 0, 0 ]
|
||||
* [0, 0, d1, d2, 0, 0 ]]
|
||||
*
|
||||
* - Case2:
|
||||
* If context_start is -1 and padding_trainable is true, we use learned weight
|
||||
* to pad,
|
||||
* and the context_lenth is 3, the output (Out) is:
|
||||
*
|
||||
* Out = [[w1, w2, a1, a2, b1, b2;
|
||||
* a1, a2, b1, b2, c1, c2;
|
||||
* b1, b2, c1, c2, w3, w4]
|
||||
* [w1, w2, d1, d2, w3, w4]]
|
||||
*
|
||||
*/
|
||||
|
||||
template <typename Place, typename T>
|
||||
class ContextProjectFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
framework::LoDTensor& in, framework::Tensor& padding_data,
|
||||
framework::Tensor& col, bool padding_trainable,
|
||||
int context_start, int context_length, int context_stride,
|
||||
int up_pad, int down_pad, bool gradient, bool input_grad,
|
||||
bool pad_grad) {
|
||||
auto lod_level_0 = in.lod()[0];
|
||||
|
||||
paddle::operators::math::Im2ColFunctor<
|
||||
paddle::operators::math::ColFormat::kOCF, Place, float>
|
||||
im2col_ocf;
|
||||
paddle::operators::math::Col2ImFunctor<
|
||||
paddle::operators::math::ColFormat::kOCF, Place, float>
|
||||
col2im_ocf;
|
||||
|
||||
int input_row_begin, input_row_end;
|
||||
int sequence_height, sequence_width;
|
||||
sequence_width = in.dims()[1];
|
||||
input_grad = gradient && input_grad;
|
||||
pad_grad = gradient && pad_grad;
|
||||
|
||||
if (!gradient || input_grad) {
|
||||
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
|
||||
input_row_begin = (context_start > 0)
|
||||
? static_cast<int>(lod_level_0[i]) + context_start
|
||||
: static_cast<int>(lod_level_0[i]);
|
||||
input_row_end = static_cast<int>(lod_level_0[i + 1]);
|
||||
|
||||
framework::Tensor out_t =
|
||||
col.Slice(static_cast<int>(lod_level_0[i]),
|
||||
static_cast<int>(lod_level_0[i + 1]));
|
||||
|
||||
sequence_height = static_cast<int>(out_t.dims()[0]);
|
||||
|
||||
if (input_row_begin < input_row_end) {
|
||||
framework::Tensor in_t = in.Slice(input_row_begin, input_row_end);
|
||||
|
||||
std::vector<int64_t> output_shape(
|
||||
{sequence_height, 1, 1, context_length,
|
||||
sequence_width}); // output_height, output_width,
|
||||
// input_channels, filter_height, filter_width
|
||||
|
||||
out_t.Resize(framework::make_ddim(output_shape));
|
||||
|
||||
std::vector<int64_t> input_shape(
|
||||
{1, input_row_end - input_row_begin,
|
||||
sequence_width}); // input_channels, input_height, input_width
|
||||
in_t.Resize(framework::make_ddim(input_shape));
|
||||
|
||||
if (gradient) {
|
||||
col2im_ocf(context, in_t, out_t,
|
||||
/*stride_height*/ context_stride, /*stride_width*/ 1,
|
||||
up_pad, down_pad, 0, 0);
|
||||
} else {
|
||||
im2col_ocf(context, in_t, out_t,
|
||||
/*stride_height*/ context_stride, /*stride_width*/ 1,
|
||||
up_pad, down_pad, 0, 0);
|
||||
}
|
||||
out_t.Resize({sequence_height, context_length * sequence_width});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!gradient || pad_grad) {
|
||||
if (padding_trainable) {
|
||||
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
|
||||
framework::Tensor out_t =
|
||||
col.Slice(static_cast<int>(lod_level_0[i]),
|
||||
static_cast<int>(lod_level_0[i + 1]));
|
||||
|
||||
sequence_height = static_cast<int>(out_t.dims()[0]);
|
||||
|
||||
// add up trainable data
|
||||
out_t.Resize({sequence_height * context_length, sequence_width});
|
||||
|
||||
if (up_pad > 0) { // add up pad
|
||||
int padding_rows = std::min(
|
||||
up_pad, static_cast<int>(lod_level_0[i + 1] - lod_level_0[i]));
|
||||
|
||||
for (int k = 0; k < padding_rows; ++k) {
|
||||
int padding_size =
|
||||
k + context_length < up_pad ? context_length : up_pad - k;
|
||||
framework::Tensor out_t_sub = out_t.Slice(
|
||||
k * context_length, k * context_length + padding_size);
|
||||
framework::Tensor w_sub = padding_data.Slice(k, k + padding_size);
|
||||
// in this block, using EigenVector<T>::Flatten is ok too.
|
||||
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
|
||||
auto w_sub_e = EigenMatrix<T>::From(w_sub);
|
||||
if (gradient) {
|
||||
w_sub_e.device(*context.GetEigenDevice<Place>()) =
|
||||
w_sub_e + out_t_sub_e;
|
||||
} else {
|
||||
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (down_pad > 0) { // add down pad
|
||||
int down_pad_begin_row =
|
||||
std::max(
|
||||
0, (sequence_height - context_start - context_length) + 1) +
|
||||
1;
|
||||
int padding_begin = std::max(0, context_start - sequence_height);
|
||||
int padding_size =
|
||||
sequence_height - context_start >= context_length
|
||||
? 1
|
||||
: context_length - (sequence_height - context_start);
|
||||
if (context_start >= sequence_height) padding_size = context_length;
|
||||
int padding_idx = padding_begin;
|
||||
for (int t = 0; t + down_pad_begin_row <= sequence_height;
|
||||
++t, ++padding_size) {
|
||||
if (context_start >= sequence_height)
|
||||
padding_size = context_length;
|
||||
if (padding_size > context_length) {
|
||||
padding_size = context_length;
|
||||
padding_idx++;
|
||||
}
|
||||
if (padding_begin > 0 || sequence_height == context_start)
|
||||
padding_idx = padding_begin + t;
|
||||
framework::Tensor out_t_sub = out_t.Slice(
|
||||
(down_pad_begin_row + t) * context_length - padding_size,
|
||||
(down_pad_begin_row + t) * context_length);
|
||||
framework::Tensor w_sub = padding_data.Slice(
|
||||
up_pad + padding_idx, up_pad + padding_idx + padding_size);
|
||||
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
|
||||
auto w_sub_e = EigenMatrix<T>::From(w_sub);
|
||||
if (gradient) {
|
||||
w_sub_e.device(*context.GetEigenDevice<Place>()) =
|
||||
w_sub_e + out_t_sub_e;
|
||||
} else {
|
||||
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
|
||||
}
|
||||
}
|
||||
}
|
||||
out_t.Resize({sequence_height, context_length * sequence_width});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,177 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/sequence_conv_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SequenceConvOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SequenceConvOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Filter"),
|
||||
"Input(Filter) of SequenceConvOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of SequenceConvOp should not be null.");
|
||||
|
||||
int context_length = ctx->Attrs().Get<int>("context_length");
|
||||
bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
|
||||
int context_start = ctx->Attrs().Get<int>("context_start");
|
||||
|
||||
auto in_dims = ctx->GetInputDim("X");
|
||||
auto filter_dims = ctx->GetInputDim("Filter");
|
||||
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
|
||||
"Input(X, Filter) should be 2-D tensor.");
|
||||
PADDLE_ENFORCE(filter_dims[0] == context_length * in_dims[1],
|
||||
"Filter's height should be context_length * "
|
||||
"number_of_input_features .");
|
||||
|
||||
if (padding_trainable) {
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("PaddingData"),
|
||||
"Input(PaddingData) of SequenceConvOp should not be null.");
|
||||
framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
|
||||
int up_pad = std::max(0, -context_start);
|
||||
int down_pad = std::max(0, context_start + context_length - 1);
|
||||
int total_pad = up_pad + down_pad;
|
||||
int input_width = static_cast<int>(in_dims[1]);
|
||||
|
||||
if (context_start == 0 && context_length == 1) {
|
||||
PADDLE_THROW(
|
||||
"If context_start is 0 and context_length is 1, padding_trainable "
|
||||
"should be false.");
|
||||
}
|
||||
PADDLE_ENFORCE(padding_dim.size() == 2,
|
||||
"Input(PaddingData) should be 2-D tensor.");
|
||||
PADDLE_ENFORCE(
|
||||
padding_dim[0] == total_pad && padding_dim[1] == input_width,
|
||||
"Input(PaddingData)'s shape is not consistent with 'context_start' "
|
||||
"and 'context_length'.");
|
||||
}
|
||||
|
||||
in_dims[1] = filter_dims[1];
|
||||
ctx->SetOutputDim("Out", in_dims);
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceConvGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Gradient of output(Out) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null.");
|
||||
|
||||
if (ctx->Attrs().Get<bool>("padding_trainable") &&
|
||||
ctx->HasOutput(framework::GradVarName("PaddingData"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("PaddingData"),
|
||||
ctx->GetInputDim("PaddingData"));
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Filter"),
|
||||
ctx->GetInputDim("Filter"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SequenceConvOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"X",
|
||||
"(LoDTensor) the input(X) is a LodTensor, which support "
|
||||
"variable-time length input sequence. The underlying tensor in "
|
||||
"this LoDTensor is a matrix with shape (T, D), where, T is the "
|
||||
"total time steps in this mini-batch, D is the input feature size.");
|
||||
AddInput("PaddingData",
|
||||
"(Tensor, optional) the input(PaddingData) is an optional "
|
||||
"parameter, and it is learnable. "
|
||||
"This is a tensor with shape (N, D), where N is the "
|
||||
"top_pad + bottom_pad, D is the input feature size. In order to "
|
||||
"ensure the equal length of sequence before and after "
|
||||
"convolution, it is necessary to fill the top and bottom of each "
|
||||
"sequence according to context_length, context_stride and "
|
||||
"context_start")
|
||||
.AsDispensable();
|
||||
AddInput("Filter",
|
||||
"(Tensor) the input(Filter) is an learnable parameter."
|
||||
"This is a tensor with shape (N, D), where N is the "
|
||||
"context_length, D is the output feature size.");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(LoDTensor) the output(Out) is a LodTensor, which support "
|
||||
"variable-time length output sequence. The underlying tensor in "
|
||||
"this LoDTensor is a matrix with shape (T, D), where, T is the "
|
||||
"total time steps in this mini-batch, D is the output feature size.");
|
||||
|
||||
AddAttr<bool>("padding_trainable",
|
||||
"(bool, default false) the padding data of SequenceConvOp "
|
||||
"is trainable or not.")
|
||||
.SetDefault(false);
|
||||
AddAttr<int>("context_length",
|
||||
"(int, default 3) the context_length of SequenceConvOp is the "
|
||||
"height of the convolution kernel.")
|
||||
.SetDefault(3)
|
||||
.GreaterThan(0);
|
||||
AddAttr<int>("context_start",
|
||||
"(int, default 0) the context_start of SequenceConvOp "
|
||||
"represents the beginning of the convolution of the number of "
|
||||
"rows of sequence, which can be negative.")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>("context_stride",
|
||||
"(int, default 1) the context_stride of SequenceConvOp "
|
||||
"represents the step length of convolution. "
|
||||
"Currently, SequenceConvOp only supports"
|
||||
"context_stride=1.")
|
||||
.SetDefault(1)
|
||||
.GreaterThan(0);
|
||||
|
||||
AddComment(R"DOC(
|
||||
SequenceConvOp performs convolution operation on features of
|
||||
context_length time-steps of each instance.
|
||||
The convolution operation calculates the output based on the input, filter
|
||||
and strides, paddings parameters. The size of each dimension of the
|
||||
parameters is checked in the infer-shape. In order to ensure the equal
|
||||
length of sequence before and after convolution, it is necessary to fill
|
||||
the top and bottom of each sequence according to context_length,
|
||||
context_stride and context_start.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
|
||||
sequence_conv_grad, ops::SequenceConvGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_conv_grad,
|
||||
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/sequence_conv_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
sequence_conv_grad,
|
||||
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,170 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/context_project.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class SequenceConvKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* out = context.Output<LoDTensor>("Out");
|
||||
auto filter = *context.Input<Tensor>("Filter");
|
||||
|
||||
out->mutable_data<T>(context.GetPlace());
|
||||
context.ShareLoD("X", "Out");
|
||||
|
||||
int context_start = context.Attr<int>("context_start");
|
||||
int context_length = context.Attr<int>("context_length");
|
||||
int context_stride = context.Attr<int>("context_stride");
|
||||
bool padding_trainable = context.Attr<bool>("padding_trainable");
|
||||
|
||||
// InferShape by in_lod
|
||||
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
|
||||
"Only support one level sequence now.");
|
||||
|
||||
const Tensor* padding_data = nullptr;
|
||||
if (padding_trainable) {
|
||||
padding_data = context.Input<Tensor>("PaddingData");
|
||||
}
|
||||
|
||||
int up_pad = std::max(0, -context_start);
|
||||
int down_pad = std::max(0, context_start + context_length - 1);
|
||||
int sequence_width;
|
||||
sequence_width = static_cast<int>(in->dims()[1]);
|
||||
|
||||
// Use col_shape in the im2col calculation.
|
||||
framework::DDim col_shape = {in->dims()[0],
|
||||
sequence_width * context_length};
|
||||
Tensor col;
|
||||
col.mutable_data<T>(col_shape, context.GetPlace());
|
||||
math::SetConstant<Place, T> set_zero;
|
||||
// Because if padding_trainable is false, padding data should be zeros.
|
||||
set_zero(context.device_context(), &col, static_cast<T>(0));
|
||||
|
||||
paddle::operators::math::ContextProjectFunctor<Place, T>
|
||||
seq_project_functor;
|
||||
LoDTensor* input = const_cast<LoDTensor*>(in);
|
||||
Tensor* pad_data = const_cast<Tensor*>(padding_data);
|
||||
|
||||
seq_project_functor(context.device_context(), *input, *pad_data, col,
|
||||
padding_trainable, context_start, context_length,
|
||||
context_stride, up_pad, down_pad, false, false, false);
|
||||
|
||||
math::matmul<Place, T>(context.device_context(), col, false, filter, false,
|
||||
static_cast<T>(1.0), out, static_cast<T>(0.0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class SequenceConvGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
|
||||
auto* filter_g = context.Output<Tensor>(framework::GradVarName("Filter"));
|
||||
auto* padding_data_g =
|
||||
context.Output<Tensor>(framework::GradVarName("PaddingData"));
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* filter = context.Input<Tensor>("Filter");
|
||||
|
||||
int context_start = context.Attr<int>("context_start");
|
||||
int context_length = context.Attr<int>("context_length");
|
||||
int context_stride = context.Attr<int>("context_stride");
|
||||
bool padding_trainable = context.Attr<bool>("padding_trainable");
|
||||
|
||||
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
|
||||
"Only support one level sequence now.");
|
||||
auto lod_g_level_0 = in->lod()[0];
|
||||
|
||||
int up_pad = std::max(0, -context_start);
|
||||
int down_pad = std::max(0, context_start + context_length - 1);
|
||||
int sequence_width = static_cast<int>(in->dims()[1]);
|
||||
|
||||
math::SetConstant<Place, T> set_zero;
|
||||
// use col_shape in the im2col calculation
|
||||
framework::DDim col_shape = {in->dims()[0],
|
||||
sequence_width * context_length};
|
||||
Tensor col;
|
||||
|
||||
if (in_g || filter_g || (padding_trainable && padding_data_g)) {
|
||||
col.mutable_data<T>(col_shape, context.GetPlace());
|
||||
// Because if padding_trainable is false, padding data should be zeros.
|
||||
set_zero(context.device_context(), &col, static_cast<T>(0));
|
||||
math::matmul<Place, T>(context.device_context(), *out_g, false, *filter,
|
||||
true, T(1.0), &col, T(1.0));
|
||||
}
|
||||
paddle::operators::math::ContextProjectFunctor<Place, T>
|
||||
seq_project_functor;
|
||||
|
||||
if (in_g) {
|
||||
in_g->mutable_data<T>(context.GetPlace());
|
||||
in_g->set_lod(in->lod());
|
||||
set_zero(context.device_context(), in_g, static_cast<T>(0));
|
||||
|
||||
seq_project_functor(context.device_context(), *in_g, *padding_data_g, col,
|
||||
padding_trainable, context_start, context_length,
|
||||
context_stride, up_pad, down_pad, true, true, false);
|
||||
}
|
||||
|
||||
if (padding_trainable && padding_data_g) {
|
||||
padding_data_g->mutable_data<T>(context.GetPlace());
|
||||
set_zero(context.device_context(), padding_data_g, static_cast<T>(0));
|
||||
|
||||
LoDTensor* input = const_cast<LoDTensor*>(in);
|
||||
seq_project_functor(context.device_context(), *input, *padding_data_g,
|
||||
col, padding_trainable, context_start, context_length,
|
||||
context_stride, up_pad, down_pad, true, false, true);
|
||||
}
|
||||
|
||||
if (filter_g) {
|
||||
filter_g->mutable_data<T>(context.GetPlace());
|
||||
set_zero(context.device_context(), filter_g, static_cast<T>(0));
|
||||
|
||||
Tensor filter_grad = *filter_g;
|
||||
LoDTensor out_grad = *out_g;
|
||||
|
||||
const Tensor* padding_data = nullptr;
|
||||
if (padding_trainable) {
|
||||
padding_data = context.Input<Tensor>("PaddingData");
|
||||
}
|
||||
|
||||
sequence_width = static_cast<int>(in->dims()[1]);
|
||||
|
||||
LoDTensor* input = const_cast<LoDTensor*>(in);
|
||||
Tensor* pad_data = const_cast<Tensor*>(padding_data);
|
||||
|
||||
seq_project_functor(context.device_context(), *input, *pad_data, col,
|
||||
padding_trainable, context_start, context_length,
|
||||
context_stride, up_pad, down_pad, false, false,
|
||||
false);
|
||||
|
||||
math::matmul<Place, T>(context.device_context(), col, true, out_grad,
|
||||
false, T(1.0), &filter_grad, T(1.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,198 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestSeqProject(OpTest):
|
||||
def setUp(self):
|
||||
self.init_test_case()
|
||||
self.op_type = 'sequence_conv'
|
||||
|
||||
if self.context_length == 1 \
|
||||
and self.context_start == 0 \
|
||||
and self.padding_trainable:
|
||||
print "If context_start is 0 " \
|
||||
"and context_length is 1," \
|
||||
" padding_trainable should be false."
|
||||
return
|
||||
|
||||
# one level, batch size
|
||||
x = np.random.uniform(0.1, 1, [self.input_size[0],
|
||||
self.input_size[1]]).astype('float32')
|
||||
w = np.random.uniform(0.1, 1, [
|
||||
self.context_length * self.input_size[1], self.output_represention
|
||||
]).astype('float32')
|
||||
|
||||
begin_pad = np.max([0, -self.context_start])
|
||||
end_pad = np.max([0, self.context_start + self.context_length - 1])
|
||||
total_pad = begin_pad + end_pad
|
||||
padding_data = np.random.uniform(
|
||||
0.1, 1, [total_pad, self.input_size[1]]).astype('float32')
|
||||
self.pad_data = padding_data
|
||||
self.inputs = {
|
||||
'X': (x, self.lod),
|
||||
'Filter': w,
|
||||
}
|
||||
self.inputs_val = ['X', 'Filter']
|
||||
self.inputs_val_no_x = ['Filter']
|
||||
self.inputs_val_no_f = ['X']
|
||||
|
||||
if total_pad != 0:
|
||||
self.inputs['PaddingData'] = padding_data
|
||||
self.inputs_val = ['X', 'PaddingData', 'Filter']
|
||||
self.inputs_val_no_x = ['PaddingData', 'Filter']
|
||||
self.inputs_val_no_f = ['PaddingData', 'X']
|
||||
|
||||
self.attrs = {
|
||||
'context_start': self.context_start,
|
||||
'context_length': self.context_length,
|
||||
'padding_trainable': self.padding_trainable,
|
||||
'context_stride': self.context_stride
|
||||
}
|
||||
out = np.zeros(
|
||||
(self.input_size[0], self.output_represention)).astype('float32')
|
||||
self.outputs = {'Out': out}
|
||||
self.compute()
|
||||
|
||||
def compute(self):
|
||||
x, lod = self.inputs['X']
|
||||
filter = self.inputs['Filter']
|
||||
pading_data = self.pad_data
|
||||
out = np.zeros((self.input_size[0], self.context_length *
|
||||
self.input_size[1])).astype('float32')
|
||||
lod = lod[0]
|
||||
begin_pad = np.max([0, -self.context_start])
|
||||
|
||||
for i in range(len(lod) - 1):
|
||||
for j in range(self.context_length):
|
||||
in_begin = lod[i] + self.context_start + j
|
||||
in_end = lod[i + 1] + self.context_start + j
|
||||
out_begin = lod[i]
|
||||
out_end = lod[i + 1]
|
||||
if in_begin < lod[i]:
|
||||
pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]])
|
||||
if self.padding_trainable:
|
||||
sub_w = pading_data[j:j + pad_size, :]
|
||||
out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:(
|
||||
j + 1) * self.input_size[1]] = sub_w
|
||||
out_begin = lod[i] + pad_size
|
||||
in_begin = lod[i]
|
||||
|
||||
if in_end > lod[i + 1]:
|
||||
pad_size = np.min(
|
||||
[in_end - lod[i + 1], lod[i + 1] - lod[i]])
|
||||
if self.padding_trainable:
|
||||
sub_w = pading_data[begin_pad + self.context_start + j -
|
||||
pad_size:begin_pad +
|
||||
self.context_start + j, :]
|
||||
out[lod[i + 1] - pad_size:lod[i + 1], j * self.
|
||||
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
|
||||
in_end = lod[i + 1]
|
||||
out_end = lod[i + 1] - pad_size
|
||||
if in_end <= in_begin:
|
||||
continue
|
||||
|
||||
in_sub = x[in_begin:in_end, :]
|
||||
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
|
||||
self.input_size[1]] += in_sub
|
||||
|
||||
np.dot(out, filter, out=self.outputs['Out'])
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
if self.padding_trainable:
|
||||
self.check_grad(
|
||||
set(self.inputs_val), 'Out', max_relative_error=0.05)
|
||||
|
||||
def test_check_grad_input(self):
|
||||
self.check_grad(
|
||||
['X'],
|
||||
'Out',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(self.inputs_val_no_x))
|
||||
|
||||
def test_check_grad_padding_data(self):
|
||||
if self.padding_trainable:
|
||||
self.check_grad(
|
||||
['PaddingData'],
|
||||
'Out',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(['X', 'Filter']))
|
||||
|
||||
def test_check_grad_Filter(self):
|
||||
self.check_grad(
|
||||
['Filter'],
|
||||
'Out',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(self.inputs_val_no_f))
|
||||
|
||||
def test_check_grad_input_filter(self):
|
||||
if self.padding_trainable:
|
||||
self.check_grad(
|
||||
['X', 'Filter'],
|
||||
'Out',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(['PaddingData']))
|
||||
|
||||
def test_check_grad_padding_input(self):
|
||||
if self.padding_trainable:
|
||||
self.check_grad(
|
||||
self.inputs_val_no_f,
|
||||
'Out',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(['Filter']))
|
||||
|
||||
def test_check_grad_padding_filter(self):
|
||||
if self.padding_trainable:
|
||||
self.check_grad(
|
||||
self.inputs_val_no_x,
|
||||
'Out',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(['X']))
|
||||
|
||||
def init_test_case(self):
|
||||
self.input_row = 11
|
||||
self.context_start = 0
|
||||
self.context_length = 1
|
||||
self.padding_trainable = False
|
||||
self.context_stride = 1
|
||||
|
||||
self.input_size = [self.input_row, 23]
|
||||
self.lod = [[0, 4, 5, 8, self.input_row]]
|
||||
self.output_represention = 8 # output feature size
|
||||
|
||||
|
||||
class TestSeqProjectCase1(TestSeqProject):
|
||||
def init_test_case(self):
|
||||
self.input_row = 11
|
||||
self.context_start = -1
|
||||
self.context_length = 3
|
||||
self.padding_trainable = True
|
||||
self.context_stride = 1
|
||||
|
||||
self.input_size = [self.input_row, 23]
|
||||
self.lod = [[0, 4, 5, 8, self.input_row]]
|
||||
self.output_represention = 8 # output feature size
|
||||
|
||||
|
||||
class TestSeqProjectCase2(TestSeqProject):
|
||||
def init_test_case(self):
|
||||
self.input_row = 25
|
||||
self.context_start = 2
|
||||
self.context_length = 3
|
||||
self.padding_trainable = True
|
||||
self.context_stride = 1
|
||||
|
||||
self.input_size = [self.input_row, 23]
|
||||
idx = range(self.input_size[0])
|
||||
del idx[0]
|
||||
self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
|
||||
[self.input_size[0]]]
|
||||
self.output_represention = 8 # output feature size
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue