parent
64d48f4d6a
commit
d2e5395b97
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/sequence_enumerate_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SequenceEnumerateOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("X"),
|
||||
"Input(X) of SequecceEnumerate operator should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("Out"),
|
||||
"Output(X) of SequenceEnumerate operator should not be null.");
|
||||
|
||||
const auto x_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE(
|
||||
x_dims.size() == 2 && x_dims[1] == 1,
|
||||
"Input(X) of SequenceEnumerate operator should be a 2-D LoDTensor "
|
||||
"with the 2nd dimension equal to 1.");
|
||||
|
||||
const auto win_size = ctx->Attrs().Get<int>("win_size");
|
||||
PADDLE_ENFORCE(win_size <= x_dims[0],
|
||||
"The enumerate window size should be less than or equal to "
|
||||
"input sequence length.");
|
||||
ctx->SetOutputDim("Out", {x_dims[0], win_size});
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(2-D LoDTensor with the 2nd dimension equal to 1) "
|
||||
"Input LoDTensor of SequenceEnumerate operator.");
|
||||
AddOutput("Out",
|
||||
"(2-D LoDTensor with the 2nd dimension equal to 1) "
|
||||
"Output LoDTensor of SequenceEnumerate operator.");
|
||||
AddAttr<int>("win_size", "(int) The enumerate sequence window size.")
|
||||
.AddCustomChecker([](const int& win_size) {
|
||||
PADDLE_ENFORCE(win_size >= 2,
|
||||
"The window size should be greater than 2.");
|
||||
});
|
||||
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
|
||||
.SetDefault(0);
|
||||
AddComment(R"DOC(
|
||||
Sequence Enumerate Operator.
|
||||
|
||||
Sequence enumerate operator generate a new LoDTensor
|
||||
with the same 1st dimension length as the original LoDTensor,
|
||||
and with the 2nd dimension equal to the input window length,
|
||||
the new sub-sequence on 2nd dimension is enumerated one by one on the original sequence.
|
||||
The values of the last insufficient part areall filled with the input pad_value.
|
||||
|
||||
Examples:
|
||||
Case 1:
|
||||
Input:
|
||||
X.lod = [[0, 3, 5]]
|
||||
X.data = [1, 2, 3, 4, 5]
|
||||
X.dims = [5, 1]
|
||||
Attrs:
|
||||
win_size = 2
|
||||
pad_value = 0
|
||||
Output:
|
||||
Out.lod = [[0, 3, 5]]
|
||||
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]]
|
||||
Out.dims = [5, 2]
|
||||
|
||||
Currently, only 1-level LoDTensor is supported.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp,
|
||||
ops::SequenceEnumerateOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_enumerate,
|
||||
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int32_t>,
|
||||
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include "paddle/fluid/operators/sequence_enumerate_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalcOutPut(const T* in_data, const int64_t in_len,
|
||||
const int64_t win_size, const int64_t pad_value,
|
||||
T* out_data) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < in_len) {
|
||||
for (size_t i = 0; i < win_size; ++i) {
|
||||
int word_pos = index + i;
|
||||
out_data[index * win_size + i] =
|
||||
word_pos < in_len ? in_data[word_pos] : pad_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* out = context.Output<LoDTensor>("Out");
|
||||
int win_size = context.Attr<int>("win_size");
|
||||
int pad_value = context.Attr<int>("pad_value");
|
||||
|
||||
auto in_dims = in->dims();
|
||||
auto in_lod = in->lod();
|
||||
|
||||
PADDLE_ENFORCE_EQ(in_lod.size(), 1UL,
|
||||
"Only support one level sequence now.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
|
||||
"The actual input data's size mismatched with LoD information.");
|
||||
|
||||
/* Generate enumerate sequence set */
|
||||
auto stream = context.cuda_device_context().stream();
|
||||
auto in_len = in->numel();
|
||||
auto in_data = in->data<T>();
|
||||
auto out_data = out->mutable_data<T>(context.GetPlace());
|
||||
// Calc output tensor
|
||||
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
|
||||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
|
||||
in_data, in_len, win_size, pad_value, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sequence_enumerate,
|
||||
paddle::operators::SequenceEnumerateOpCUDAKernel<int32_t>,
|
||||
paddle::operators::SequenceEnumerateOpCUDAKernel<int64_t>);
|
@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SequenceEnumerateKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* out = context.Output<LoDTensor>("Out");
|
||||
int win_size = context.Attr<int>("win_size");
|
||||
int pad_value = context.Attr<int>("pad_value");
|
||||
|
||||
auto in_dims = in->dims();
|
||||
auto in_lod = in->lod();
|
||||
|
||||
PADDLE_ENFORCE_EQ(in_lod.size(), 1UL,
|
||||
"Only support one level sequence now.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
|
||||
"The actual input data's size mismatched with LoD information.");
|
||||
|
||||
// Generate enumerate sequence set
|
||||
auto seq_length = in_dims[0];
|
||||
auto in_data = in->data<T>();
|
||||
auto out_data = out->mutable_data<T>(context.GetPlace());
|
||||
for (int idx = 0; idx < seq_length; ++idx) {
|
||||
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
|
||||
int word_pos = idx + word_idx;
|
||||
out_data[win_size * idx + word_idx] =
|
||||
word_pos < seq_length ? in_data[word_pos] : pad_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def sequence_enumerate(input_seq, lod0, win_size, pad_value):
|
||||
out_seq = []
|
||||
for idx in range(0, len(input_seq)):
|
||||
single_seq = []
|
||||
for word_idx in range(win_size):
|
||||
word_pos = idx + word_idx
|
||||
dat = input_seq[word_pos] if word_pos < len(input_seq) \
|
||||
else pad_value
|
||||
single_seq.append(dat)
|
||||
out_seq.append(single_seq)
|
||||
return out_seq
|
||||
|
||||
|
||||
class TestSequenceEnumerateOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "sequence_enumerate"
|
||||
self.init_test_case()
|
||||
self.inputs = {'X': (self.in_seq, self.lod)}
|
||||
self.attrs = {'win_size': self.win_size, 'pad_value': self.pad_value}
|
||||
self.outputs = {'Out': (self.out_seq, self.lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def init_test_case(self):
|
||||
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
|
||||
self.lod = [[9, 4, 11, 6]]
|
||||
self.win_size = 2
|
||||
self.pad_value = 0
|
||||
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
|
||||
self.pad_value)
|
||||
self.out_seq = np.array(out_seq).astype("int32")
|
||||
|
||||
|
||||
class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
|
||||
def init_test_case(self):
|
||||
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
|
||||
self.lod = [[9, 4, 11, 6]]
|
||||
self.win_size = 2
|
||||
self.pad_value = 0
|
||||
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
|
||||
self.pad_value)
|
||||
self.out_seq = np.array(out_seq).astype("int64")
|
||||
|
||||
|
||||
class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
|
||||
def init_test_case(self):
|
||||
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
|
||||
self.lod = [[9, 4, 11, 6]]
|
||||
self.win_size = 30
|
||||
self.pad_value = 0
|
||||
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
|
||||
self.pad_value)
|
||||
self.out_seq = np.array(out_seq).astype("int32")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue