commit
7b84c580e2
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,194 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/sequence_pad_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SequencePadOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SequencePadOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("PadValue"),
|
||||
"Input(PadValue) of SequencePadOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of SequencePadOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
||||
"The rank of Input(x) can't be less than 2.");
|
||||
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
|
||||
auto pad_value_dims = ctx->GetInputDim("PadValue");
|
||||
PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) ||
|
||||
pad_value_dims == time_step_dims,
|
||||
"The Input(PadValue) must be a scalar or a tensor whose "
|
||||
"shape equals to time steps in sequences");
|
||||
|
||||
int out_dim_0 = -1;
|
||||
int out_dim_1 = -1;
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
// run time
|
||||
framework::Variable* x_var =
|
||||
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
|
||||
const auto& x_lod = x_var->Get<LoDTensor>().lod();
|
||||
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
|
||||
const auto& x_lod_0 = x_lod[0];
|
||||
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
|
||||
"The Input(X)'s lod info is corrupted.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
x_dims[0], static_cast<int64_t>(x_lod_0.back()),
|
||||
"The Input(X)'s lod info mismatches the actual tensor shape.");
|
||||
|
||||
int seq_num = x_lod_0.size() - 1;
|
||||
int max_seq_len = math::MaximumSequenceLength(x_lod_0);
|
||||
int padded_length = ctx->Attrs().Get<int>("padded_length");
|
||||
if (padded_length == -1) {
|
||||
padded_length = max_seq_len;
|
||||
}
|
||||
PADDLE_ENFORCE_GE(padded_length, max_seq_len,
|
||||
"The Attr(padded_length) must be -1 or an int greater "
|
||||
"than the length of the longest original sequence.");
|
||||
out_dim_0 = seq_num;
|
||||
out_dim_1 = padded_length;
|
||||
} else {
|
||||
// compile time
|
||||
framework::VarDesc* x_desc =
|
||||
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]);
|
||||
PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1);
|
||||
}
|
||||
|
||||
std::vector<int> out_dims_vec{out_dim_0, out_dim_1};
|
||||
auto time_step_dims_vec = framework::vectorize2int(time_step_dims);
|
||||
out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(),
|
||||
time_step_dims_vec.end());
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
|
||||
}
|
||||
};
|
||||
|
||||
class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(LoDTensor, default LoDTensor<float>) Input variable which "
|
||||
"should contain lod information.");
|
||||
AddInput("PadValue",
|
||||
"(LoDTensor), this Tensor holds values that will be fill into "
|
||||
"padded steps. It can be a scalar or a tensor whose shape equals "
|
||||
"to time steps in sequences. If it's a scalar, it will be "
|
||||
"automatically broadcasted to the shape of time step.");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(LoDTensor) The output vairable, which contains padded sequences.");
|
||||
AddAttr<int>(
|
||||
"padded_length",
|
||||
"The length of padded sequences. It can be setted to -1 or "
|
||||
"any positive int. When it is -1, all sequences will be padded up to "
|
||||
"the length of the longest one among them; when it a certain positive "
|
||||
"value, it must be greater than the length of the longest original "
|
||||
"sequence.")
|
||||
.SetDefault(-1);
|
||||
AddComment(R"DOC(
|
||||
Sequence Pad Operator
|
||||
|
||||
This operator pads sequences in a same batch to a consistent length.
|
||||
The length is specified by attribute 'padded_length'. New elements,
|
||||
whose values are specified by input 'PadValue', will be appended to
|
||||
the end of each sequence, to make their final lengths consistent.
|
||||
|
||||
Following are cases to better explain how this works:
|
||||
|
||||
Case 1:
|
||||
|
||||
Given a 1-level LoDTensor input(X):
|
||||
X.lod = [[0, 2, 5]]
|
||||
X.data = [a, b, c, d, e]
|
||||
and Input(PadValue):
|
||||
PadValue.data = [0]
|
||||
and attribite 'padded_length' = 4,
|
||||
then we get LoDTensor:
|
||||
Out.data = [[a, b, 0, 0],
|
||||
[c, d, e, 0]]
|
||||
|
||||
Case 2:
|
||||
|
||||
Given a 1-level LoDTensor input(X):
|
||||
X.lod = [[0, 2, 5]]
|
||||
X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
|
||||
and Input(PadValue):
|
||||
PadValue.data = [0]
|
||||
and attribite 'padded_length' = -1, which mean using the length
|
||||
of longest input sequence(3 in this case),
|
||||
then we get LoDTensor:
|
||||
Out.data = [[[a1, a2], [b1, b2], [0, 0]],
|
||||
[[c1, c2], [d1, d2], [e1, e2]]]
|
||||
|
||||
Case 3:
|
||||
|
||||
Given a 1-level LoDTensor input(X):
|
||||
X.lod = [[0, 2, 5]]
|
||||
X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
|
||||
and Input(PadValue):
|
||||
PadValue.data = [p1, p2]
|
||||
and attribite 'padded_length' = -1, which mean using the length
|
||||
of longest input sequence(3 in this case),
|
||||
then we get LoDTensor:
|
||||
Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
|
||||
[[c1, c2], [d1, d2], [e1, e2]]]
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SequencePadGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SequencePadGradOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
|
||||
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_pad,
|
||||
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_pad_grad,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,29 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/sequence_pad_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sequence_pad,
|
||||
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sequence_pad_grad,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/memcpy.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/sequence_padding.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using LoD = framework::LoD;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SequencePadOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto* x = ctx.Input<LoDTensor>("X");
|
||||
auto* out = ctx.Output<LoDTensor>("Out");
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
const auto* pad_value = ctx.Input<LoDTensor>("PadValue");
|
||||
|
||||
int padded_length = ctx.Attr<int>("padded_length");
|
||||
|
||||
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), *x, out, *pad_value,
|
||||
padded_length, 0, false, math::kBatchLengthWidth);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SequencePadGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
|
||||
if (d_x) {
|
||||
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||
d_x->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int padded_length = ctx.Attr<int>("padded_length");
|
||||
|
||||
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
||||
ctx.template device_context<DeviceContext>(), *d_out, d_x,
|
||||
padded_length, 0, false, math::kBatchLengthWidth);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,131 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestSequencePadOp(OpTest):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 4]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [1.0]
|
||||
self.padded_length = -1
|
||||
self.dtype = 'float32'
|
||||
|
||||
def set_data(self):
|
||||
x_data = np.random.uniform(0.1, 0.5, self.x_shape).astype(self.dtype)
|
||||
pad_value_data = np.array(self.pad_value).astype(self.dtype)
|
||||
self.inputs = {
|
||||
'X': (x_data, self.x_len_lod),
|
||||
'PadValue': pad_value_data
|
||||
}
|
||||
self.attrs = {'padded_length': self.padded_length}
|
||||
|
||||
def compute(self):
|
||||
# get padded length
|
||||
padded_length = self.padded_length
|
||||
x_len_lod_0 = self.x_len_lod[0]
|
||||
if padded_length == -1:
|
||||
max_seq_len = 0
|
||||
for l in x_len_lod_0:
|
||||
max_seq_len = max(max_seq_len, l)
|
||||
padded_length = max_seq_len
|
||||
|
||||
# do padding
|
||||
x_data = self.inputs['X'][0]
|
||||
pad_value_data = self.inputs['PadValue']
|
||||
if pad_value_data.shape == (1, ):
|
||||
pad_value_data = np.broadcast_to(
|
||||
pad_value_data, shape=x_data.shape[1:])
|
||||
padded_sequences = []
|
||||
start_idx = 0
|
||||
for l in x_len_lod_0:
|
||||
end_idx = start_idx + l
|
||||
seq = x_data[start_idx:end_idx]
|
||||
to_pad_len = padded_length - l
|
||||
for _ in range(to_pad_len):
|
||||
seq = np.append(seq, pad_value_data[np.newaxis, :], axis=0)
|
||||
padded_sequences.append(seq)
|
||||
start_idx = end_idx
|
||||
|
||||
out_data = np.array(padded_sequences)
|
||||
self.outputs = {'Out': out_data}
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = 'sequence_pad'
|
||||
self.set_attr()
|
||||
self.set_data()
|
||||
self.compute()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["X"], "Out")
|
||||
|
||||
|
||||
class TestSequencePadOp2(TestSequencePadOp):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 4]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [1.0, 2.0, 3.0, 4.0]
|
||||
self.padded_length = -1
|
||||
self.dtype = 'float32'
|
||||
|
||||
|
||||
class TestSequencePadOp3(TestSequencePadOp):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 4]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [1.0]
|
||||
self.padded_length = 7
|
||||
self.dtype = 'float32'
|
||||
|
||||
|
||||
class TestSequencePadOp4(TestSequencePadOp):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 4]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [1.0, 2.0, 3.0, 4.0]
|
||||
self.padded_length = 7
|
||||
self.dtype = 'float32'
|
||||
|
||||
|
||||
class TestSequencePadOp5(TestSequencePadOp):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 2, 2]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [1.0]
|
||||
self.padded_length = -1
|
||||
self.dtype = 'float32'
|
||||
|
||||
|
||||
class TestSequencePadOp6(TestSequencePadOp):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 2, 2]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [[1.0, 2.0], [3.0, 4.0]]
|
||||
self.padded_length = -1
|
||||
self.dtype = 'float32'
|
||||
|
||||
|
||||
class TestSequencePadOp7(TestSequencePadOp):
|
||||
def set_attr(self):
|
||||
self.x_shape = [12, 2, 2]
|
||||
self.x_len_lod = [[2, 3, 4, 3]]
|
||||
self.pad_value = [1.0]
|
||||
self.padded_length = 7
|
||||
self.dtype = 'float32'
|
Loading…
Reference in new issue