commit
831909ce69
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||
|
||||
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
|
||||
paddle::operators::ArgMaxOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
arg_max,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||
int64_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||
int32_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||
int16_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||
uint8_t>);
|
@ -0,0 +1,31 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
arg_max,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||
double>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||
int64_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||
int32_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||
int16_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||
size_t>,
|
||||
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||
uint8_t>);
|
@ -0,0 +1,160 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/string/printf.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
enum ArgMinMaxType { kArgMin, kArgMax };
|
||||
|
||||
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
|
||||
ArgMinMaxType argMinMaxValue>
|
||||
struct ArgMinMaxFunctor {};
|
||||
|
||||
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
|
||||
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
|
||||
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
|
||||
enum_argminmax_value> { \
|
||||
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
|
||||
framework::LoDTensor* out, int64_t axis) { \
|
||||
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
|
||||
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
|
||||
out_eigen.device(*(ctx.eigen_device())) = \
|
||||
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
|
||||
} \
|
||||
}
|
||||
|
||||
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
|
||||
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
|
||||
|
||||
template <typename DeviceContext, typename T, typename Tout,
|
||||
ArgMinMaxType EnumArgMinMaxValue>
|
||||
class ArgMinMaxKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
|
||||
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
|
||||
out.mutable_data<Tout>(ctx.GetPlace());
|
||||
auto axis = ctx.Attr<int64_t>("axis");
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
|
||||
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
|
||||
functor##rank; \
|
||||
functor##rank(dev_ctx, x, &out, axis)
|
||||
|
||||
switch (x.dims().size()) {
|
||||
case 1:
|
||||
CALL_ARG_MINMAX_FUNCTOR(1);
|
||||
break;
|
||||
case 2:
|
||||
CALL_ARG_MINMAX_FUNCTOR(2);
|
||||
break;
|
||||
case 3:
|
||||
CALL_ARG_MINMAX_FUNCTOR(3);
|
||||
break;
|
||||
case 4:
|
||||
CALL_ARG_MINMAX_FUNCTOR(4);
|
||||
break;
|
||||
case 5:
|
||||
CALL_ARG_MINMAX_FUNCTOR(5);
|
||||
break;
|
||||
case 6:
|
||||
CALL_ARG_MINMAX_FUNCTOR(6);
|
||||
break;
|
||||
default:
|
||||
PADDLE_THROW(
|
||||
"%s operator doesn't supports tensors whose ranks are greater "
|
||||
"than 6.",
|
||||
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
|
||||
break;
|
||||
#undef CALL_ARG_MINMAX_FUNCTOR
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
using ArgMinKernel =
|
||||
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
using ArgMaxKernel =
|
||||
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;
|
||||
|
||||
class ArgMinMaxOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
|
||||
const auto& x_dims = ctx->GetInputDim("X");
|
||||
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
|
||||
PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(),
|
||||
"'axis' must be inside [-Rank(X), Rank(X))");
|
||||
|
||||
auto x_rank = x_dims.size();
|
||||
if (axis < 0) axis += x_rank;
|
||||
|
||||
std::vector<int64_t> vec;
|
||||
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
|
||||
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(vec));
|
||||
}
|
||||
};
|
||||
|
||||
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
protected:
|
||||
virtual const char* OpName() const = 0;
|
||||
virtual const char* Name() const = 0;
|
||||
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "Input tensor.");
|
||||
AddOutput("Out", "Output tensor.");
|
||||
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
|
||||
AddComment(string::Sprintf(R"DOC(
|
||||
%s Operator.
|
||||
|
||||
Computes the indices of the %s elements of the input tensor's element
|
||||
along the provided axis.
|
||||
)DOC",
|
||||
OpName(), Name()));
|
||||
}
|
||||
};
|
||||
|
||||
class ArgMinOpMaker : public BaseArgMinMaxOpMaker {
|
||||
protected:
|
||||
const char* OpName() const override { return "ArgMin"; }
|
||||
const char* Name() const override { return "min"; }
|
||||
};
|
||||
|
||||
class ArgMaxOpMaker : public BaseArgMinMaxOpMaker {
|
||||
protected:
|
||||
const char* OpName() const override { return "ArgMax"; }
|
||||
const char* Name() const override { return "max"; }
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||
|
||||
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
|
||||
paddle::operators::ArgMinOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
arg_min,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||
int64_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||
int32_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||
int16_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||
uint8_t>);
|
@ -0,0 +1,31 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
arg_min,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||
double>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||
int64_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||
int32_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||
int16_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||
size_t>,
|
||||
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||
uint8_t>);
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class BaseTestCase(OpTest):
|
||||
def initTestCase(self):
|
||||
self.op_type = 'arg_min'
|
||||
self.dims = (3, 4, 5)
|
||||
self.dtype = 'float32'
|
||||
self.axis = 0
|
||||
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.x = (1000 * np.random.random(self.dims)).astype(self.dtype)
|
||||
self.inputs = {'X': self.x}
|
||||
self.attrs = {'axis': self.axis}
|
||||
if self.op_type == "arg_min":
|
||||
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
|
||||
else:
|
||||
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestCase0(BaseTestCase):
|
||||
def initTestCase(self):
|
||||
self.op_type = 'arg_max'
|
||||
self.dims = (3, 4, 5)
|
||||
self.dtype = 'float32'
|
||||
self.axis = 0
|
||||
|
||||
|
||||
class TestCase1(BaseTestCase):
|
||||
def initTestCase(self):
|
||||
self.op_type = 'arg_min'
|
||||
self.dims = (3, 4)
|
||||
self.dtype = 'float64'
|
||||
self.axis = 1
|
||||
|
||||
|
||||
class TestCase2(BaseTestCase):
|
||||
def initTestCase(self):
|
||||
self.op_type = 'arg_max'
|
||||
self.dims = (3, 4)
|
||||
self.dtype = 'int64'
|
||||
self.axis = 0
|
||||
|
||||
|
||||
class TestCase3(BaseTestCase):
|
||||
def initTestCase(self):
|
||||
self.op_type = 'arg_max'
|
||||
self.dims = (3, )
|
||||
self.dtype = 'int64'
|
||||
self.axis = 0
|
||||
|
||||
|
||||
class TestCase4(BaseTestCase):
|
||||
def initTestCase(self):
|
||||
self.op_type = 'arg_min'
|
||||
self.dims = (1, )
|
||||
self.dtype = 'int32'
|
||||
self.axis = 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue