commit
9b558a8035
@ -0,0 +1,87 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/similarity_focus_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
class SimilarityFocusOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X",
|
||||||
|
"(Tensor, default Tensor<float>), a 4-D tensor with shape,"
|
||||||
|
" [BatchSize, X, Y, Z]");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(Tensor, default Tensor<float>), the similarity focus mask"
|
||||||
|
" with the same shape of input X.");
|
||||||
|
AddAttr<int>("axis",
|
||||||
|
"(int32), indicating the dimension to be select. It can"
|
||||||
|
" only be 1, 2, or 3.");
|
||||||
|
AddAttr<std::vector<int>>("indexes",
|
||||||
|
"(std::vector<int32>), indicating the indexes"
|
||||||
|
" of the selected dimension.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
SimilarityFocus Operator.
|
||||||
|
|
||||||
|
Generate a similarity focus mask with the same shape of input using the following method:
|
||||||
|
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
|
||||||
|
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
|
||||||
|
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
|
||||||
|
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
|
||||||
|
2. For each index, find the largest numbers in the tensor T, so that the same
|
||||||
|
row and same column has at most one number(what it means is that if the
|
||||||
|
largest number has been found in the i-th row and the j-th column, then
|
||||||
|
the numbers in the i-th row or j-th column will be skipped. And then the
|
||||||
|
next largest number will be selected from the remaining numbers. Obviously
|
||||||
|
there will be min(B, C) numbers), and mark the corresponding position of the
|
||||||
|
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
|
||||||
|
each index.
|
||||||
|
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
|
||||||
|
|
||||||
|
Refer to `Similarity Focus Layer <http://www.aclweb.org/anthology/N16-1108>`_
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SimilarityFocusOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4.");
|
||||||
|
ctx->SetOutputDim("Out", x_dims);
|
||||||
|
ctx->ShareLoD("X", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
|
||||||
|
platform::CPUPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(similarity_focus, ops::SimilarityFocusOp,
|
||||||
|
ops::SimilarityFocusOpMaker,
|
||||||
|
paddle::framework::EmptyGradOpMaker);
|
||||||
|
REGISTER_OP_CPU_KERNEL(similarity_focus, ops::SimilarityFocusKernel<float>,
|
||||||
|
ops::SimilarityFocusKernel<double>);
|
@ -0,0 +1,168 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class SimilarityFocusKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
Tensor* out = context.Output<Tensor>("Out");
|
||||||
|
const Tensor* x = context.Input<Tensor>("X");
|
||||||
|
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||||
|
const T* x_data = x->data<T>();
|
||||||
|
|
||||||
|
int axis = context.Attr<int>("axis");
|
||||||
|
std::vector<int> indexes = context.Attr<std::vector<int>>("indexes");
|
||||||
|
|
||||||
|
int64_t batch_size = x->dims()[0];
|
||||||
|
int64_t dim[4];
|
||||||
|
for (int i = 1; i <= 3; ++i) {
|
||||||
|
dim[i] = x->dims()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indexes.size() < 1) {
|
||||||
|
PADDLE_THROW("Indexes' size can not be 0.");
|
||||||
|
}
|
||||||
|
for (auto index : indexes) {
|
||||||
|
if (dim[axis] < index) {
|
||||||
|
PADDLE_THROW("Index exceeds tensor shape limit.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t array_size = 1;
|
||||||
|
for (int i = 1; i <= 3; ++i) {
|
||||||
|
if (i != axis) {
|
||||||
|
array_size *= dim[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<T, int64_t>> array(array_size);
|
||||||
|
|
||||||
|
bool (*cmp)(std::pair<T, int64_t>, std::pair<T, int64_t>) = [](
|
||||||
|
std::pair<T, int64_t> x, std::pair<T, int64_t> y) {
|
||||||
|
return x.first > y.first;
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t (*compute_index)(int64_t*, int, int, int, int) = [](
|
||||||
|
int64_t* dim, int d1, int d2, int d3, int d4) {
|
||||||
|
return d1 * dim[1] * dim[2] * dim[3] + d2 * dim[2] * dim[3] +
|
||||||
|
d3 * dim[3] + d4;
|
||||||
|
};
|
||||||
|
|
||||||
|
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
for (auto index : indexes) {
|
||||||
|
if (axis == 1) {
|
||||||
|
for (int j = 0; j < dim[2]; ++j) {
|
||||||
|
for (int k = 0; k < dim[3]; ++k) {
|
||||||
|
array[j * dim[3] + k] = std::make_pair(
|
||||||
|
x_data[compute_index(dim, i, index, j, k)], j * dim[3] + k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(array.begin(), array.end(), cmp);
|
||||||
|
int tag_num = 0;
|
||||||
|
std::vector<bool> tag2(dim[2]), tag3(dim[3]);
|
||||||
|
for (auto x : array) {
|
||||||
|
int idx2 = x.second / dim[3];
|
||||||
|
int idx3 = x.second % dim[3];
|
||||||
|
if (tag2[idx2] || tag3[idx3]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tag_num++;
|
||||||
|
tag2[idx2] = true;
|
||||||
|
tag3[idx3] = true;
|
||||||
|
for (int j = 0; j < dim[1]; ++j) {
|
||||||
|
out_data[compute_index(dim, i, j, idx2, idx3)] = 1;
|
||||||
|
}
|
||||||
|
if (tag_num == std::min(dim[2], dim[3])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (axis == 2) {
|
||||||
|
for (int j = 0; j < dim[1]; ++j) {
|
||||||
|
for (int k = 0; k < dim[3]; ++k) {
|
||||||
|
array[j * dim[3] + k] = std::make_pair(
|
||||||
|
x_data[compute_index(dim, i, j, index, k)], j * dim[3] + k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(array.begin(), array.end(), cmp);
|
||||||
|
int tag_num = 0;
|
||||||
|
std::vector<bool> tag1(dim[1]), tag3(dim[3]);
|
||||||
|
for (auto x : array) {
|
||||||
|
int idx1 = x.second / dim[3];
|
||||||
|
int idx3 = x.second % dim[3];
|
||||||
|
if (tag1[idx1] || tag3[idx3]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tag_num++;
|
||||||
|
tag1[idx1] = true;
|
||||||
|
tag3[idx3] = true;
|
||||||
|
for (int j = 0; j < dim[2]; ++j) {
|
||||||
|
out_data[compute_index(dim, i, idx1, j, idx3)] = 1;
|
||||||
|
}
|
||||||
|
if (tag_num == std::min(dim[1], dim[3])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (axis == 3) {
|
||||||
|
for (int j = 0; j < dim[1]; ++j) {
|
||||||
|
for (int k = 0; k < dim[2]; ++k) {
|
||||||
|
array[j * dim[2] + k] = std::make_pair(
|
||||||
|
x_data[compute_index(dim, i, j, k, index)], j * dim[2] + k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(array.begin(), array.end(), cmp);
|
||||||
|
int tag_num = 0;
|
||||||
|
std::vector<bool> tag1(dim[1]), tag2(dim[2]);
|
||||||
|
for (auto x : array) {
|
||||||
|
int idx1 = x.second / dim[2];
|
||||||
|
int idx2 = x.second % dim[2];
|
||||||
|
if (tag1[idx1] || tag2[idx2]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tag_num++;
|
||||||
|
tag1[idx1] = true;
|
||||||
|
tag2[idx2] = true;
|
||||||
|
for (int j = 0; j < dim[3]; ++j) {
|
||||||
|
out_data[compute_index(dim, i, idx1, idx2, j)] = 1;
|
||||||
|
}
|
||||||
|
if (tag_num == std::min(dim[1], dim[2])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW("Axis must be 1 or 2 or 3");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,246 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/variable.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
void LodTensorArray2LodTensorVector(const framework::Scope &scope,
|
||||||
|
const std::string &base_name,
|
||||||
|
const std::string &lod_tensor_array_name,
|
||||||
|
std::vector<std::string> *res_names) {
|
||||||
|
auto &inx =
|
||||||
|
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
|
||||||
|
for (size_t i = 0; i < inx.size(); i++) {
|
||||||
|
std::string var_name = base_name + std::to_string(i);
|
||||||
|
framework::Variable *g_feed_value =
|
||||||
|
const_cast<framework::Scope &>(scope).Var(var_name);
|
||||||
|
auto &feed_input =
|
||||||
|
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
|
||||||
|
feed_input.ShareDataWith(inx[i]);
|
||||||
|
res_names->push_back(var_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LodTensorVectorResizeFromLodTensorArray(
|
||||||
|
const framework::Scope &scope, const std::string &base_name,
|
||||||
|
const std::string &lod_tensor_array_name,
|
||||||
|
std::vector<std::string> *res_names) {
|
||||||
|
auto &inx =
|
||||||
|
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
|
||||||
|
for (size_t i = 0; i < inx.size(); i++) {
|
||||||
|
std::string var_name = base_name + std::to_string(i);
|
||||||
|
framework::Variable *g_feed_value =
|
||||||
|
const_cast<framework::Scope &>(scope).Var(var_name);
|
||||||
|
auto &feed_input =
|
||||||
|
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
|
||||||
|
auto dims = inx[i].dims();
|
||||||
|
feed_input.Resize(dims);
|
||||||
|
res_names->push_back(var_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LodTensorArrayCreateFromLodTensorArray(
|
||||||
|
const framework::Scope &scope,
|
||||||
|
const std::string &input_lod_tensor_array_name,
|
||||||
|
const std::string &output_lod_tensor_array_name) {
|
||||||
|
auto &inx = scope.FindVar(input_lod_tensor_array_name)
|
||||||
|
->Get<framework::LoDTensorArray>();
|
||||||
|
auto &grad_inx = *scope.FindVar(output_lod_tensor_array_name)
|
||||||
|
->GetMutable<framework::LoDTensorArray>();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < inx.size(); i++) {
|
||||||
|
std::string var_name = output_lod_tensor_array_name + std::to_string(i);
|
||||||
|
framework::Variable *g_feed_value =
|
||||||
|
const_cast<framework::Scope &>(scope).Var(var_name);
|
||||||
|
auto &feed_input =
|
||||||
|
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
|
||||||
|
grad_inx.push_back(feed_input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class LoDTensorArray2TensorOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
using OperatorBase::OperatorBase;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void RunImpl(const framework::Scope &scope,
|
||||||
|
const platform::Place &place) const override {
|
||||||
|
auto axis = Attr<int>("axis");
|
||||||
|
|
||||||
|
framework::AttributeMap attrs;
|
||||||
|
attrs["axis"] = axis;
|
||||||
|
|
||||||
|
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
|
||||||
|
auto &out =
|
||||||
|
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
|
||||||
|
auto &out_inx =
|
||||||
|
*scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>();
|
||||||
|
|
||||||
|
const size_t n = inx.size();
|
||||||
|
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
|
||||||
|
|
||||||
|
std::string base_name = Inputs("X")[0];
|
||||||
|
std::vector<std::string> names;
|
||||||
|
|
||||||
|
// get the input tensorarray items' dim in out_inx
|
||||||
|
auto out_inx_dim = out_inx.dims();
|
||||||
|
out_inx_dim[0] = inx.size();
|
||||||
|
out_inx.Resize(out_inx_dim);
|
||||||
|
|
||||||
|
std::string var_name = "out_index";
|
||||||
|
framework::Variable *tmp_index_var =
|
||||||
|
const_cast<framework::Scope &>(scope).Var(var_name);
|
||||||
|
auto &tmp_index_tensor =
|
||||||
|
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
|
||||||
|
tmp_index_tensor.Resize(out_inx_dim);
|
||||||
|
int *tmp_index_data =
|
||||||
|
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
|
||||||
|
|
||||||
|
auto out_dims = inx[0].dims();
|
||||||
|
size_t out_dim_sum = 0;
|
||||||
|
for (size_t index = 0; index < inx.size(); index++) {
|
||||||
|
auto inx_dims = inx[index].dims();
|
||||||
|
out_dim_sum += inx_dims[axis];
|
||||||
|
tmp_index_data[index] = inx_dims[axis];
|
||||||
|
}
|
||||||
|
out_inx.ShareDataWith(tmp_index_tensor);
|
||||||
|
|
||||||
|
// get input array items' dims
|
||||||
|
out_dims[axis] = out_dim_sum;
|
||||||
|
out.Resize(out_dims);
|
||||||
|
|
||||||
|
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
|
||||||
|
// Invoke Reshape Op
|
||||||
|
auto concat_op = framework::OpRegistry::CreateOp(
|
||||||
|
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
|
||||||
|
|
||||||
|
concat_op->Run(scope, place);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "Input LoDTensorArray of tensor_array_to_tensor operator.");
|
||||||
|
AddOutput("Out", "Output tensor of tensor_array_to_tensor operator.");
|
||||||
|
AddOutput("OutIndex",
|
||||||
|
"Output input LoDTensorArray items' dims of "
|
||||||
|
"tensor_array_to_tensor operator.");
|
||||||
|
AddAttr<int>("axis",
|
||||||
|
"The axis along which the input tensors will be concatenated.")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
tensor_array_to_tensor Operator.
|
||||||
|
|
||||||
|
Concatenate the input LoDTensorArray along dimension axis to the output Tensor.
|
||||||
|
Examples:
|
||||||
|
Input = {[1,2], [3,4], [5,6]}
|
||||||
|
axis = 0
|
||||||
|
Output = [[1,2],
|
||||||
|
[3,4],
|
||||||
|
[5,6]]
|
||||||
|
OutputIndex = [1,1,1]
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LoDTensorArray2TensorOpInferShape : public framework::InferShapeBase {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferShapeContext *ctx) const override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferShapeContext *context) const override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LoDTensorArray2TensorGradInferVarType
|
||||||
|
: public framework::VarTypeInference {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::OpDesc &op_desc,
|
||||||
|
framework::BlockDesc *block) const override {
|
||||||
|
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) {
|
||||||
|
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
using OperatorBase::OperatorBase;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void RunImpl(const framework::Scope &scope,
|
||||||
|
const platform::Place &place) const override {
|
||||||
|
auto axis = Attr<int>("axis");
|
||||||
|
framework::AttributeMap attrs;
|
||||||
|
attrs["axis"] = axis;
|
||||||
|
|
||||||
|
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
|
||||||
|
const size_t n = inx.size();
|
||||||
|
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
|
||||||
|
|
||||||
|
std::string base_name = Inputs("X")[0];
|
||||||
|
std::vector<std::string> names;
|
||||||
|
|
||||||
|
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
|
||||||
|
|
||||||
|
// grad
|
||||||
|
auto dx_name = Output(framework::GradVarName("X"));
|
||||||
|
auto dout_name = Input(framework::GradVarName("Out"));
|
||||||
|
|
||||||
|
std::vector<std::string> grad_names;
|
||||||
|
|
||||||
|
LodTensorVectorResizeFromLodTensorArray(scope, "grad_name", Input("X"),
|
||||||
|
&grad_names);
|
||||||
|
|
||||||
|
auto concat_grad_op = framework::OpRegistry::CreateOp(
|
||||||
|
"concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}},
|
||||||
|
{{"X@GRAD", grad_names}}, attrs);
|
||||||
|
|
||||||
|
concat_grad_op->Run(scope, place);
|
||||||
|
|
||||||
|
LodTensorArrayCreateFromLodTensorArray(scope, Input("X"), dx_name);
|
||||||
|
auto &grad_inx =
|
||||||
|
*scope.FindVar(dx_name)->GetMutable<framework::LoDTensorArray>();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < grad_names.size(); i++) {
|
||||||
|
std::string var_name = grad_names[i];
|
||||||
|
auto &feed_input = scope.FindVar(var_name)->Get<framework::LoDTensor>();
|
||||||
|
grad_inx[i].ShareDataWith(feed_input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
USE_OP(concat);
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(tensor_array_to_tensor, ops::LoDTensorArray2TensorOp,
|
||||||
|
ops::LoDTensorArray2TensorOpMaker,
|
||||||
|
ops::LoDTensorArray2TensorOpInferShape,
|
||||||
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||||
|
REGISTER_OPERATOR(tensor_array_to_tensor_grad, ops::LoDTensorArray2TensorGradOp,
|
||||||
|
ops::LoDTensorArray2TensorGradInferShape,
|
||||||
|
ops::LoDTensorArray2TensorGradInferVarType);
|
@ -0,0 +1,217 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 2
|
||||||
|
x_dim, y_dim, z_dim = 3, 2, 2
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.array([[[[0.8, 0.1], [0.4, 0.5]], [[0.9, 0.7], [0.9, 0.9]],
|
||||||
|
[[0.8, 0.9], [0.1, 0.2]]],
|
||||||
|
[[[0.2, 0.5], [0.3, 0.4]], [[0.9, 0.7], [0.8, 0.4]],
|
||||||
|
[[0.0, 0.2], [0.4, 0.7]]]]),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 1,
|
||||||
|
'indexes': [0],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(y_dim)]
|
||||||
|
tag2 = [0 for i in range(z_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // z_dim
|
||||||
|
idx2 = index % z_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(y_dim, z_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(1, y_dim, z_dim).repeat([x_dim], axis=0)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp_axis1(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 3
|
||||||
|
x_dim, y_dim, z_dim = 4, 5, 6
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.random(
|
||||||
|
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 1,
|
||||||
|
'indexes': [0, 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((1, y_dim, z_dim)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, index, :, :].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(y_dim)]
|
||||||
|
tag2 = [0 for i in range(z_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // z_dim
|
||||||
|
idx2 = index % z_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(y_dim, z_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(1, y_dim, z_dim)
|
||||||
|
res = res.repeat([x_dim], axis=0)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp_axis2(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 6
|
||||||
|
x_dim, y_dim, z_dim = 7, 8, 9
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.random(
|
||||||
|
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 2,
|
||||||
|
'indexes': [0, 3, 5],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((x_dim, 1, z_dim)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, :, index, :].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(x_dim)]
|
||||||
|
tag2 = [0 for i in range(z_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // z_dim
|
||||||
|
idx2 = index % z_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(x_dim, z_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(x_dim, 1, z_dim)
|
||||||
|
res = res.repeat([y_dim], axis=1)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityFocusOp_axis3(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "similarity_focus"
|
||||||
|
batch_size = 64
|
||||||
|
x_dim, y_dim, z_dim = 48, 48, 13
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.random(
|
||||||
|
(batch_size, x_dim, y_dim, z_dim)).astype("float32"),
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'axis': 3,
|
||||||
|
'indexes': [0, 2, 7, 9],
|
||||||
|
}
|
||||||
|
|
||||||
|
output = None
|
||||||
|
for batch in range(batch_size):
|
||||||
|
res = np.zeros((x_dim, y_dim, 1)).astype("float32").reshape(-1)
|
||||||
|
for index in self.attrs['indexes']:
|
||||||
|
channel = self.inputs['X'][batch, :, :, index].reshape(-1).copy(
|
||||||
|
)
|
||||||
|
tag1 = [0 for i in range(x_dim)]
|
||||||
|
tag2 = [0 for i in range(y_dim)]
|
||||||
|
cnt = 0
|
||||||
|
for i in range(channel.size):
|
||||||
|
index = channel.argmax()
|
||||||
|
idx1 = index // y_dim
|
||||||
|
idx2 = index % y_dim
|
||||||
|
if tag1[idx1] + tag2[idx2] == 0:
|
||||||
|
tag1[idx1] = 1
|
||||||
|
tag2[idx2] = 1
|
||||||
|
res[index] = 1
|
||||||
|
cnt += 1
|
||||||
|
if cnt == min(x_dim, y_dim):
|
||||||
|
break
|
||||||
|
channel[index] = -1
|
||||||
|
res = res.reshape(x_dim, y_dim, 1)
|
||||||
|
res = res.repeat([z_dim], axis=2)
|
||||||
|
res = res.reshape(1, x_dim, y_dim, z_dim)
|
||||||
|
if output is not None:
|
||||||
|
output = np.concatenate((output, res), axis=0)
|
||||||
|
else:
|
||||||
|
output = res
|
||||||
|
self.outputs = {'Out': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from paddle.fluid.op import Operator
|
||||||
|
from paddle.fluid.executor import Executor
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoDTensorArrayConcat(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "tensor_array_to_tensor"
|
||||||
|
self.attrs = {"axis": 0}
|
||||||
|
self.outputs = ["Out"]
|
||||||
|
|
||||||
|
def test_get_set(self):
|
||||||
|
scope = core.Scope()
|
||||||
|
program = fluid.Program()
|
||||||
|
block = program.global_block()
|
||||||
|
|
||||||
|
input_arr = block.create_var(
|
||||||
|
name="tmp_lod_tensor_array",
|
||||||
|
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
|
||||||
|
input_arr.persistable = True
|
||||||
|
input_arr_var = scope.var('tmp_lod_tensor_array')
|
||||||
|
input_tensor_array = input_arr_var.get_lod_tensor_array()
|
||||||
|
self.assertEqual(0, len(input_tensor_array))
|
||||||
|
|
||||||
|
cpu = core.CPUPlace()
|
||||||
|
for i in range(10):
|
||||||
|
t = core.LoDTensor()
|
||||||
|
if i == 0:
|
||||||
|
t.set(numpy.array([[i], [i]], dtype='float32'), cpu)
|
||||||
|
else:
|
||||||
|
t.set(numpy.array([[i]], dtype='float32'), cpu)
|
||||||
|
input_tensor_array.append(t)
|
||||||
|
|
||||||
|
self.assertEqual(10, len(input_tensor_array))
|
||||||
|
|
||||||
|
random_grad = numpy.random.random_sample([11]).astype(numpy.float32)
|
||||||
|
|
||||||
|
y_out = block.create_var(name="Out")
|
||||||
|
y_out.persistable = True
|
||||||
|
y_out_index = block.create_var(name="OutIndex")
|
||||||
|
y_out_index.persistable = True
|
||||||
|
|
||||||
|
y_grad_arr = block.create_var(
|
||||||
|
name='Out@GRAD', dtype='float32', shape=[11])
|
||||||
|
y_grad_arr.persistable = True
|
||||||
|
y_grad = scope.var('Out@GRAD')
|
||||||
|
y_grad_tensor = y_grad.get_tensor()
|
||||||
|
y_grad_tensor.set(random_grad, cpu)
|
||||||
|
|
||||||
|
op = block.append_op(
|
||||||
|
type=self.op_type,
|
||||||
|
inputs={"X": input_arr},
|
||||||
|
outputs={"Out": y_out,
|
||||||
|
"OutIndex": y_out_index},
|
||||||
|
attrs=self.attrs)
|
||||||
|
|
||||||
|
out_grad = block.create_var(
|
||||||
|
name="tmp_lod_tensor_array@GRAD",
|
||||||
|
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
|
||||||
|
out_grad.persistable = True
|
||||||
|
|
||||||
|
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
|
||||||
|
set(), [])
|
||||||
|
grad_op_desc = grad_op_desc_list[0]
|
||||||
|
new_op_desc = block.desc.append_op()
|
||||||
|
new_op_desc.copy_from(grad_op_desc)
|
||||||
|
for var_name in grad_op_desc.output_arg_names():
|
||||||
|
block.desc.var(var_name.encode("ascii"))
|
||||||
|
|
||||||
|
grad_op_desc.infer_var_type(block.desc)
|
||||||
|
grad_op_desc.infer_shape(block.desc)
|
||||||
|
for arg in grad_op_desc.output_arg_names():
|
||||||
|
grad_var = block.desc.find_var(arg.encode("ascii"))
|
||||||
|
grad_var.set_dtype(core.VarDesc.VarType.FP32)
|
||||||
|
|
||||||
|
fetch_list = []
|
||||||
|
fetch_list.append(block.var('Out'))
|
||||||
|
fetch_list.append(block.var('OutIndex'))
|
||||||
|
|
||||||
|
exe = fluid.Executor(fluid.CPUPlace())
|
||||||
|
out = exe.run(program, fetch_list=fetch_list, scope=scope)
|
||||||
|
#print ("index: ", numpy.array(out[1]))
|
||||||
|
|
||||||
|
# test forward
|
||||||
|
tensor_res = numpy.array(out[0])
|
||||||
|
tensor_res_out_idx = numpy.array(out[1])
|
||||||
|
tensor_gt = numpy.array(
|
||||||
|
[0] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float32')
|
||||||
|
|
||||||
|
self.assertEqual(len(tensor_res), len(tensor_gt))
|
||||||
|
self.assertEqual(len(tensor_res_out_idx), 10)
|
||||||
|
|
||||||
|
for i in range(len(tensor_res)):
|
||||||
|
self.assertEqual(tensor_res[i], tensor_gt[i])
|
||||||
|
|
||||||
|
for i in range(len(tensor_res_out_idx)):
|
||||||
|
if i == 0:
|
||||||
|
self.assertEqual(tensor_res_out_idx[i], 2)
|
||||||
|
else:
|
||||||
|
self.assertEqual(tensor_res_out_idx[i], 1)
|
||||||
|
|
||||||
|
# test backward
|
||||||
|
grad_tensor = scope.var('tmp_lod_tensor_array@GRAD')
|
||||||
|
grad_tensor_array = grad_tensor.get_lod_tensor_array()
|
||||||
|
|
||||||
|
self.assertEqual(10, len(grad_tensor_array))
|
||||||
|
|
||||||
|
for i in range(len(grad_tensor_array)):
|
||||||
|
if i == 0:
|
||||||
|
self.assertEqual(
|
||||||
|
numpy.array(grad_tensor_array[i])[0],
|
||||||
|
numpy.array(random_grad[i]))
|
||||||
|
self.assertEqual(
|
||||||
|
numpy.array(grad_tensor_array[i])[1],
|
||||||
|
numpy.array(random_grad[i + 1]))
|
||||||
|
if i == 1:
|
||||||
|
self.assertEqual(
|
||||||
|
numpy.array(grad_tensor_array[i]),
|
||||||
|
numpy.array(random_grad[i + 1]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue