add the argmax, argmin for the api2.0

* add the new api and op for the argmax, argmin
test_feature_precision_test_c
wawltor 5 years ago committed by GitHub
parent d26ae9ad87
commit 6b28456ed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,9 +53,9 @@ using Tensor = framework::Tensor;
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const IndType height, // n * h
const IndType width, // c
const IndType post_size, // h
__global__ void ArgCUDAKernel(const int64_t height, // n * h
const int64_t width, // c
const int64_t post_size, // h
const Reducer reducer, const T init, const T* in,
IndType* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h
template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
Tensor* indices, const IndType pre, const IndType post,
const IndType n) {
Tensor* indices, const int64_t pre, const int64_t post,
const int64_t n) {
auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](IndType col) {
auto ComputeBlockSize = [](int64_t col) {
if (col > 512)
return 1024;
else if (col > 256)
@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
return 8;
};
int max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
int height = pre * post;
int width = n;
int grid_size = height < max_grid_dimx ? height : max_grid_dimx;
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
int64_t height = pre * post;
int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace());
@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
}
template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
struct VisitDataCudaArgMinMaxFunctor {
const framework::ExecutionContext& ctx;
explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename IndType>
void apply() const {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int64_t>("axis");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const bool& flatten = ctx.Attr<bool>("flatten");
framework::DDim input_dims;
if (flatten) {
input_dims = framework::make_ddim({input->numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
input_dims = input->dims();
if (axis < 0) axis += input->dims().size();
}
int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis];
int64_t groups = numel / input_dims[axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = in_dims[axis];
int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) {
pre *= in_dims[i];
pre *= input_dims[i];
}
for (int i = axis + 1; i < in_dims.size(); i++) {
post *= in_dims[i];
for (int i = axis + 1; i < input_dims.size(); i++) {
post *= input_dims[i];
}
const auto& dev_ctx = ctx.cuda_device_context();
ComputeFullArg<T, int64_t, Reducer>(dev_ctx, *input, output, pre, post, n);
ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n);
}
};
template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataType(static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
return;
}
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
}
};

@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {};
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
framework::LoDTensor* out, framework::DDim x_dims, \
int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor {
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto keepdims = ctx.Attr<bool>("keepdims");
auto x_rank = x.dims().size();
if (axis < 0) axis += x_rank;
const bool& flatten = ctx.Attr<bool>("flatten");
// if flatten, will construct the new dims for the cacluate
framework::DDim x_dims;
if (flatten) {
x_dims = framework::make_ddim({x.numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x.dims();
if (axis < 0) axis += x_dims.size();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, axis, keepdims)
functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
switch (x.dims().size()) {
switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
bool keepdims = ctx->Attrs().Get<bool>("keepdims");
const bool& flatten = ctx->Attrs().Get<bool>("flatten");
PADDLE_ENFORCE_GE(axis, -x_dims.size(),
platform::errors::InvalidArgument(
@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
std::vector<int64_t> vec;
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
if (keepdims) {
vec.push_back(static_cast<int64_t>(1));
if (flatten) {
// if is flatten, will return the only on element
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
} else {
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
ctx->SetOutputDim("Out", framework::make_ddim(vec));
}
};
@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
AddAttr<int>("dtype", "Keep the dim that to reduce.").SetDefault(-1);
AddAttr<bool>("flatten",
"Flatten the input value, and search the min or max indices")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
%s Operator.

@ -201,107 +201,5 @@ class BaseTestComplex2_2(OpTest):
}
class APT_ArgMaxTest(unittest.TestCase):
def test_output_result(self):
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=[3, 4], dtype="float32")
data2 = fluid.data(name="Y", shape=[3], dtype="int64")
out = paddle.argmax(input=data1, out=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(
feed={"X": np.random.rand(3, 4).astype("float32")},
fetch_list=[data2, out])
self.assertEqual((result[0] == result[1]).all(), True)
def test_basic(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[3, 4], dtype="float32")
out = paddle.argmax(input=data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = np.random.rand(3, 4).astype("float32")
expected_result = np.argmax(np_input, axis=1)
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[3, 4], dtype="float32")
out = paddle.argmax(input=data, axis=0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = np.random.rand(3, 4).astype("float32")
expected_result = np.argmax(np_input, axis=0)
result = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=[3, 4], dtype="float32")
out = paddle.argmax(input=data, dtype="int32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = np.random.rand(3, 4).astype("float32")
expected_result = np.argmax(np_input, axis=1).astype(np.int32)
result = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=[3, 4], dtype="float32")
data2 = fluid.data(name="Y", shape=[3], dtype="int64")
out = paddle.argmax(input=data, out=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(
feed={"X": np.random.rand(3, 4).astype("float32")},
fetch_list=[data2, out])
self.assertEqual((result[0] == result[1]).all(), True)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[100], dtype="float32")
y_1 = paddle.argmax(x, name='arg_max_res')
self.assertEqual(('arg_max_res' in y_1.name), True)
def test_errors(self):
def test_dtype1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
paddle.argmax(data, dtype="float32")
self.assertRaises(TypeError, test_dtype1)
def test_dtype2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float64")
paddle.argmax(data, dtype="float32")
self.assertRaises(TypeError, test_dtype2)
class TestArgMinMaxOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_argmax_x_type():
x1 = [1, 2, 3]
output = fluid.layers.argmax(x=x1)
self.assertRaises(TypeError, test_argmax_x_type)
def test_argmin_x_type():
x2 = [1, 2, 3]
output = fluid.layers.argmin(x=x2)
self.assertRaises(TypeError, test_argmin_x_type)
if __name__ == '__main__':
unittest.main()

@ -125,95 +125,168 @@ def argsort(x, axis=-1, descending=False, name=None):
return ids
def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None):
def argmax(x, axis=None, dtype=None, keepdim=False, name=None):
"""
:alias_main: paddle.argmax
:alias: paddle.argmax,paddle.tensor.argmax,paddle.tensor.search.argmax
This OP computes the indices of the max elements of the input tensor's
element along the provided axis.
Args:
input(Variable): An input N-D Tensor with type float32, float64, int16,
x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(input). when axis<0, it works the same way
as axis+R. Default is None, it will use the last dim to select indices of max value.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor which can
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will
return the int64 indices.
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result. Defalut is None.
keepdims(bool, optional): Keep the axis that do the select max.
keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Variable: A Tensor with data type int64.
Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
import paddle
in1 = np.array([[[5,8,9,5],
[0,0,1,7],
[6,9,2,4]],
[[5,2,4,2],
[4,7,7,9],
[1,7,0,6]]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(in1)
out1 = paddle.argmax(input=x, axis=-1)
out2 = paddle.argmax(input=x, axis=0)
out3 = paddle.argmax(input=x, axis=1)
out4 = paddle.argmax(input=x, axis=2)
out5 = paddle.argmax(input=x, axis=2, keepdims=True)
print(out1.numpy())
# [[2 3 1]
# [0 3 1]]
print(out2.numpy())
# [[0 0 0 0]
# [1 1 1 1]
# [0 0 0 1]]
print(out3.numpy())
# [[2 2 0 1]
# [0 1 1 1]]
print(out4.numpy())
# [[2 3 1]
# [0 3 1]]
print(out5.numpy())
#array([[[2],
# [3],
# [1]],
# [[0],
# [3],
# [1]]])
paddle.disable_static()
data = np.array([[5,8,9,5],
[0,0,1,7],
[6,9,2,4]])
x = paddle.to_variable(data)
out1 = paddle.argmax(x)
print(out1.numpy()) # 2
out2 = paddle.argmax(x, axis=1)
print(out2.numpy())
# [2 3 1]
out3 = paddle.argmax(x, axis=-1)
print(out3.numpy())
# [2 3 1]
"""
helper = LayerHelper("arg_max", **locals())
flatten = False
if axis is None:
flatten = True
axis = 0
if in_dygraph_mode():
if dtype != None:
var_dtype = convert_np_dtype_to_dtype_(dtype)
out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype,
'keepdim', keepdim, 'flatten', flatten)
else:
out = core.ops.arg_max(x, 'axis', axis, 'keepdim', keepdim,
'flatten', flatten)
return out
helper = LayerHelper("argmax", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'paddle.argmax')
var_dtype = None
attrs = {}
if dtype is not None:
check_dtype(dtype, 'create data type', ['int32', 'int64'], 'arg_max')
if dtype not in ['int32', 'int64']:
raise ValueError(
"The value of 'dtype' in argmax op must be int32, int64, but received of {}".
format(dtype))
var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype
else:
var_dtype = VarDesc.VarType.INT64
if out is None:
out = helper.create_variable_for_type_inference(var_dtype)
out = helper.create_variable_for_type_inference(var_dtype)
attrs['keepdims'] = keepdim
attrs['axis'] = axis
attrs['flatten'] = flatten
helper.append_op(
type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
out.stop_gradient = True
return out
def argmin(x, axis=None, dtype=None, keepdim=False, name=None):
"""
This OP computes the indices of the min elements of the input tensor's
element along the provided axis.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is None, and it will
return the int64 indices.
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
data = np.array([[5,8,9,5],
[0,0,1,7],
[6,9,2,4]])
x = paddle.to_variable(data)
out1 = paddle.argmin(x)
print(out1.numpy()) # 4
out2 = paddle.argmin(x, axis=1)
print(out2.numpy())
# [0 0 2]
out3 = paddle.argmin(x, axis=-1)
print(out3.numpy())
# [0 0 2]
"""
flatten = False
if axis is None:
axis = -1
attrs['keepdims'] = keepdims
flatten = True
axis = 0
if in_dygraph_mode():
if dtype != None:
var_dtype = convert_np_dtype_to_dtype_(dtype)
out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype,
'keepdim', keepdim, 'flatten', flatten)
else:
out = core.ops.arg_min(x, 'axis', axis, 'keepdim', keepdim,
'flatten', flatten)
return out
helper = LayerHelper("argmin", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'paddle.argmin')
var_dtype = None
attrs = {}
if dtype is not None:
if dtype not in ['int32', 'int64']:
raise ValueError(
"The value of 'dtype' in argmin op must be int32, int64, but received of {}".
format(dtype))
var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype
else:
var_dtype = VarDesc.VarType.INT64
out = helper.create_variable_for_type_inference(var_dtype)
attrs['keepdims'] = keepdim
attrs['axis'] = axis
attrs['flatten'] = flatten
helper.append_op(
type='arg_max',
inputs={'X': input},
outputs={'Out': [out]},
attrs=attrs)
type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
out.stop_gradient = True
return out

Loading…
Cancel
Save