add new api for Paddle2.0: nonzero, index_selct, roll, cross (#23176)
parent
f11af6a935
commit
2e4196f647
@ -0,0 +1,169 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/cross_op.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using framework::DDim;
|
||||
|
||||
class CrossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X) of CrossOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Index) of CrossOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of CrossOp should not be null."));
|
||||
|
||||
auto x_dim = ctx->GetInputDim("X");
|
||||
auto y_dim = ctx->GetInputDim("Y");
|
||||
auto dim = ctx->Attrs().Get<int>("dim");
|
||||
|
||||
bool dims_match = CheckDims(x_dim, y_dim);
|
||||
PADDLE_ENFORCE_EQ(dims_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The 'shape' of Input(X) should be equal to "
|
||||
"the 'shape' of Input(Y). But received "
|
||||
"Input(X).dimensions = [%s], "
|
||||
"Input(Y).dimensions = [%s]",
|
||||
x_dim, y_dim));
|
||||
|
||||
if (dim != kDefaultDim) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim < x_dim.size() && dim >= (0 - x_dim.size()), true,
|
||||
platform::errors::OutOfRange(
|
||||
"Attr(dim) is out of range, It's expected "
|
||||
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
|
||||
x_dim.size(), x_dim.size() - 1, dim));
|
||||
if (dim < 0) {
|
||||
dim += x_dim.size();
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(x_dim[dim] == 3 && y_dim[dim] == 3, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X/Y).dims()[dim] should be equal to 3."
|
||||
"But received Input(X/Y).dims()[dim] = %d.",
|
||||
x_dim[dim]));
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", x_dim);
|
||||
auto type = ctx->GetInputsVarType("X")[0];
|
||||
if (type == framework::proto::VarType::LOD_TENSOR) {
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class CrossGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument("Input(X) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Y"), true,
|
||||
platform::errors::InvalidArgument("Input(Y) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Out@GRAD) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(X@GRAD) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Y")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Y@GRAD) should be not null."));
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class CrossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) the input tensor.");
|
||||
AddInput("Y", "(Tensor) the second input tensor.");
|
||||
AddOutput("Out", "(Tensor), the output tensor.");
|
||||
AddAttr<int>("dim", "the dimension to take the cross-product in.")
|
||||
.SetDefault(kDefaultDim);
|
||||
AddComment(R"DOC(
|
||||
Returns the cross product of vectors in dimension dim of
|
||||
input and other. Input and other must have the same size,
|
||||
and the size of their dim dimension should be 3.
|
||||
If dim is not given, it defaults to the first dimension
|
||||
found with the size 3.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CrossGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("cross_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Y", this->Input("Y"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(cross, ops::CrossOp, ops::CrossOpMaker,
|
||||
ops::CrossGradMaker<paddle::framework::OpDesc>,
|
||||
ops::CrossGradMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(cross_grad, ops::CrossGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
cross, ops::CrossKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::CrossKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::CrossKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::CrossKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
cross_grad, ops::CrossGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::CrossGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::CrossGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::CrossGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/cross_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
cross, ops::CrossKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::CrossKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::CrossKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::CrossKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
cross_grad,
|
||||
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::CrossGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,222 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DDim = framework::DDim;
|
||||
const int kDefaultDim = framework::DDim::kMaxRank;
|
||||
|
||||
inline bool CheckDims(const DDim& dims_x, const DDim& dims_y) {
|
||||
if (dims_x.size() != dims_y.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < dims_x.size(); i++) {
|
||||
if (dims_x[i] != dims_y[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CrossKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input_x_var = context.InputVar("X");
|
||||
auto* input_y_var = context.InputVar("Y");
|
||||
auto* output_var = context.OutputVar("Out");
|
||||
|
||||
auto& input_x = input_x_var->Get<LoDTensor>();
|
||||
auto& input_y = input_y_var->Get<LoDTensor>();
|
||||
auto* output = output_var->GetMutable<LoDTensor>();
|
||||
int dim = context.Attr<int>("dim");
|
||||
|
||||
auto input_x_dims = input_x.dims();
|
||||
auto input_y_dims = input_y.dims();
|
||||
bool dims_match = CheckDims(input_x_dims, input_y_dims);
|
||||
PADDLE_ENFORCE_EQ(dims_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The 'shape' of Input(X) should be equal to "
|
||||
"the 'shape' of Input(Y). But received "
|
||||
"Input(X).dimensions = [%s], "
|
||||
"Input(Y).dimensions = [%s]",
|
||||
input_x_dims, input_x_dims));
|
||||
|
||||
if (dim != kDefaultDim) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), true,
|
||||
platform::errors::OutOfRange(
|
||||
"Attr(dim) is out of range, It's expected "
|
||||
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
|
||||
input_x_dims.size(), input_x_dims.size() - 1, dim));
|
||||
if (dim < 0) {
|
||||
dim += input_x_dims.size();
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_x_dims[dim] == 3, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X/Y).dims[dim] must be equal to 3. But received: "
|
||||
"Input(X/Y).dims[dim] = [%d].",
|
||||
input_x_dims[dim]));
|
||||
} else {
|
||||
for (auto i = 0; i < input_x_dims.size(); i++) {
|
||||
if (input_x_dims[i] == 3) {
|
||||
dim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(dim == kDefaultDim, false,
|
||||
platform::errors::InvalidArgument(
|
||||
"There must be at least one dimension 'd' so that "
|
||||
"Input(X/Y).dims()[d] is equal to 3. "
|
||||
"But received: Input(X/Y).dims() == [%s].",
|
||||
input_x_dims));
|
||||
}
|
||||
auto outer_loops = 1;
|
||||
for (auto i = 0; i < dim; i++) {
|
||||
outer_loops *= input_x_dims[i];
|
||||
}
|
||||
auto slice_size = 1;
|
||||
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
|
||||
slice_size *= input_x_dims[i];
|
||||
}
|
||||
|
||||
std::vector<T> input_x_vec, input_y_vec;
|
||||
framework::TensorToVector(input_x, context.device_context(), &input_x_vec);
|
||||
framework::TensorToVector(input_y, context.device_context(), &input_y_vec);
|
||||
std::vector<T> out_vec(output->numel());
|
||||
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (auto i = 0; i < outer_loops; i++) {
|
||||
for (auto j = 0; j < 3; j++) {
|
||||
auto dst_pos = (3 * i + j) * slice_size;
|
||||
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
|
||||
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
|
||||
|
||||
for (auto k = 0; k < slice_size; k++) {
|
||||
out_vec[dst_pos + k] =
|
||||
input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] -
|
||||
input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
framework::TensorFromVector(out_vec, context.device_context(), output);
|
||||
output->Resize(input_x_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CrossGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input_x_var = context.InputVar("X");
|
||||
auto* input_y_var = context.InputVar("Y");
|
||||
auto* input_out_grad_var = context.InputVar(framework::GradVarName("Out"));
|
||||
auto* output_x_grad_var = context.OutputVar(framework::GradVarName("X"));
|
||||
auto* output_y_grad_var = context.OutputVar(framework::GradVarName("Y"));
|
||||
|
||||
auto& input_x = input_x_var->Get<LoDTensor>();
|
||||
auto& input_y = input_y_var->Get<LoDTensor>();
|
||||
auto& input_out_grad = input_out_grad_var->Get<LoDTensor>();
|
||||
auto* output_x_grad = output_x_grad_var->GetMutable<LoDTensor>();
|
||||
auto* output_y_grad = output_y_grad_var->GetMutable<LoDTensor>();
|
||||
|
||||
int dim = context.Attr<int>("dim");
|
||||
auto input_x_dims = input_x.dims();
|
||||
if (dim != kDefaultDim) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()), true,
|
||||
platform::errors::OutOfRange(
|
||||
"Attr(dim) is out of range, It's expected "
|
||||
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
|
||||
input_x_dims.size(), input_x_dims.size() - 1, dim));
|
||||
if (dim < 0) {
|
||||
dim += input_x_dims.size();
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_x_dims[dim] == 3, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X/Y).dims[dim] must be equal to 3. But received: "
|
||||
"Input(X/Y).dims[dim] = [%d].",
|
||||
input_x_dims[dim]));
|
||||
} else {
|
||||
for (auto i = 0; i < input_x_dims.size(); i++) {
|
||||
if (input_x_dims[i] == 3) {
|
||||
dim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(dim == kDefaultDim, false,
|
||||
platform::errors::InvalidArgument(
|
||||
"There must be at least one dimension 'd' "
|
||||
"so that Input(X/Y).dims()[d] is equal to 3. "
|
||||
"But received: Input(X/Y).dims() == [%s].",
|
||||
input_x_dims));
|
||||
}
|
||||
auto outer_loops = 1;
|
||||
for (auto i = 0; i < dim; i++) {
|
||||
outer_loops *= input_x_dims[i];
|
||||
}
|
||||
auto slice_size = 1;
|
||||
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
|
||||
slice_size *= input_x_dims[i];
|
||||
}
|
||||
|
||||
std::vector<T> input_x_vec, input_y_vec, input_dout_vec;
|
||||
framework::TensorToVector(input_x, context.device_context(), &input_x_vec);
|
||||
framework::TensorToVector(input_y, context.device_context(), &input_y_vec);
|
||||
framework::TensorToVector(input_out_grad, context.device_context(),
|
||||
&input_dout_vec);
|
||||
std::vector<T> out_dx_vec(output_x_grad->numel());
|
||||
std::vector<T> out_dy_vec(output_y_grad->numel());
|
||||
|
||||
output_x_grad->mutable_data<T>(context.GetPlace());
|
||||
output_y_grad->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (auto i = 0; i < outer_loops; i++) {
|
||||
for (auto j = 0; j < 3; j++) {
|
||||
auto dst_pos = (3 * i + j) * slice_size;
|
||||
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
|
||||
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
|
||||
for (auto k = 0; k < slice_size; k++) {
|
||||
out_dx_vec[dst_pos + k] =
|
||||
input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] -
|
||||
input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k];
|
||||
out_dy_vec[dst_pos + k] =
|
||||
input_dout_vec[in_pos1 + k] * input_x_vec[in_pos2 + k] -
|
||||
input_dout_vec[in_pos2 + k] * input_x_vec[in_pos1 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
framework::TensorFromVector(out_dx_vec, context.device_context(),
|
||||
output_x_grad);
|
||||
framework::TensorFromVector(out_dy_vec, context.device_context(),
|
||||
output_y_grad);
|
||||
output_x_grad->Resize(input_x_dims);
|
||||
output_y_grad->Resize(input_x_dims);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/index_select_op.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class IndexSelectOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X) of IndexSelectOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Index) of IndexSelectOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of IndexSelectOp should not be null."));
|
||||
|
||||
auto input_dim = ctx->GetInputDim("X");
|
||||
auto index_dim = ctx->GetInputDim("Index");
|
||||
auto dim = ctx->Attrs().Get<int>("dim");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim < input_dim.size() && dim >= (0 - input_dim.size()), true,
|
||||
platform::errors::OutOfRange(
|
||||
"Attr(dim) is out of range, It's expected "
|
||||
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
|
||||
input_dim.size(), input_dim.size() - 1, dim));
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
|
||||
true, platform::errors::InvalidArgument(
|
||||
"The 'shape' of Input(Index) must be 1-D tensor. "
|
||||
"But received: the 'shape' of Input(Index) is [%s], "
|
||||
"the dimension of Input(Index) is [%d].",
|
||||
index_dim, index_dim.size()));
|
||||
|
||||
auto output_dim = framework::vectorize(input_dim);
|
||||
if (dim < 0) {
|
||||
dim += input_dim.size();
|
||||
}
|
||||
output_dim[dim] = index_dim[0];
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_dim));
|
||||
auto type = ctx->GetInputsVarType("X")[0];
|
||||
if (type == framework::proto::VarType::LOD_TENSOR) {
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class IndexSelectGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Index"), true,
|
||||
platform::errors::InvalidArgument("Input(Index) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Out@GRAD) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(X@GRAD) should be not null."));
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class IndexSelectOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) the input tensor.");
|
||||
AddInput("Index", "the 1-D tensor containing the indices to index.");
|
||||
AddOutput("Out", "the output tensor.");
|
||||
AddAttr<int>("dim", "the dimension in which we index.").SetDefault(0);
|
||||
AddComment(R"DOC(
|
||||
Returns a new tensor which indexes the input tensor
|
||||
along dimension dim using the entries in index which
|
||||
is a Tensor.
|
||||
|
||||
The returned tensor has the same number of dimensions
|
||||
as the original tensor (input). The dim-th dimension
|
||||
has the same size as the length of index; other dimensions
|
||||
have the same size as in the original tensor.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("index_select_grad");
|
||||
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Index", this->Input("Index"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInference,
|
||||
"X");
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
|
||||
ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
|
||||
ops::IndexSelectGradMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
|
||||
ops::IndexSelectGradNoNeedBufferVarsInference);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
index_select,
|
||||
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
index_select_grad,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/index_select_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
index_select,
|
||||
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::IndexSelectKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
index_select_grad,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::IndexSelectGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,203 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename T, typename IndexT = int>
|
||||
void IndexSelectInner(const framework::ExecutionContext& context,
|
||||
const LoDTensor& input, const LoDTensor& index,
|
||||
LoDTensor* output, int dim) {
|
||||
auto input_dim = input.dims();
|
||||
auto input_dim_size = input_dim.size();
|
||||
auto output_dim = output->dims();
|
||||
|
||||
auto slice_size = 1;
|
||||
for (auto i = dim + 1; i < input_dim_size; i++) {
|
||||
slice_size *= input_dim[i];
|
||||
}
|
||||
|
||||
auto input_width = slice_size * input_dim[dim];
|
||||
auto output_width = slice_size * output_dim[dim];
|
||||
|
||||
auto outer_nums = 1;
|
||||
for (auto i = 0; i < dim; i++) {
|
||||
outer_nums *= input_dim[i];
|
||||
}
|
||||
|
||||
auto index_size = index.dims()[0];
|
||||
|
||||
std::vector<T> input_vec;
|
||||
std::vector<IndexT> index_vec;
|
||||
TensorToVector(input, context.device_context(), &input_vec);
|
||||
TensorToVector(index, context.device_context(), &index_vec);
|
||||
std::vector<T> out_vec(output->numel());
|
||||
|
||||
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
|
||||
<< "; slice_size: " << slice_size << "; input_width: " << input_width
|
||||
<< "; output_width: " << output_width
|
||||
<< "; index_size: " << index_size;
|
||||
|
||||
for (auto i = 0; i < outer_nums; i++) {
|
||||
auto input_start_offset = i * input_width;
|
||||
auto output_start_offset = i * output_width;
|
||||
|
||||
for (auto j = 0; j < index_size; j++) {
|
||||
IndexT index_value = index_vec[j];
|
||||
for (auto k = 0; k < slice_size; k++) {
|
||||
out_vec[output_start_offset + j * slice_size + k] =
|
||||
input_vec[input_start_offset + index_value * slice_size + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
framework::TensorFromVector(out_vec, context.device_context(), output);
|
||||
output->Resize(output_dim);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class IndexSelectKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* inputs_var = context.InputVar("X");
|
||||
auto* index_var = context.InputVar("Index");
|
||||
auto* output_var = context.OutputVar("Out");
|
||||
|
||||
auto& inputs = inputs_var->Get<LoDTensor>();
|
||||
auto& index = index_var->Get<LoDTensor>();
|
||||
auto* output = output_var->GetMutable<framework::LoDTensor>();
|
||||
|
||||
int dim = context.Attr<int>("dim");
|
||||
if (dim < 0) {
|
||||
dim += inputs.dims().size();
|
||||
}
|
||||
|
||||
const auto& index_type = index.type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(index_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Index) holds the wrong type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
IndexSelectInner<T, int>(context, inputs, index, output, dim);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
IndexSelectInner<T, int64_t>(context, inputs, index, output, dim);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IndexT = int>
|
||||
void IndexSelectGradInner(const framework::ExecutionContext& context,
|
||||
const LoDTensor& out_grad, const LoDTensor& index,
|
||||
LoDTensor* x_grad, int dim) {
|
||||
std::vector<T> input_vec;
|
||||
std::vector<IndexT> index_vec;
|
||||
TensorToVector(out_grad, context.device_context(), &input_vec);
|
||||
TensorToVector(index, context.device_context(), &index_vec);
|
||||
|
||||
auto input_dim = out_grad.dims();
|
||||
auto input_dim_size = input_dim.size();
|
||||
auto output_dim = x_grad->dims();
|
||||
std::vector<T> out_vec(x_grad->numel(), 0);
|
||||
|
||||
auto slice_size = 1;
|
||||
for (auto i = dim + 1; i < input_dim_size; i++) {
|
||||
slice_size *= input_dim[i];
|
||||
}
|
||||
|
||||
auto input_width = slice_size * input_dim[dim];
|
||||
auto output_width = slice_size * output_dim[dim];
|
||||
|
||||
auto outer_nums = 1;
|
||||
for (auto i = 0; i < dim; i++) {
|
||||
outer_nums *= input_dim[i];
|
||||
}
|
||||
|
||||
auto index_size = index.dims()[0];
|
||||
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
|
||||
<< "; slice_size: " << slice_size << "; input_width: " << input_width
|
||||
<< "; output_width: " << output_width
|
||||
<< "; index_size: " << index_size;
|
||||
|
||||
for (auto i = 0; i < outer_nums; i++) {
|
||||
auto input_start_offset = i * input_width;
|
||||
auto output_start_offset = i * output_width;
|
||||
|
||||
for (auto j = 0; j < index_size; j++) {
|
||||
IndexT index_value = index_vec[j];
|
||||
for (auto k = 0; k < slice_size; k++) {
|
||||
out_vec[output_start_offset + index_value * slice_size + k] +=
|
||||
input_vec[input_start_offset + j * slice_size + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
x_grad->mutable_data<T>(context.GetPlace());
|
||||
framework::TensorFromVector(out_vec, context.device_context(), x_grad);
|
||||
x_grad->Resize(output_dim);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class IndexSelectGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* index_var = context.InputVar("Index");
|
||||
auto* x_grad_var = context.OutputVar(framework::GradVarName("X"));
|
||||
auto* out_grad_var = context.InputVar(framework::GradVarName("Out"));
|
||||
|
||||
auto& index = index_var->Get<LoDTensor>();
|
||||
auto& out_grad = out_grad_var->Get<LoDTensor>();
|
||||
auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>();
|
||||
int dim = context.Attr<int>("dim");
|
||||
if (dim < 0) {
|
||||
dim += out_grad.dims().size();
|
||||
}
|
||||
|
||||
const auto& index_type = index.type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(index_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Index) holds the wrong type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
IndexSelectGradInner<T, int>(context, out_grad, index, x_grad, dim);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
IndexSelectGradInner<T, int64_t>(context, out_grad, index, x_grad, dim);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/roll_op.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class RollOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X) of RollOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of RollOp should not be null."));
|
||||
|
||||
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("dims");
|
||||
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
|
||||
|
||||
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"Attr(dims).size() should be equl to "
|
||||
"Attr(shifts).size(). But received "
|
||||
"Attr(dims).size() = %d, Attr(shifts).size() = %d",
|
||||
dims.size(), shifts.size()));
|
||||
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
auto type = ctx->GetInputsVarType("X")[0];
|
||||
if (type == framework::proto::VarType::LOD_TENSOR) {
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class RollGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Out@GRAD) should be not null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(X@GRAD) should be not null."));
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class RollOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) the input tensor.");
|
||||
AddOutput("Out", "(Tensor), the output tensor.");
|
||||
AddAttr<std::vector<int64_t>>("shifts",
|
||||
"The number of places by which the elements "
|
||||
"of the tensor are shifted.")
|
||||
.SetDefault({});
|
||||
AddAttr<std::vector<int64_t>>(
|
||||
"dims",
|
||||
"Axis along which to roll. It must have the same size "
|
||||
"with shifts.")
|
||||
.SetDefault({});
|
||||
AddComment(R"DOC(
|
||||
Roll the tensor along the given dimension(s).
|
||||
Elements that are shifted beyond the last position
|
||||
are re-introduced at the first position. If a dimension
|
||||
is not specified, the tensor will be flattened before
|
||||
rolling and then restored to the original shape.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RollGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("roll_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(RollGradNoNeedBufferVarsInference, "X");
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(roll, ops::RollOp, ops::RollOpMaker,
|
||||
ops::RollGradMaker<paddle::framework::OpDesc>,
|
||||
ops::RollGradMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(roll_grad, ops::RollGradOp,
|
||||
ops::RollGradNoNeedBufferVarsInference);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::RollKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::RollKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/roll_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::RollKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::RollKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,135 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename T>
|
||||
inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim,
|
||||
int64_t shift) {
|
||||
if (dim < 0) {
|
||||
dim += input_dim.size();
|
||||
}
|
||||
shift = shift % input_dim[dim];
|
||||
if (shift < 0) {
|
||||
shift += input_dim[dim];
|
||||
}
|
||||
|
||||
auto outer_loops = 1;
|
||||
for (auto i = 0; i < dim; i++) {
|
||||
outer_loops *= input_dim[i];
|
||||
}
|
||||
auto slice_width = 1;
|
||||
for (auto i = dim + 1; i < input_dim.size(); i++) {
|
||||
slice_width *= input_dim[i];
|
||||
}
|
||||
|
||||
VLOG(3) << "shift_along_dim_debug: input_dim: " << input_dim
|
||||
<< "; dim: " << dim << "; shift: " << shift
|
||||
<< "; outer_loops: " << outer_loops
|
||||
<< "; slice_width: " << slice_width;
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<T> head;
|
||||
auto head_size = slice_width * (input_dim[dim] - shift);
|
||||
head.resize(head_size);
|
||||
|
||||
for (auto i = 0; i < outer_loops; i++) {
|
||||
for (auto j = 0; j < head_size; j++) {
|
||||
head[j] = data[i * input_dim[dim] * slice_width + j];
|
||||
}
|
||||
for (auto j = input_dim[dim] - shift; j < input_dim[dim]; j++) {
|
||||
auto dst_pos = j - input_dim[dim] + shift;
|
||||
for (auto k = 0; k < slice_width; k++) {
|
||||
data[(i * input_dim[dim] + dst_pos) * slice_width + k] =
|
||||
data[(i * input_dim[dim] + j) * slice_width + k];
|
||||
}
|
||||
}
|
||||
for (auto j = 0; j < head_size; j++) {
|
||||
data[(i * input_dim[dim] + shift) * slice_width + j] = head[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RollKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input_var = context.InputVar("X");
|
||||
auto* output_var = context.OutputVar("Out");
|
||||
auto& input = input_var->Get<LoDTensor>();
|
||||
auto* output = output_var->GetMutable<LoDTensor>();
|
||||
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
|
||||
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
|
||||
|
||||
std::vector<T> out_vec;
|
||||
TensorToVector(input, context.device_context(), &out_vec);
|
||||
|
||||
size_t nums = shifts.size();
|
||||
const DDim input_dim = input.dims();
|
||||
|
||||
for (size_t i = 0; i < nums; i++) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
|
||||
platform::errors::OutOfRange(
|
||||
"Attr(dims[%d]) is out of range, It's expected "
|
||||
"to be in range of [-%d, %d]. But received Attr(dims[%d]) = %d.",
|
||||
i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
|
||||
shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
|
||||
}
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
framework::TensorFromVector(out_vec, context.device_context(), output);
|
||||
output->Resize(input_dim);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RollGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input_var = context.InputVar(framework::GradVarName("Out"));
|
||||
auto* output_var = context.OutputVar(framework::GradVarName("X"));
|
||||
auto& input = input_var->Get<LoDTensor>();
|
||||
auto* output = output_var->GetMutable<LoDTensor>();
|
||||
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
|
||||
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("dims");
|
||||
|
||||
std::vector<T> out_vec;
|
||||
TensorToVector(input, context.device_context(), &out_vec);
|
||||
|
||||
size_t nums = shifts.size();
|
||||
const DDim input_dim = input.dims();
|
||||
|
||||
for (size_t i = 0; i < nums; i++) {
|
||||
shift_along_dim(out_vec.data(), input_dim, dims[i], 0 - shifts[i]);
|
||||
}
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
framework::TensorFromVector(out_vec, context.device_context(), output);
|
||||
output->Resize(input_dim);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
class TestCrossOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cross"
|
||||
self.initTestCase()
|
||||
self.inputs = {
|
||||
'X': np.random.random(self.shape).astype(self.dtype),
|
||||
'Y': np.random.random(self.shape).astype(self.dtype)
|
||||
}
|
||||
self.init_output()
|
||||
|
||||
def initTestCase(self):
|
||||
self.attrs = {'dim': -2}
|
||||
self.dtype = np.float64
|
||||
self.shape = (1024, 3, 1)
|
||||
|
||||
def init_output(self):
|
||||
x = np.squeeze(self.inputs['X'], 2)
|
||||
y = np.squeeze(self.inputs['Y'], 2)
|
||||
z_list = []
|
||||
for i in range(1024):
|
||||
z_list.append(np.cross(x[i], y[i]))
|
||||
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out')
|
||||
|
||||
|
||||
class TestCrossOpCase1(TestCrossOp):
|
||||
def initTestCase(self):
|
||||
self.shape = (2048, 3)
|
||||
self.dtype = np.float32
|
||||
|
||||
def init_output(self):
|
||||
z_list = []
|
||||
for i in range(2048):
|
||||
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
|
||||
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
|
||||
|
||||
|
||||
class TestCrossAPI(unittest.TestCase):
|
||||
def input_data(self):
|
||||
self.data_x = np.array(
|
||||
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])
|
||||
self.data_y = np.array(
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
|
||||
def test_cross_api(self):
|
||||
self.input_data()
|
||||
|
||||
# case 1:
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 3])
|
||||
y = fluid.layers.data(name='y', shape=[-1, 3])
|
||||
z = paddle.cross(x, y, dim=1)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': self.data_x,
|
||||
'y': self.data_y},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
# case 2:
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 3])
|
||||
y = fluid.layers.data(name='y', shape=[-1, 3])
|
||||
z = paddle.cross(x, y)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': self.data_x,
|
||||
'y': self.data_y},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[-1.0, -1.0, -1.0], [2.0, 2.0, 2.0],
|
||||
[-1.0, -1.0, -1.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
def test_dygraph_api(self):
|
||||
self.input_data()
|
||||
# case 1:
|
||||
with fluid.dygraph.guard():
|
||||
x = fluid.dygraph.to_variable(self.data_x)
|
||||
y = fluid.dygraph.to_variable(self.data_y)
|
||||
z = paddle.cross(x, y)
|
||||
np_z = z.numpy()
|
||||
expect_out = np.array([[-1.0, -1.0, -1.0], [2.0, 2.0, 2.0],
|
||||
[-1.0, -1.0, -1.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np_z))
|
||||
|
||||
# case 2:
|
||||
with fluid.dygraph.guard():
|
||||
x = fluid.dygraph.to_variable(self.data_x)
|
||||
y = fluid.dygraph.to_variable(self.data_y)
|
||||
z = paddle.cross(x, y, dim=1)
|
||||
np_z = z.numpy()
|
||||
expect_out = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np_z))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,135 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
class TestIndexSelectOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "index_select"
|
||||
self.init_dtype_type()
|
||||
index_np = np.random.randint(
|
||||
low=0, high=self.x_shape[self.dim], size=self.index_size)
|
||||
x_np = np.random.random(self.x_shape).astype(self.x_type)
|
||||
self.inputs = {'X': x_np, 'Index': index_np}
|
||||
self.attrs = {'dim': self.dim}
|
||||
outer_loop = np.prod(self.x_shape[:self.dim])
|
||||
x_reshape = [outer_loop] + list(self.x_shape[self.dim:])
|
||||
x_np_reshape = np.reshape(x_np, tuple(x_reshape))
|
||||
out_list = []
|
||||
for i in range(outer_loop):
|
||||
for j in range(self.index_size):
|
||||
out_list.append(x_np_reshape[i, index_np[j]])
|
||||
self.out_shape = list(self.x_shape)
|
||||
self.out_shape[self.dim] = self.index_size
|
||||
self.out_shape = tuple(self.out_shape)
|
||||
|
||||
out = np.reshape(out_list, self.out_shape)
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def init_dtype_type(self):
|
||||
self.dim = 1
|
||||
self.x_type = np.float64
|
||||
self.index_type = np.int64
|
||||
self.x_shape = (100, 4, 5)
|
||||
self.index_size = 100
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestIndexSelectOpCase2(TestIndexSelectOp):
|
||||
def init_dtype_type(self):
|
||||
self.x_type = np.float32
|
||||
self.index_type = np.int32
|
||||
self.dim = -2
|
||||
self.x_shape = (10, 10, 4, 10)
|
||||
self.index_size = 10
|
||||
|
||||
|
||||
class TestIndexSelectAPI(unittest.TestCase):
|
||||
def input_data(self):
|
||||
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0, 12.0]])
|
||||
self.data_index = np.array([0, 1, 1]).astype('int32')
|
||||
|
||||
def test_index_select_api(self):
|
||||
self.input_data()
|
||||
|
||||
# case 1:
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 4])
|
||||
index = fluid.layers.data(
|
||||
name='index', shape=[3], dtype='int32', append_batch_size=False)
|
||||
z = paddle.index_select(x, index, dim=1)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': self.data_x,
|
||||
'index': self.data_index},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
|
||||
[9.0, 10.0, 10.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
# case 2:
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 4])
|
||||
index = fluid.layers.data(
|
||||
name='index', shape=[3], dtype='int32', append_batch_size=False)
|
||||
z = paddle.index_select(x, index)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': self.data_x,
|
||||
'index': self.data_index},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array(
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
def test_dygraph_api(self):
|
||||
self.input_data()
|
||||
# case 1:
|
||||
with fluid.dygraph.guard():
|
||||
x = fluid.dygraph.to_variable(self.data_x)
|
||||
index = fluid.dygraph.to_variable(self.data_index)
|
||||
z = paddle.index_select(x, index)
|
||||
np_z = z.numpy()
|
||||
expect_out = np.array(
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np_z))
|
||||
|
||||
# case 2:
|
||||
with fluid.dygraph.guard():
|
||||
x = fluid.dygraph.to_variable(self.data_x)
|
||||
index = fluid.dygraph.to_variable(self.data_index)
|
||||
z = paddle.index_select(x, index, dim=1)
|
||||
np_z = z.numpy()
|
||||
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
|
||||
[9.0, 10.0, 10.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np_z))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
class TestNonZeroAPI(unittest.TestCase):
|
||||
def test_nonzero_api_as_tuple(self):
|
||||
data = np.array([[True, False], [False, True]])
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 2])
|
||||
y = paddle.nonzero(x, as_tuple=True)
|
||||
self.assertEqual(type(y), tuple)
|
||||
self.assertEqual(len(y), 2)
|
||||
z = fluid.layers.concat(list(y), axis=1)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
|
||||
res, = exe.run(feed={'x': data},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[0, 0], [1, 1]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
data = np.array([True, True, False])
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1])
|
||||
y = paddle.nonzero(x, as_tuple=True)
|
||||
self.assertEqual(type(y), tuple)
|
||||
self.assertEqual(len(y), 1)
|
||||
z = fluid.layers.concat(list(y), axis=1)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': data},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[0], [1]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
def test_nonzero_api(self):
|
||||
data = np.array([[True, False], [False, True]])
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 2])
|
||||
y = paddle.nonzero(x)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': data},
|
||||
fetch_list=[y.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[0, 0], [1, 1]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
data = np.array([True, True, False])
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1])
|
||||
y = paddle.nonzero(x)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': data},
|
||||
fetch_list=[y.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[0], [1]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
class TestRollOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "roll"
|
||||
self.init_dtype_type()
|
||||
self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)}
|
||||
self.attrs = {'shifts': self.shifts, 'dims': self.dims}
|
||||
self.outputs = {
|
||||
'Out': np.roll(self.inputs['X'], self.attrs['shifts'],
|
||||
self.attrs['dims'])
|
||||
}
|
||||
|
||||
def init_dtype_type(self):
|
||||
self.dtype = np.float64
|
||||
self.x_shape = (100, 4, 5)
|
||||
self.shifts = [101, -1]
|
||||
self.dims = [0, -2]
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestRollOpCase2(TestRollOp):
|
||||
def init_dtype_type(self):
|
||||
self.dtype = np.float32
|
||||
self.x_shape = (100, 100, 5)
|
||||
self.shifts = [8, -1]
|
||||
self.dims = [-1, -2]
|
||||
|
||||
|
||||
class TestRollAPI(unittest.TestCase):
|
||||
def input_data(self):
|
||||
self.data_x = np.array(
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
|
||||
|
||||
def test_index_select_api(self):
|
||||
self.input_data()
|
||||
|
||||
# case 1:
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 3])
|
||||
z = paddle.roll(x, shifts=1)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': self.data_x},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[9.0, 1.0, 2.0], [3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
# case 2:
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(name='x', shape=[-1, 3])
|
||||
z = paddle.roll(x, shifts=1, dims=0)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
res, = exe.run(feed={'x': self.data_x},
|
||||
fetch_list=[z.name],
|
||||
return_numpy=False)
|
||||
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np.array(res)))
|
||||
|
||||
def test_dygraph_api(self):
|
||||
self.input_data()
|
||||
# case 1:
|
||||
with fluid.dygraph.guard():
|
||||
x = fluid.dygraph.to_variable(self.data_x)
|
||||
z = paddle.roll(x, shifts=1)
|
||||
np_z = z.numpy()
|
||||
expect_out = np.array([[9.0, 1.0, 2.0], [3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np_z))
|
||||
|
||||
# case 2:
|
||||
with fluid.dygraph.guard():
|
||||
x = fluid.dygraph.to_variable(self.data_x)
|
||||
z = paddle.roll(x, shifts=1, dims=0)
|
||||
np_z = z.numpy()
|
||||
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]])
|
||||
self.assertTrue(np.allclose(expect_out, np_z))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue