commit
c198523885
@ -0,0 +1,130 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/sequence_reshape_op.h"
|
||||||
|
#include "paddle/framework/ddim.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class SequenceReshapeOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||||
|
"Input(X) of SequenceReshapeOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||||
|
"Output(Out) of SequenceReshapeOp should not be null.");
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
auto x_numel = product(x_dims);
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
|
||||||
|
int new_dim = ctx->Attrs().Get<int>("new_dim");
|
||||||
|
ctx->SetOutputDim("Out",
|
||||||
|
{x_numel / new_dim, static_cast<int64_t>(new_dim)});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("X",
|
||||||
|
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape "
|
||||||
|
"being [N, M].");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with "
|
||||||
|
"shape [T, new_dim] where T is calculated based on X.lod, M and "
|
||||||
|
"new_dim.");
|
||||||
|
AddAttr<int>("new_dim", "Sequence dimension of the output LoDTensor.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Sequence Reshape Operator.
|
||||||
|
|
||||||
|
This operator will rearrange the input sequences. The new dimension is set by
|
||||||
|
attribute and length of each sequence may change longer or shorter which is
|
||||||
|
decided by original length, original dimension and new dimension. The following
|
||||||
|
example will help to illustrate the function of this operator:
|
||||||
|
|
||||||
|
x is a LoDTensor:
|
||||||
|
x.lod = [[0, 2, 6]]
|
||||||
|
x.data = [[1, 2], [3, 4],
|
||||||
|
[5, 6], [7, 8], [9, 10], [11, 12]]
|
||||||
|
x.dims = [6, 2]
|
||||||
|
|
||||||
|
set new_dim = 4
|
||||||
|
|
||||||
|
then out is a LoDTensor:
|
||||||
|
out.lod = [[0, 1, 3]]
|
||||||
|
out.data = [[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8], [9, 10, 11, 12]]
|
||||||
|
out.dims = [3, 4]
|
||||||
|
|
||||||
|
Currently, only 1-level LoDTensor is supported and please make sure (original
|
||||||
|
length * original dimension) can be divided by new_dim with no remainder for
|
||||||
|
each sequence.
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SequenceReshapeGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
ctx->HasInput(framework::GradVarName("Out")),
|
||||||
|
"Input(Out@GRAD) of SequenceReshapeGradOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||||
|
"Input(X) of SequenceReshapeGradOp should not be null.");
|
||||||
|
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||||
|
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker {
|
||||||
|
public:
|
||||||
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||||
|
auto* op_desc_ptr = new framework::OpDesc();
|
||||||
|
op_desc_ptr->SetType("sequence_reshape_grad");
|
||||||
|
op_desc_ptr->SetInput("X", Input("X"));
|
||||||
|
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||||
|
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||||
|
op_desc_ptr->SetAttrMap(Attrs());
|
||||||
|
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp,
|
||||||
|
ops::SequenceReshapeOpMaker, ops::SequenceReshapeGradOpMaker);
|
||||||
|
REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
sequence_reshape,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, double>,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int>,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
sequence_reshape_grad,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int>);
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/sequence_reshape_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
sequence_reshape,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, double>,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int>,
|
||||||
|
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
sequence_reshape_grad,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int64_t>,
|
||||||
|
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, int>);
|
@ -0,0 +1,86 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class SequenceReshapeKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
auto* in = context.Input<LoDTensor>("X");
|
||||||
|
auto* out = context.Output<LoDTensor>("Out");
|
||||||
|
int out_width = context.Attr<int>("new_dim");
|
||||||
|
|
||||||
|
auto in_dims = in->dims();
|
||||||
|
int64_t in_width = in_dims[1];
|
||||||
|
auto& in_lod = in->lod();
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(in_lod.size(), 1UL,
|
||||||
|
"Only support one level sequence now.");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
in_dims[0], in_lod[0].back(),
|
||||||
|
"Inconsistent size between X.shape[0] and X.lod()[0].back().");
|
||||||
|
|
||||||
|
auto in_lod_l0 = in_lod[0];
|
||||||
|
int seq_num = in_lod_l0.size() - 1;
|
||||||
|
|
||||||
|
if (in_width == out_width) {
|
||||||
|
out->set_lod(in->lod());
|
||||||
|
} else {
|
||||||
|
auto& out_lod = *out->mutable_lod();
|
||||||
|
out_lod.resize(1);
|
||||||
|
out_lod[0].resize(seq_num + 1);
|
||||||
|
out_lod[0][0] = 0;
|
||||||
|
for (int i = 0; i < seq_num; ++i) {
|
||||||
|
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
|
||||||
|
size_t offset = 0;
|
||||||
|
offset = (seq_len * in_width) / out_width;
|
||||||
|
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
|
||||||
|
"Please make sure (sequence_length * dimension) can "
|
||||||
|
"be divided by new_dim with no remainder for each "
|
||||||
|
"sequence. The %dth sequence is invalid.",
|
||||||
|
i + 1);
|
||||||
|
out_lod[0][i + 1] = out_lod[0][i] + offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::Copy(*in, context.GetPlace(), out);
|
||||||
|
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class SequenceReshapeGradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
auto* x_tensor_ptr = context.Input<LoDTensor>("X");
|
||||||
|
auto* outg_tensor_ptr =
|
||||||
|
context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto* xg_tensor_ptr =
|
||||||
|
context.Output<LoDTensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
xg_tensor_ptr->mutable_data<T>(context.GetPlace());
|
||||||
|
framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr);
|
||||||
|
xg_tensor_ptr->Resize(x_tensor_ptr->dims());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceReshape(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = 'sequence_reshape'
|
||||||
|
dimension = 12
|
||||||
|
x_lod = [[0, 4, 5, 8, 11]]
|
||||||
|
x = np.random.uniform(0.1, 1, [11, 24]).astype('float32')
|
||||||
|
|
||||||
|
self.inputs = {'X': (x, x_lod)}
|
||||||
|
self.attrs = {'new_dim': dimension}
|
||||||
|
|
||||||
|
out, out_lod = self.compute_output(x, x_lod, dimension)
|
||||||
|
|
||||||
|
self.outputs = {'Out': (out, out_lod)}
|
||||||
|
|
||||||
|
def compute_output(self, x, x_lod, dimension):
|
||||||
|
x_width = x.shape[1]
|
||||||
|
out_lod = [[0]]
|
||||||
|
for i in xrange(len(x_lod[0]) - 1):
|
||||||
|
seq_len = x_lod[0][i + 1] - x_lod[0][i]
|
||||||
|
offset = (seq_len * x_width) / dimension
|
||||||
|
assert int(offset) * dimension == seq_len * x_width
|
||||||
|
out_lod[0].append(out_lod[0][-1] + int(offset))
|
||||||
|
out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32')
|
||||||
|
out.ravel()[:] = x.ravel()[:]
|
||||||
|
return out, out_lod
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(["X"], "Out")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceReshape_reduce(TestSequenceReshape):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = 'sequence_reshape'
|
||||||
|
dimension = 24
|
||||||
|
x_lod = [[0, 4, 6, 8, 12]]
|
||||||
|
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
|
||||||
|
|
||||||
|
self.inputs = {'X': (x, x_lod)}
|
||||||
|
self.attrs = {'new_dim': dimension}
|
||||||
|
|
||||||
|
out, out_lod = self.compute_output(x, x_lod, dimension)
|
||||||
|
|
||||||
|
self.outputs = {'Out': (out, out_lod)}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceReshape_same(TestSequenceReshape):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = 'sequence_reshape'
|
||||||
|
dimension = 12
|
||||||
|
x_lod = [[0, 4, 6, 8, 12]]
|
||||||
|
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
|
||||||
|
|
||||||
|
self.inputs = {'X': (x, x_lod)}
|
||||||
|
self.attrs = {'new_dim': dimension}
|
||||||
|
|
||||||
|
out, out_lod = self.compute_output(x, x_lod, dimension)
|
||||||
|
|
||||||
|
self.outputs = {'Out': (out, out_lod)}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue