commit
831909ce69
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
|
||||||
|
paddle::operators::ArgMaxOpMaker,
|
||||||
|
paddle::framework::EmptyGradOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
arg_max,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
int64_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
int32_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
int16_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
uint8_t>);
|
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
arg_max,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
double>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int64_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int32_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int16_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
size_t>,
|
||||||
|
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
uint8_t>);
|
@ -0,0 +1,160 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/ddim.h"
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
#include "paddle/fluid/string/printf.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
enum ArgMinMaxType { kArgMin, kArgMax };
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
|
||||||
|
ArgMinMaxType argMinMaxValue>
|
||||||
|
struct ArgMinMaxFunctor {};
|
||||||
|
|
||||||
|
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
|
||||||
|
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
|
||||||
|
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
|
||||||
|
enum_argminmax_value> { \
|
||||||
|
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
|
||||||
|
framework::LoDTensor* out, int64_t axis) { \
|
||||||
|
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
|
||||||
|
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
|
||||||
|
out_eigen.device(*(ctx.eigen_device())) = \
|
||||||
|
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
|
||||||
|
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T, typename Tout,
|
||||||
|
ArgMinMaxType EnumArgMinMaxValue>
|
||||||
|
class ArgMinMaxKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
|
||||||
|
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
|
||||||
|
out.mutable_data<Tout>(ctx.GetPlace());
|
||||||
|
auto axis = ctx.Attr<int64_t>("axis");
|
||||||
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||||
|
|
||||||
|
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
|
||||||
|
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
|
||||||
|
functor##rank; \
|
||||||
|
functor##rank(dev_ctx, x, &out, axis)
|
||||||
|
|
||||||
|
switch (x.dims().size()) {
|
||||||
|
case 1:
|
||||||
|
CALL_ARG_MINMAX_FUNCTOR(1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
CALL_ARG_MINMAX_FUNCTOR(2);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
CALL_ARG_MINMAX_FUNCTOR(3);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
CALL_ARG_MINMAX_FUNCTOR(4);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
CALL_ARG_MINMAX_FUNCTOR(5);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
CALL_ARG_MINMAX_FUNCTOR(6);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PADDLE_THROW(
|
||||||
|
"%s operator doesn't supports tensors whose ranks are greater "
|
||||||
|
"than 6.",
|
||||||
|
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
|
||||||
|
break;
|
||||||
|
#undef CALL_ARG_MINMAX_FUNCTOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
using ArgMinKernel =
|
||||||
|
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
using ArgMaxKernel =
|
||||||
|
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;
|
||||||
|
|
||||||
|
class ArgMinMaxOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
|
||||||
|
const auto& x_dims = ctx->GetInputDim("X");
|
||||||
|
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
|
||||||
|
PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(),
|
||||||
|
"'axis' must be inside [-Rank(X), Rank(X))");
|
||||||
|
|
||||||
|
auto x_rank = x_dims.size();
|
||||||
|
if (axis < 0) axis += x_rank;
|
||||||
|
|
||||||
|
std::vector<int64_t> vec;
|
||||||
|
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
|
||||||
|
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim(vec));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
protected:
|
||||||
|
virtual const char* OpName() const = 0;
|
||||||
|
virtual const char* Name() const = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "Input tensor.");
|
||||||
|
AddOutput("Out", "Output tensor.");
|
||||||
|
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
|
||||||
|
AddComment(string::Sprintf(R"DOC(
|
||||||
|
%s Operator.
|
||||||
|
|
||||||
|
Computes the indices of the %s elements of the input tensor's element
|
||||||
|
along the provided axis.
|
||||||
|
)DOC",
|
||||||
|
OpName(), Name()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ArgMinOpMaker : public BaseArgMinMaxOpMaker {
|
||||||
|
protected:
|
||||||
|
const char* OpName() const override { return "ArgMin"; }
|
||||||
|
const char* Name() const override { return "min"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class ArgMaxOpMaker : public BaseArgMinMaxOpMaker {
|
||||||
|
protected:
|
||||||
|
const char* OpName() const override { return "ArgMax"; }
|
||||||
|
const char* Name() const override { return "max"; }
|
||||||
|
};
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
|
||||||
|
paddle::operators::ArgMinOpMaker,
|
||||||
|
paddle::framework::EmptyGradOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
arg_min,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
int64_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
int32_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
int16_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
uint8_t>);
|
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/arg_min_max_op_base.h"
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
arg_min,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
double>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int64_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int32_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
int16_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
size_t>,
|
||||||
|
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
|
||||||
|
uint8_t>);
|
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTestCase(OpTest):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.op_type = 'arg_min'
|
||||||
|
self.dims = (3, 4, 5)
|
||||||
|
self.dtype = 'float32'
|
||||||
|
self.axis = 0
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.initTestCase()
|
||||||
|
self.x = (1000 * np.random.random(self.dims)).astype(self.dtype)
|
||||||
|
self.inputs = {'X': self.x}
|
||||||
|
self.attrs = {'axis': self.axis}
|
||||||
|
if self.op_type == "arg_min":
|
||||||
|
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
|
||||||
|
else:
|
||||||
|
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase0(BaseTestCase):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.op_type = 'arg_max'
|
||||||
|
self.dims = (3, 4, 5)
|
||||||
|
self.dtype = 'float32'
|
||||||
|
self.axis = 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase1(BaseTestCase):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.op_type = 'arg_min'
|
||||||
|
self.dims = (3, 4)
|
||||||
|
self.dtype = 'float64'
|
||||||
|
self.axis = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase2(BaseTestCase):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.op_type = 'arg_max'
|
||||||
|
self.dims = (3, 4)
|
||||||
|
self.dtype = 'int64'
|
||||||
|
self.axis = 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase3(BaseTestCase):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.op_type = 'arg_max'
|
||||||
|
self.dims = (3, )
|
||||||
|
self.dtype = 'int64'
|
||||||
|
self.axis = 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase4(BaseTestCase):
|
||||||
|
def initTestCase(self):
|
||||||
|
self.op_type = 'arg_min'
|
||||||
|
self.dims = (1, )
|
||||||
|
self.dtype = 'int32'
|
||||||
|
self.axis = 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue