add gather_nd op and unit test (#19366)
* fixed the code for coverage * fixed the document,test=document_preview test=developfix_crf_doc
parent
ecd9f330c9
commit
85914f7a88
@ -0,0 +1,182 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/gather_nd_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class GatherNdOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
"Input(X) of GatherNdOp should not be null.");
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
|
||||
"Input(Index) of GatherNdOp should not be null.");
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
"Output(Out) of GatherNdOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto x_dims_size = x_dims.size();
|
||||
auto index_dims = ctx->GetInputDim("Index");
|
||||
auto index_dims_size = index_dims.size();
|
||||
|
||||
PADDLE_ENFORCE_LE(index_dims[index_dims_size - 1], x_dims_size,
|
||||
"Input(Index).shape[-1] <= Input(X).rank");
|
||||
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
|
||||
"The rank of Input(Index) should be greater than 1");
|
||||
|
||||
std::vector<int64_t> result_dims;
|
||||
// The result dims is
|
||||
// Index.shape[:-1] + X.shape[Index.shape[-1]:]
|
||||
for (int i = 0; i < index_dims_size - 1; ++i) {
|
||||
result_dims.emplace_back(index_dims[i]);
|
||||
}
|
||||
for (int i = index_dims[index_dims_size - 1]; i < x_dims_size; ++i) {
|
||||
result_dims.emplace_back(x_dims[i]);
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(result_dims));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class GatherNdGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class GatherNdOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "The source input of gather_nd op");
|
||||
AddInput("Index", "The index input of gather_nd op");
|
||||
AddOutput("Out", "The output of gather_nd op");
|
||||
AddComment(R"DOC(
|
||||
Gather_Nd Operator.
|
||||
|
||||
This function is actually a high-dimensional extension of gather
|
||||
and supports for simultaneous indexing by multiple axes. Out is
|
||||
obtained by gathering slices from X into a tensor with shape
|
||||
Index.shape[:-1] + X.shape[Index.shape[-1]:].
|
||||
|
||||
Example:
|
||||
|
||||
Given:
|
||||
X = [[[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15],
|
||||
[16, 17, 18, 19],
|
||||
[20, 21, 22, 23]]]
|
||||
|
||||
X.shape = (2, 3, 4)
|
||||
|
||||
*Case 1:
|
||||
|
||||
Index = [[1]]
|
||||
|
||||
we get:
|
||||
Out =
|
||||
[[12, 13, 14, 15],
|
||||
[16, 17, 18, 19],
|
||||
[20, 21, 22, 23]]
|
||||
|
||||
*Case 2:
|
||||
|
||||
Index = [[0,2]]
|
||||
|
||||
we get:
|
||||
|
||||
Out = [8, 9, 10, 11]
|
||||
|
||||
*Case 3:
|
||||
|
||||
Index = [[1, 2, 3]]
|
||||
|
||||
we get:
|
||||
|
||||
Out = [23]
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class GatherNdGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
||||
op->SetType("gather_nd_grad");
|
||||
op->SetInput("Index", Input("Index"));
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
op->SetAttrMap(Attrs());
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GatherNdGradNoNeedBufferVarInference,
|
||||
"X");
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker,
|
||||
ops::GatherNdGradOpDescMaker);
|
||||
|
||||
REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
|
||||
ops::GatherNdGradNoNeedBufferVarInference);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
|
||||
ops::GatherNdOpKernel<double>,
|
||||
ops::GatherNdOpKernel<int64_t>,
|
||||
ops::GatherNdOpKernel<int>,
|
||||
ops::GatherNdOpKernel<uint8_t>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>,
|
||||
ops::GatherNdGradOpKernel<double>,
|
||||
ops::GatherNdGradOpKernel<int64_t>,
|
||||
ops::GatherNdGradOpKernel<int>,
|
||||
ops::GatherNdGradOpKernel<uint8_t>);
|
||||
@ -0,0 +1,105 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/operators/gather.cu.h"
|
||||
#include "paddle/fluid/operators/gather_nd_op.h"
|
||||
#include "paddle/fluid/operators/scatter.cu.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
|
||||
"This kernel only runs on GPU device.");
|
||||
auto *x = ctx.Input<Tensor>("X");
|
||||
auto *index = ctx.Input<Tensor>("Index");
|
||||
auto *output = ctx.Output<Tensor>("Out");
|
||||
|
||||
output->mutable_data<T>(ctx.GetPlace());
|
||||
if (x->numel() == 0) return;
|
||||
const auto &index_type = index->type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
index_type_match, true,
|
||||
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
GPUGatherNd<DeviceContext, T, int>(ctx, *x, *index, output);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
GPUGatherNd<DeviceContext, T, int64_t>(ctx, *x, *index, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
|
||||
"This kernel only runs on GPU device.");
|
||||
auto *index = ctx.Input<Tensor>("Index");
|
||||
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
|
||||
dX->mutable_data<T>(ctx.GetPlace());
|
||||
auto dxt = framework::EigenVector<T>::Flatten(*dX);
|
||||
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
||||
if (dO->numel() == 0) return;
|
||||
|
||||
const auto &index_type = index->type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
index_type_match, true,
|
||||
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
||||
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
GPUScatterNdAdd<DeviceContext, T, int>(ctx, *dO, *index, dX);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
GPUScatterNdAdd<DeviceContext, T, int64_t>(ctx, *dO, *index, dX);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
|
||||
ops::GatherNdOpCUDAKernel<CUDA, double>,
|
||||
ops::GatherNdOpCUDAKernel<CUDA, int64_t>,
|
||||
ops::GatherNdOpCUDAKernel<CUDA, int>,
|
||||
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(gather_nd_grad,
|
||||
ops::GatherNdGradOpCUDAKernel<CUDA, float>,
|
||||
ops::GatherNdGradOpCUDAKernel<CUDA, double>,
|
||||
ops::GatherNdGradOpCUDAKernel<CUDA, int64_t>,
|
||||
ops::GatherNdGradOpCUDAKernel<CUDA, int>,
|
||||
ops::GatherNdGradOpCUDAKernel<CUDA, plat::float16>);
|
||||
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/gather.h"
|
||||
#include "paddle/fluid/operators/scatter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class GatherNdOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
|
||||
"This kernel only runs on CPU.");
|
||||
|
||||
auto *x = ctx.Input<Tensor>("X");
|
||||
auto *index = ctx.Input<Tensor>("Index");
|
||||
auto *output = ctx.Output<Tensor>("Out");
|
||||
|
||||
output->mutable_data<T>(ctx.GetPlace());
|
||||
if (x->numel() == 0) return;
|
||||
|
||||
const auto &index_type = index->type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
index_type_match, true,
|
||||
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
CPUGatherNd<T, int>(ctx.device_context(), *x, *index, output);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
CPUGatherNd<T, int64_t>(ctx.device_context(), *x, *index, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class GatherNdGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
|
||||
"This kernel only runs on CPU.");
|
||||
auto *index = ctx.Input<Tensor>("Index");
|
||||
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
dX->mutable_data<T>(ctx.GetPlace());
|
||||
auto dxt = framework::EigenVector<T>::Flatten(*dX);
|
||||
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
|
||||
.eigen_device();
|
||||
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
||||
if (dO->numel() == 0) return;
|
||||
|
||||
const auto &index_type = index->type();
|
||||
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
||||
index_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
index_type_match, true,
|
||||
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(index_type),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
ScatterNdAdd<T, int32_t>(ctx, *dO, *index, dX);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
ScatterNdAdd<T, int64_t>(ctx, *dO, *index, dX);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class TestGatherNdOpWithEmptyIndex(OpTest):
|
||||
"""
|
||||
Index has empty element, which means copy entire tensor
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "gather_nd"
|
||||
xnp = np.array(
|
||||
[[65, 17, 2], [-14, -25, -1], [76, 22, 3]]).astype("float32")
|
||||
self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
|
||||
self.outputs = {
|
||||
'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestGatherNdOpWithLowIndex(OpTest):
|
||||
"""
|
||||
Index has low rank, X has high rank
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "gather_nd"
|
||||
xnp = np.array(
|
||||
[[65, 17, 2], [14, 25, 1], [76, 22, 3]]).astype("float32")
|
||||
index = np.array([[1], [2]]).astype("int64")
|
||||
|
||||
self.inputs = {'X': xnp, 'Index': index}
|
||||
|
||||
self.outputs = {'Out': xnp[tuple(index.T)]} #[[14, 25, 1], [76, 22, 3]]
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestGatherNdOpWithSameIndexAsX(OpTest):
|
||||
"""
|
||||
Index has same rank as X's rank
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "gather_nd"
|
||||
xnp = np.array(
|
||||
[[65, 17, 2], [14, 25, 1], [76, 22, 3]]).astype("float64")
|
||||
index = np.array([[1, 1], [2, 1]]).astype("int64")
|
||||
|
||||
self.inputs = {'X': xnp, 'Index': index}
|
||||
self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22]
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestGatherNdOpWithHighRankSame(OpTest):
|
||||
"""
|
||||
Both Index and X have high rank, and Rank(Index) = Rank(X)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "gather_nd"
|
||||
shape = (20, 9, 8, 1, 31)
|
||||
xnp = np.random.rand(*shape)
|
||||
index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T
|
||||
|
||||
self.inputs = {'X': xnp, 'Index': index.astype("int32")}
|
||||
self.outputs = {'Out': xnp[tuple(index.T)]}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestGatherNdOpWithHighRankDiff(OpTest):
|
||||
"""
|
||||
Both Index and X have high rank, and Rank(Index) < Rank(X)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "gather_nd"
|
||||
shape = (20, 9, 8, 1, 31)
|
||||
xnp = np.random.rand(*shape).astype("double")
|
||||
index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T
|
||||
index_re = index.reshape([10, 5, 20, 5])
|
||||
|
||||
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
|
||||
self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
#Test Python API
|
||||
class TestGatherNdOpAPI(OpTest):
|
||||
def test_case1(self):
|
||||
x1 = fluid.layers.data(
|
||||
name='x1', shape=[30, 40, 50, 60], dtype='float32')
|
||||
index1 = fluid.layers.data(name='index1', shape=[2, 4], dtype='int32')
|
||||
output1 = fluid.layers.gather_nd(x1, index1)
|
||||
|
||||
def test_case2(self):
|
||||
x2 = fluid.layers.data(name='x2', shape=[30, 40, 50], dtype='float32')
|
||||
index2 = fluid.layers.data(name='index2', shape=[2, 2], dtype='int64')
|
||||
output2 = fluid.layers.gather_nd(x2, index2)
|
||||
|
||||
def test_case3(self):
|
||||
x3 = fluid.layers.data(name='x3', shape=[3, 4, 5], dtype='float32')
|
||||
index3 = fluid.layers.data(name='index3', shape=[2, 1], dtype='int32')
|
||||
output3 = fluid.layers.gather_nd(x3, index3, name="gather_nd_layer")
|
||||
|
||||
|
||||
#Test Raise Index Error
|
||||
class TestGatherNdOpRaise(OpTest):
|
||||
def test_check_raise(self):
|
||||
def check_raise_is_test():
|
||||
try:
|
||||
x = fluid.layers.data(
|
||||
name='x', shape=[3, 4, 5], dtype='float32')
|
||||
index = fluid.layers.data(
|
||||
name='index', shape=[2, 10], dtype='int32')
|
||||
output = fluid.layers.gather_nd(x, index)
|
||||
except Exception as e:
|
||||
t = \
|
||||
"Input(Index).shape[-1] <= Input(X).rank"
|
||||
if t in str(e):
|
||||
raise IndexError
|
||||
|
||||
self.assertRaises(IndexError, check_raise_is_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue