parent
d0fdcb2f6d
commit
a7f94ec794
@ -0,0 +1,83 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/similarity_focus_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
|
||||
" [BatchSize, X, Y, Z]");
|
||||
AddOutput("Out",
|
||||
"(Tensor, default Tensor<float>), the similarity focus mask"
|
||||
" with the same shape of input X.");
|
||||
AddAttr<int>("axis",
|
||||
"(int32), indicating the dimension to be select. It can"
|
||||
" only be 1, 2, or 3.");
|
||||
AddAttr<std::vector<int>>("indexes",
|
||||
"(std::vector<int32>), indicating the indexes"
|
||||
" of the selected dimension.");
|
||||
AddComment(R"DOC(
|
||||
SimilarityFocus Operator.
|
||||
|
||||
Generate a similarity focus mask with the same shape of input using the following method:
|
||||
1. Extract the 3-D matrix(here the first dimension is BatchSize) corresponding
|
||||
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
|
||||
it will get the matrix T=X[:, a, :, :]. In this casr, if the shape of input X
|
||||
is (BatchSize, A, B, C), the shape of matrix T is (BatchSize, B, C).
|
||||
2. For each index, find the largest numbers in the matrix T, so that the same
|
||||
row and same column has at most one number(obviously there will be min(B, C)
|
||||
numbers), and mark the corresponding position of the 3-D similarity focus mask
|
||||
as 1, otherwise as 0. Do elementwise-or for each index.
|
||||
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
|
||||
|
||||
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SimilarityFocusOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
|
||||
ctx->SetOutputDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
|
||||
platform::CPUPlace());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
|
||||
ops::SimilarityFocusOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
|
||||
ops::SimilarityFocusKernel<double>);
|
@ -0,0 +1,168 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class SimilarityFocusKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
Tensor* out = context.Output<Tensor>("Out");
|
||||
const Tensor* x = context.Input<Tensor>("X");
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
const T* x_data = x->data<T>();
|
||||
|
||||
int axis = context.Attr<int>("axis");
|
||||
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
|
||||
|
||||
int64_t batch_size = x->dims()[0];
|
||||
int64_t dim[4];
|
||||
for (int i = 1; i <= 3; ++i) {
|
||||
dim[i] = x->dims()[i];
|
||||
}
|
||||
|
||||
if (indexes.size() < 1) {
|
||||
PADDLE_THROW("Indexes' size can not be 0.");
|
||||
}
|
||||
for (auto index : indexes) {
|
||||
if (dim[axis] < index) {
|
||||
PADDLE_THROW("Index exceeds tensor shape limit.");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t array_size = 1;
|
||||
for (int i = 1; i <= 3; ++i) {
|
||||
if (i != axis) {
|
||||
array_size *= dim[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<T, int64_t>> array(array_size);
|
||||
|
||||
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
|
||||
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
|
||||
return x.first > y.first;
|
||||
};
|
||||
|
||||
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
|
||||
int64_t* dim, int d1, int d2, int d3, int d4) {
|
||||
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
|
||||
d3 * dim[3] + d4;
|
||||
};
|
||||
|
||||
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (auto index : indexes) {
|
||||
if (axis == 1) {
|
||||
for (int j = 0; j < dim[2]; ++j) {
|
||||
for (int k = 0; k < dim[3]; ++k) {
|
||||
array[j * dim[3] + k] = std::make_pair(
|
||||
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(array.begin(), array.end(), cmp);
|
||||
int tag_num = 0;
|
||||
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
|
||||
for (auto x : array) {
|
||||
int idx2 = x.second / dim[3];
|
||||
int idx3 = x.second % dim[3];
|
||||
if (tag2[idx2] || tag3[idx3]) {
|
||||
continue;
|
||||
}
|
||||
tag_num++;
|
||||
tag2[idx2] = true;
|
||||
tag3[idx3] = true;
|
||||
for (int j = 0; j < dim[1]; ++j) {
|
||||
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
|
||||
}
|
||||
if (tag_num == std::min(dim[2], dim[3])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
for (int j = 0; j < dim[1]; ++j) {
|
||||
for (int k = 0; k < dim[3]; ++k) {
|
||||
array[j * dim[3] + k] = std::make_pair(
|
||||
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(array.begin(), array.end(), cmp);
|
||||
int tag_num = 0;
|
||||
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
|
||||
for (auto x : array) {
|
||||
int idx1 = x.second / dim[3];
|
||||
int idx3 = x.second % dim[3];
|
||||
if (tag1[idx1] || tag3[idx3]) {
|
||||
continue;
|
||||
}
|
||||
tag_num++;
|
||||
tag1[idx1] = true;
|
||||
tag3[idx3] = true;
|
||||
for (int j = 0; j < dim[2]; ++j) {
|
||||
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
|
||||
}
|
||||
if (tag_num == std::min(dim[1], dim[3])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (axis == 3) {
|
||||
for (int j = 0; j < dim[1]; ++j) {
|
||||
for (int k = 0; k < dim[2]; ++k) {
|
||||
array[j * dim[2] + k] = std::make_pair(
|
||||
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(array.begin(), array.end(), cmp);
|
||||
int tag_num = 0;
|
||||
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
|
||||
for (auto x : array) {
|
||||
int idx1 = x.second / dim[2];
|
||||
int idx2 = x.second % dim[2];
|
||||
if (tag1[idx1] || tag2[idx2]) {
|
||||
continue;
|
||||
}
|
||||
tag_num++;
|
||||
tag1[idx1] = true;
|
||||
tag2[idx2] = true;
|
||||
for (int j = 0; j < dim[3]; ++j) {
|
||||
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
|
||||
}
|
||||
if (tag_num == std::min(dim[1], dim[2])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW("Axis must be 1 or 2 or 3");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestSimilarityFocusOp_axis1(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "similarity_focus"
|
||||
batch_size = 3
|
||||
x_dim, y_dim, z_dim = 4, 5, 6
|
||||
self.inputs = {
|
||||
'X': np.random.random(
|
||||
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||
}
|
||||
self.attrs = {
|
||||
'axis': 1,
|
||||
'indexes': [0, 3],
|
||||
}
|
||||
|
||||
output = None
|
||||
for batch in range(batch_size):
|
||||
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
|
||||
for index in self.attrs['indexes']:
|
||||
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
|
||||
)
|
||||
tag1 = [0 for i in range(y_dim)]
|
||||
tag2 = [0 for i in range(z_dim)]
|
||||
cnt = 0
|
||||
for i in range(channel.size):
|
||||
index = channel.argmax()
|
||||
idx1 = index / z_dim
|
||||
idx2 = index % z_dim
|
||||
if tag1[idx1] + tag2[idx2] == 0:
|
||||
tag1[idx1] = 1
|
||||
tag2[idx2] = 1
|
||||
res[index] = 1
|
||||
cnt += 1
|
||||
if cnt == min(y_dim, z_dim):
|
||||
break
|
||||
channel[index] = -1
|
||||
res = res.reshape(1, y_dim, z_dim)
|
||||
res = res.repeat([x_dim], axis=0)
|
||||
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||
if output is not None:
|
||||
output = np.concatenate((output, res), axis=0)
|
||||
else:
|
||||
output = res
|
||||
self.outputs = {'Out': output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestSimilarityFocusOp_axis2(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "similarity_focus"
|
||||
batch_size = 6
|
||||
x_dim, y_dim, z_dim = 7, 8, 9
|
||||
self.inputs = {
|
||||
'X': np.random.random(
|
||||
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||
}
|
||||
self.attrs = {
|
||||
'axis': 2,
|
||||
'indexes': [0, 3, 5],
|
||||
}
|
||||
|
||||
output = None
|
||||
for batch in range(batch_size):
|
||||
res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1)
|
||||
for index in self.attrs['indexes']:
|
||||
channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy(
|
||||
)
|
||||
tag1 = [0 for i in range(x_dim)]
|
||||
tag2 = [0 for i in range(z_dim)]
|
||||
cnt = 0
|
||||
for i in range(channel.size):
|
||||
index = channel.argmax()
|
||||
idx1 = index / z_dim
|
||||
idx2 = index % z_dim
|
||||
if tag1[idx1] + tag2[idx2] == 0:
|
||||
tag1[idx1] = 1
|
||||
tag2[idx2] = 1
|
||||
res[index] = 1
|
||||
cnt += 1
|
||||
if cnt == min(x_dim, z_dim):
|
||||
break
|
||||
channel[index] = -1
|
||||
res = res.reshape(x_dim, 1, z_dim)
|
||||
res = res.repeat([y_dim], axis=1)
|
||||
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||
if output is not None:
|
||||
output = np.concatenate((output, res), axis=0)
|
||||
else:
|
||||
output = res
|
||||
self.outputs = {'Out': output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestSimilarityFocusOp_axis3(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "similarity_focus"
|
||||
batch_size = 64
|
||||
x_dim, y_dim, z_dim = 48, 48, 13
|
||||
self.inputs = {
|
||||
'X': np.random.random(
|
||||
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||
}
|
||||
self.attrs = {
|
||||
'axis': 3,
|
||||
'indexes': [0, 2, 7, 9],
|
||||
}
|
||||
|
||||
output = None
|
||||
for batch in range(batch_size):
|
||||
res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1)
|
||||
for index in self.attrs['indexes']:
|
||||
channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy(
|
||||
)
|
||||
tag1 = [0 for i in range(x_dim)]
|
||||
tag2 = [0 for i in range(y_dim)]
|
||||
cnt = 0
|
||||
for i in range(channel.size):
|
||||
index = channel.argmax()
|
||||
idx1 = index / y_dim
|
||||
idx2 = index % y_dim
|
||||
if tag1[idx1] + tag2[idx2] == 0:
|
||||
tag1[idx1] = 1
|
||||
tag2[idx2] = 1
|
||||
res[index] = 1
|
||||
cnt += 1
|
||||
if cnt == min(x_dim, y_dim):
|
||||
break
|
||||
channel[index] = -1
|
||||
res = res.reshape(x_dim, y_dim, 1)
|
||||
res = res.repeat([z_dim], axis=2)
|
||||
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||
if output is not None:
|
||||
output = np.concatenate((output, res), axis=0)
|
||||
else:
|
||||
output = res
|
||||
self.outputs = {'Out': output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue