parent
638d924d89
commit
426912df5a
@ -0,0 +1,154 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/index_sample_op.h"
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "Input(Tensor), dtype support int32/int64/float/double");
|
||||
AddInput("Index", "Index(Tensor), dtype support int32/int64");
|
||||
AddOutput("Out", "Return the element of input at index");
|
||||
|
||||
AddComment(R"DOC(
|
||||
IndexSample OP returns the element of the specified location of X,
|
||||
and the location is specified by Index.
|
||||
|
||||
X tensor and Index tensor's shape must be 2-D,
|
||||
dimension at 0 which usually is batch size must be equal.
|
||||
|
||||
The returned tensor has the same shape and dimensions as the Index tensor.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class IndexSampleOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(Input) of FindByIndex should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(Index) of FindByIndex should not be null."));
|
||||
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_dims.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(X) shape of IndexSample op should be 2-D, but "
|
||||
"got X's shape = [%s], please check X shape.",
|
||||
input_dims));
|
||||
|
||||
auto index_dims = ctx->GetInputDim("Index");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_dims.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(Index) shape of IndexSample op should be 2-D, but "
|
||||
"got Index's shape [%s] , please check index shape.",
|
||||
input_dims));
|
||||
if (ctx->IsRuntime()) {
|
||||
PADDLE_ENFORCE_EQ(input_dims[0], index_dims[0],
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(X)'s value of dimension 0 must same with "
|
||||
"Inputs(Index)'s value of dimension 0, but "
|
||||
"got %d of Inputs(X), and got %d of Inputs(Index), "
|
||||
"please check Inputs shape.",
|
||||
input_dims[0], index_dims[0]));
|
||||
}
|
||||
ctx->SetOutputDim("Out", index_dims);
|
||||
auto type = ctx->GetInputsVarType("Index")[0];
|
||||
if (type == framework::proto::VarType::LOD_TENSOR) {
|
||||
ctx->ShareLoD("Index", /*->*/ "Out");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class IndexSampleGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Index"), true,
|
||||
platform::errors::InvalidArgument("Input(Index) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Out@GRAD) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(X@GRAD) should be not null."));
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out"));
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IndexSampleGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("index_sample_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Index", this->Input("Index"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X");
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker,
|
||||
ops::IndexSampleGradMaker<paddle::framework::OpDesc>,
|
||||
ops::IndexSampleGradMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp,
|
||||
ops::IndexSampleGradNoNeedBufferVarInferer);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
index_sample, ops::IndexSampleKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::IndexSampleKernel<paddle::platform::CPUPlace, double>,
|
||||
ops::IndexSampleKernel<paddle::platform::CPUPlace, int>,
|
||||
ops::IndexSampleKernel<paddle::platform::CPUPlace, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
index_sample_grad,
|
||||
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, double>,
|
||||
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, int>,
|
||||
ops::IndexSampleGradKernel<paddle::platform::CPUPlace, int64_t>);
|
@ -0,0 +1,186 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename T, typename IndexT = int>
|
||||
void IndexSampleInner(const framework::ExecutionContext &context,
|
||||
const LoDTensor &input, const LoDTensor &index,
|
||||
LoDTensor *output) {
|
||||
auto input_dims = input.dims();
|
||||
auto index_dims = index.dims();
|
||||
|
||||
int batch_size = input_dims[0];
|
||||
auto value_length = input_dims[1];
|
||||
auto index_length = index_dims[1];
|
||||
int index_ids_num = index.numel();
|
||||
auto *input_data = input.data<T>();
|
||||
auto *index_data = index.data<IndexT>();
|
||||
|
||||
std::vector<T> res{};
|
||||
for (int i = 0; i < index_ids_num; i++) {
|
||||
int b = floor(i / index_length);
|
||||
PADDLE_ENFORCE_GE(
|
||||
index_data[i], 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Variable value (index) of OP(index_sample) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
value_length, index_data[i]));
|
||||
PADDLE_ENFORCE_LT(
|
||||
index_data[i], value_length,
|
||||
platform::errors::InvalidArgument(
|
||||
"Variable value (index) of OP(index_sample) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
value_length, index_data[i]));
|
||||
|
||||
int v_i = b * value_length + static_cast<int>(index_data[i]);
|
||||
T v = input_data[v_i];
|
||||
VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i
|
||||
<< " value = " << v;
|
||||
res.push_back(v);
|
||||
}
|
||||
|
||||
auto ddim = framework::make_ddim({batch_size, index_length});
|
||||
output->Resize(ddim);
|
||||
T *out_data = output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
memcpy(out_data, &res[0], sizeof(T) * index_ids_num);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class IndexSampleKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *input_var = ctx.InputVar("X");
|
||||
auto *index_var = ctx.InputVar("Index");
|
||||
|
||||
auto &input_tensor = input_var->Get<LoDTensor>();
|
||||
auto &index_tensor = index_var->Get<LoDTensor>();
|
||||
|
||||
auto *out_var = ctx.OutputVar("Out");
|
||||
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
|
||||
|
||||
const auto &index_type = index_tensor.type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(index_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Index) holds the wrong type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
IndexSampleInner<T, int>(ctx, input_tensor, index_tensor, out_tensor);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
IndexSampleInner<T, int64_t>(ctx, input_tensor, index_tensor, out_tensor);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IndexT = int>
|
||||
void IndexSampleGradInner(const framework::ExecutionContext &context,
|
||||
const LoDTensor &out_grad, const LoDTensor &index,
|
||||
LoDTensor *x_grad) {
|
||||
auto index_dims = index.dims();
|
||||
auto x_grad_dims = x_grad->dims();
|
||||
|
||||
int batch_size = x_grad_dims[0];
|
||||
auto value_length = x_grad_dims[1];
|
||||
auto index_length = index_dims[1];
|
||||
int index_ids_num = index.numel();
|
||||
|
||||
T *x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
|
||||
auto *out_grad_data = out_grad.data<T>();
|
||||
auto *index_data = index.data<IndexT>();
|
||||
|
||||
memset(x_grad_data, 0, batch_size * value_length * sizeof(T));
|
||||
|
||||
for (int i = 0; i < index_ids_num; i++) {
|
||||
int b = floor(i / index_length);
|
||||
PADDLE_ENFORCE_GE(
|
||||
index_data[i], 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Variable value (index) of OP(index_sample_grad) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
value_length, index_data[i]));
|
||||
PADDLE_ENFORCE_LT(
|
||||
index_data[i], value_length,
|
||||
platform::errors::InvalidArgument(
|
||||
"Variable value (index) of OP(index_sample_grad) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
value_length, index_data[i]));
|
||||
int v_i = b * value_length + static_cast<int>(index_data[i]);
|
||||
x_grad_data[v_i] += out_grad_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class IndexSampleGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *index_var = context.InputVar("Index");
|
||||
auto *x_grad_var = context.OutputVar(framework::GradVarName("X"));
|
||||
auto *out_grad_var = context.InputVar(framework::GradVarName("Out"));
|
||||
|
||||
auto &index_tensor = index_var->Get<LoDTensor>();
|
||||
auto &out_grad_tensor = out_grad_var->Get<LoDTensor>();
|
||||
auto *x_grad_tensor = x_grad_var->GetMutable<framework::LoDTensor>();
|
||||
|
||||
const auto &index_type = index_tensor.type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(index_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Index) holds the wrong type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
IndexSampleGradInner<T, int>(context, out_grad_tensor, index_tensor,
|
||||
x_grad_tensor);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
IndexSampleGradInner<T, int64_t>(context, out_grad_tensor, index_tensor,
|
||||
x_grad_tensor);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestIndexSampleOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "index_sample"
|
||||
self.config()
|
||||
xnp = np.random.random(self.x_shape).astype(self.x_type)
|
||||
indexnp = np.random.randint(
|
||||
low=0, high=self.x_shape[1],
|
||||
size=self.index_shape).astype(self.index_type)
|
||||
self.inputs = {'X': xnp, 'Index': indexnp}
|
||||
index_array = []
|
||||
for i in range(self.index_shape[0]):
|
||||
for j in indexnp[i]:
|
||||
index_array.append(xnp[i, j])
|
||||
out = np.reshape(index_array, self.index_shape)
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def config(self):
|
||||
"""
|
||||
For multi-dimension input
|
||||
"""
|
||||
self.x_shape = (10, 20)
|
||||
self.x_type = "float64"
|
||||
self.index_shape = (10, 10)
|
||||
self.index_type = "int32"
|
||||
|
||||
|
||||
class TestCase1(TestIndexSampleOp):
|
||||
def config(self):
|
||||
"""
|
||||
For one dimension input
|
||||
"""
|
||||
self.x_shape = (100, 1)
|
||||
self.x_type = "float64"
|
||||
self.index_shape = (100, 1)
|
||||
self.index_type = "int32"
|
||||
|
||||
|
||||
class TestCase2(TestIndexSampleOp):
|
||||
def config(self):
|
||||
"""
|
||||
For int64_t index type
|
||||
"""
|
||||
self.x_shape = (10, 100)
|
||||
self.x_type = "float64"
|
||||
self.index_shape = (10, 10)
|
||||
self.index_type = "int64"
|
||||
|
||||
|
||||
class TestCase3(TestIndexSampleOp):
|
||||
def config(self):
|
||||
"""
|
||||
For int index type
|
||||
"""
|
||||
self.x_shape = (10, 100)
|
||||
self.x_type = "float64"
|
||||
self.index_shape = (10, 10)
|
||||
self.index_type = "int32"
|
||||
|
||||
|
||||
class TestCase4(TestIndexSampleOp):
|
||||
def config(self):
|
||||
"""
|
||||
For int64 index type
|
||||
"""
|
||||
self.x_shape = (10, 100)
|
||||
self.x_type = "float64"
|
||||
self.index_shape = (10, 10)
|
||||
self.index_type = "int64"
|
||||
|
||||
|
||||
class TestIndexSampleShape(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
|
||||
# create x value
|
||||
x_shape = (2, 5)
|
||||
x_type = "float64"
|
||||
x_np = np.random.random(x_shape).astype(x_type)
|
||||
|
||||
# create index value
|
||||
index_shape = (2, 3)
|
||||
index_type = "int32"
|
||||
index_np = np.random.randint(
|
||||
low=0, high=x_shape[1], size=index_shape).astype(index_type)
|
||||
|
||||
x = fluid.data(name='x', shape=[-1, 5], dtype='float64')
|
||||
index = fluid.data(name='index', shape=[-1, 3], dtype='int32')
|
||||
output = paddle.index_sample(x=x, index=index)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place=place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
feed = {'x': x_np, 'index': index_np}
|
||||
res = exe.run(feed=feed, fetch_list=[output])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue