Merge pull request #14071 from barrierye/add_similarity_focus_op
Add similarity focus oprevert-14324-fix_vlog
commit
ff28b1ffc0
@ -0,0 +1,87 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/similarity_focus_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X",
|
||||||
|
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
|
||||||
|
" [BatchSize, X, Y, Z]");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(Tensor, default Tensor<float>), the similarity focus mask"
|
||||||
|
" with the same shape of input X.");
|
||||||
|
AddAttr<int>("axis",
|
||||||
|
"(int32), indicating the dimension to be select. It can"
|
||||||
|
" only be 1, 2, or 3.");
|
||||||
|
AddAttr<std::vector<int>>("indexes",
|
||||||
|
"(std::vector<int32>), indicating the indexes"
|
||||||
|
" of the selected dimension.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
SimilarityFocus Operator.
|
||||||
|
|
||||||
|
Generate a similarity focus mask with the same shape of input using the following method:
|
||||||
|
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
|
||||||
|
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
|
||||||
|
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
|
||||||
|
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
|
||||||
|
2. For each index, find the largest numbers in the tensor T, so that the same
|
||||||
|
row and same column has at most one number(what it means is that if the
|
||||||
|
largest number has been found in the i-th row and the j-th column, then
|
||||||
|
the numbers in the i-th row or j-th column will be skipped. And then the
|
||||||
|
next largest number will be selected from the remaining numbers. Obviously
|
||||||
|
there will be min(B, C) numbers), and mark the corresponding position of the
|
||||||
|
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
|
||||||
|
each index.
|
||||||
|
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
|
||||||
|
|
||||||
|
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SimilarityFocusOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
|
||||||
|
ctx->SetOutputDim("Out", x_dims);
|
||||||
|
ctx->ShareLoD("X", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
|
||||||
|
platform::CPUPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
|
||||||
|
ops::SimilarityFocusOpMaker,
|
||||||
|
paddle::framework::EmptyGradOpMaker);
|
||||||
|
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
|
||||||
|
ops::SimilarityFocusKernel<double>);
|
@ -0,0 +1,168 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class SimilarityFocusKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
Tensor* out = context.Output<Tensor>("Out");
|
||||||
|
const Tensor* x = context.Input<Tensor>("X");
|
||||||
|
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||||
|
const T* x_data = x->data<T>();
|
||||||
|
|
||||||
|
int axis = context.Attr<int>("axis");
|
||||||
|
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
|
||||||
|
|
||||||
|
int64_t batch_size = x->dims()[0];
|
||||||
|
int64_t dim[4];
|
||||||
|
for (int i = 1; i <= 3; ++i) {
|
||||||
|
dim[i] = x->dims()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indexes.size() < 1) {
|
||||||
|
PADDLE_THROW("Indexes' size can not be 0.");
|
||||||
|
}
|
||||||
|
for (auto index : indexes) {
|
||||||
|
if (dim[axis] < index) {
|
||||||
|
PADDLE_THROW("Index exceeds tensor shape limit.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t array_size = 1;
|
||||||
|
for (int i = 1; i <= 3; ++i) {
|
||||||
|
if (i != axis) {
|
||||||
|
array_size *= dim[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<T, int64_t>> array(array_size);
|
||||||
|
|
||||||
|
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
|
||||||
|
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
|
||||||
|
return x.first > y.first;
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
|
||||||
|
int64_t* dim, int d1, int d2, int d3, int d4) {
|
||||||
|
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
|
||||||
|
d3 * dim[3] + d4;
|
||||||
|
};
|
||||||
|
|
||||||
|
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
for (auto index : indexes) {
|
||||||
|
if (axis == 1) {
|
||||||
|
for (int j = 0; j < dim[2]; ++j) {
|
||||||
|
for (int k = 0; k < dim[3]; ++k) {
|
||||||
|
array[j * dim[3] + k] = std::make_pair(
|
||||||
|
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(array.begin(), array.end(), cmp);
|
||||||
|
int tag_num = 0;
|
||||||
|
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
|
||||||
|
for (auto x : array) {
|
||||||
|
int idx2 = x.second / dim[3];
|
||||||
|
int idx3 = x.second % dim[3];
|
||||||
|
if (tag2[idx2] || tag3[idx3]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tag_num++;
|
||||||
|
tag2[idx2] = true;
|
||||||
|
tag3[idx3] = true;
|
||||||
|
for (int j = 0; j < dim[1]; ++j) {
|
||||||
|
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
|
||||||
|
}
|
||||||
|
if (tag_num == std::min(dim[2], dim[3])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (axis == 2) {
|
||||||
|
for (int j = 0; j < dim[1]; ++j) {
|
||||||
|
for (int k = 0; k < dim[3]; ++k) {
|
||||||
|
array[j * dim[3] + k] = std::make_pair(
|
||||||
|
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(array.begin(), array.end(), cmp);
|
||||||
|
int tag_num = 0;
|
||||||
|
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
|
||||||
|
for (auto x : array) {
|
||||||
|
int idx1 = x.second / dim[3];
|
||||||
|
int idx3 = x.second % dim[3];
|
||||||
|
if (tag1[idx1] || tag3[idx3]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tag_num++;
|
||||||
|
tag1[idx1] = true;
|
||||||
|
tag3[idx3] = true;
|
||||||
|
for (int j = 0; j < dim[2]; ++j) {
|
||||||
|
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
|
||||||
|
}
|
||||||
|
if (tag_num == std::min(dim[1], dim[3])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (axis == 3) {
|
||||||
|
for (int j = 0; j < dim[1]; ++j) {
|
||||||
|
for (int k = 0; k < dim[2]; ++k) {
|
||||||
|
array[j * dim[2] + k] = std::make_pair(
|
||||||
|
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(array.begin(), array.end(), cmp);
|
||||||
|
int tag_num = 0;
|
||||||
|
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
|
||||||
|
for (auto x : array) {
|
||||||
|
int idx1 = x.second / dim[2];
|
||||||
|
int idx2 = x.second % dim[2];
|
||||||
|
if (tag1[idx1] || tag2[idx2]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tag_num++;
|
||||||
|
tag1[idx1] = true;
|
||||||
|
tag2[idx2] = true;
|
||||||
|
for (int j = 0; j < dim[3]; ++j) {
|
||||||
|
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
|
||||||
|
}
|
||||||
|
if (tag_num == std::min(dim[1], dim[2])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW("Axis must be 1 or 2 or 3");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,217 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 2
|
||||||
|
x_dim, y_dim, z_dim = 3, 2, 2
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.array([[[[0.8, 0.1], [0.4, 0.5]], [[0.9, 0.7], [0.9, 0.9]],
|
||||||
|
[[0.8, 0.9], [0.1, 0.2]]],
|
||||||
|
[[[0.2, 0.5], [0.3, 0.4]], [[0.9, 0.7], [0.8, 0.4]],
|
||||||
|
[[0.0, 0.2], [0.4, 0.7]]]]),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 1,
|
||||||
|
'indexes': [0],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(y_dim)]
|
||||||
|
tag2 = [0 for i in range(z_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // z_dim
|
||||||
|
idx2 = index % z_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(y_dim, z_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(1, y_dim, z_dim).repeat([x_dim], axis=0)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp_axis1(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 3
|
||||||
|
x_dim, y_dim, z_dim = 4, 5, 6
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.random(
|
||||||
|
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 1,
|
||||||
|
'indexes': [0, 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(y_dim)]
|
||||||
|
tag2 = [0 for i in range(z_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // z_dim
|
||||||
|
idx2 = index % z_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(y_dim, z_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(1, y_dim, z_dim)
|
||||||
|
res = res.repeat([x_dim], axis=0)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp_axis2(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 6
|
||||||
|
x_dim, y_dim, z_dim = 7, 8, 9
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.random(
|
||||||
|
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 2,
|
||||||
|
'indexes': [0, 3, 5],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(x_dim)]
|
||||||
|
tag2 = [0 for i in range(z_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // z_dim
|
||||||
|
idx2 = index % z_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(x_dim, z_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(x_dim, 1, z_dim)
|
||||||
|
res = res.repeat([y_dim], axis=1)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp_axis3(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 64
|
||||||
|
x_dim, y_dim, z_dim = 48, 48, 13
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.random(
|
||||||
|
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 3,
|
||||||
|
'indexes': [0, 2, 7, 9],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(x_dim)]
|
||||||
|
tag2 = [0 for i in range(y_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // y_dim
|
||||||
|
idx2 = index % y_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(x_dim, y_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(x_dim, y_dim, 1)
|
||||||
|
res = res.repeat([z_dim], axis=2)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue