parent
907e6d04de
commit
e85c513307
@ -0,0 +1,61 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/sequence_erase_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SequenceEraseOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SequenceEraseOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of SequenceEraseOp should not be null.");
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(LoDTensor) 2-D input LoDTensor with the 2-nd dimension "
|
||||
"of length 1.");
|
||||
AddOutput("Out",
|
||||
"(LoDTensor) 2-D output LoDTensor with the 2-nd dimension "
|
||||
"of length 1.");
|
||||
AddAttr<std::vector<int>>("tokens",
|
||||
"(vector<int>) "
|
||||
"Tokens to be removed from input.");
|
||||
AddComment(R"DOC(
|
||||
Sequence Erase Operator.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
|
||||
ops::SequenceEraseOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sequence_erase,
|
||||
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>);
|
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/softmax.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SequenceEraseKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<LoDTensor>("X");
|
||||
auto* out = ctx.Output<LoDTensor>("Out");
|
||||
|
||||
auto lod = in->lod();
|
||||
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
||||
// auto dims = x->dims();
|
||||
/*
|
||||
const size_t level = lod.size() - 1;
|
||||
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
|
||||
"The first dimension of Input(X) should be equal to the "
|
||||
"sum of all sequences' lengths.");
|
||||
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
|
||||
"The width of each timestep in Input(X) of "
|
||||
"SequenceEraseOp should be 1.");
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
*/
|
||||
auto tokens = ctx.Attr<std::vector<int>>("tokens");
|
||||
auto in_len = in->numel();
|
||||
auto in_dat = in->data<T>();
|
||||
auto lod0 = lod[0];
|
||||
std::vector<size_t> num_erased(in_len + 1, 0);
|
||||
for (int64_t i = 1; i < in_len + 1; ++i) {
|
||||
num_erased[i] = num_erased[i - 1];
|
||||
if (std::find(tokens.begin(), tokens.end(), in_dat[i - 1]) !=
|
||||
tokens.end()) {
|
||||
num_erased[i] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> out_lod0(lod0.size(), 0);
|
||||
for (size_t i = 1; i < lod0.size(); ++i) {
|
||||
out_lod0[i] = lod0[i] - num_erased[lod0[i]];
|
||||
}
|
||||
|
||||
auto out_len = in_len - num_erased[in_len];
|
||||
out->Resize({static_cast<int64_t>(out_len), 1});
|
||||
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
for (size_t i = 0; i < in_len; ++i) {
|
||||
if (num_erased[i] == num_erased[i + 1]) {
|
||||
out_dat[i - num_erased[i]] = in_dat[i];
|
||||
}
|
||||
}
|
||||
framework::LoD out_lod;
|
||||
out_lod.push_back(out_lod0);
|
||||
out->set_lod(out_lod);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,58 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def sequence_erase(in_seq, lod0, tokens):
|
||||
# num_erased[i]: the number of elments to be removed before #i elements
|
||||
num_erased = [0] * (len(in_seq) + 1)
|
||||
for i in range(1, len(in_seq) + 1):
|
||||
num_erased[i] = num_erased[i - 1]
|
||||
if in_seq[i - 1] in tokens:
|
||||
num_erased[i] += 1
|
||||
|
||||
# recalculate lod information
|
||||
new_lod0 = [0] * len(lod0)
|
||||
for i in range(1, len(lod0)):
|
||||
new_lod0[i] = lod0[i] - num_erased[lod0[i]]
|
||||
|
||||
out_seq = np.zeros(
|
||||
(len(in_seq) - num_erased[len(in_seq)], 1)).astype("int32")
|
||||
for i in range(0, len(in_seq)):
|
||||
if num_erased[i] == num_erased[i + 1]:
|
||||
out_seq[i - num_erased[i]] = in_seq[i]
|
||||
# else in_seq[i] needs to be removed
|
||||
return out_seq, new_lod0
|
||||
|
||||
|
||||
class TestSequenceEraseOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "sequence_erase"
|
||||
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
|
||||
lod = [[0, 5, 15, 30]]
|
||||
tokens = [2, 5]
|
||||
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
|
||||
|
||||
self.attrs = {'tokens': tokens}
|
||||
self.inputs = {'X': (in_seq, lod)}
|
||||
self.outputs = {'Out': (out_seq, [new_lod0])}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
|
||||
lod0 = [0, 5, 15, 30]
|
||||
tokens = [2, 5]
|
||||
out_seq, new_lod = sequence_erase(in_seq, lod0, tokens)
|
||||
|
||||
print lod0, new_lod
|
||||
print("compare")
|
||||
for i in range(0, len(lod0)-1):
|
||||
print(np.transpose(in_seq[lod0[i] : lod0[i+1]]))
|
||||
print(np.transpose(out_seq[new_lod[i] : new_lod[i+1]]))
|
||||
print("\n")
|
||||
"""
|
||||
unittest.main()
|
Loading…
Reference in new issue