Add histc op (#24562)
* add histc operator, test=develop * update english doc to 2.0 API, test=develop * update API from histc to histogram, test=develop Co-authored-by: root <root@yq01-gpu-255-129-15-00.epc.baidu.com>revert-24981-add_device_attr_for_regulization
parent
1f032c53d5
commit
704cad6a66
@ -0,0 +1,92 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/histogram_op.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::OpKernelType;
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
class HistogramOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "histogram");
|
||||||
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "histogram");
|
||||||
|
const auto &nbins = ctx->Attrs().Get<int64_t>("bins");
|
||||||
|
const auto &minval = ctx->Attrs().Get<int>("min");
|
||||||
|
const auto &maxval = ctx->Attrs().Get<int>("max");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_GE(nbins, 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The bins should be greater than or equal to 1."
|
||||||
|
"But received nbins is %d",
|
||||||
|
nbins));
|
||||||
|
PADDLE_ENFORCE_GE(maxval, minval, platform::errors::InvalidArgument(
|
||||||
|
"max must be larger or equal to min."
|
||||||
|
"But received max is %d, min is %d",
|
||||||
|
maxval, minval));
|
||||||
|
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim({nbins}));
|
||||||
|
ctx->ShareLoD("X", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext &ctx) const {
|
||||||
|
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||||
|
return framework::OpKernelType(data_type, ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HistogramOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "(Tensor) The input tensor of Histogram op,");
|
||||||
|
AddOutput("Out", "(Tensor) The output tensor of Histogram op,");
|
||||||
|
AddAttr<int64_t>("bins", "(int) number of histogram bins")
|
||||||
|
.SetDefault(100)
|
||||||
|
.EqualGreaterThan(1);
|
||||||
|
AddAttr<int>("min", "(int) lower end of the range (inclusive)")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddAttr<int>("max", "(int) upper end of the range (inclusive)")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Histogram Operator.
|
||||||
|
Computes the histogram of a tensor. The elements are sorted
|
||||||
|
into equal width bins between min and max. If min and max are
|
||||||
|
both zero, the minimum and maximum values of the data are used.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(
|
||||||
|
histogram, ops::HistogramOp, ops::HistogramOpMaker,
|
||||||
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||||
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
histogram, ops::HistogramKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::HistogramKernel<paddle::platform::CPUDeviceContext, double>,
|
||||||
|
ops::HistogramKernel<paddle::platform::CPUDeviceContext, int>,
|
||||||
|
ops::HistogramKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,147 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/operators/histogram_op.h"
|
||||||
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||||
|
#include "paddle/fluid/platform/gpu_launch_config.h"
|
||||||
|
#include "paddle/fluid/platform/hostdevice.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using IndexType = int64_t;
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||||
|
|
||||||
|
#define CUDA_KERNEL_LOOP(i, n) \
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||||
|
i += blockDim.x * gridDim.x)
|
||||||
|
|
||||||
|
inline int GET_BLOCKS(const int N) {
|
||||||
|
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IndexType>
|
||||||
|
__device__ static IndexType GetBin(T bVal, T minvalue, T maxvalue,
|
||||||
|
int64_t nbins) {
|
||||||
|
IndexType bin =
|
||||||
|
static_cast<int>((bVal - minvalue) * nbins / (maxvalue - minvalue));
|
||||||
|
if (bin == nbins) bin -= 1;
|
||||||
|
return bin;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IndexType>
|
||||||
|
__global__ void KernelHistogram(const T* input, const int totalElements,
|
||||||
|
const int64_t nbins, const T minvalue,
|
||||||
|
const T maxvalue, int64_t* output) {
|
||||||
|
CUDA_KERNEL_LOOP(linearIndex, totalElements) {
|
||||||
|
const IndexType inputIdx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
const auto inputVal = input[inputIdx];
|
||||||
|
if (inputVal >= minvalue && inputVal <= maxvalue) {
|
||||||
|
const IndexType bin =
|
||||||
|
GetBin<T, IndexType>(inputVal, minvalue, maxvalue, nbins);
|
||||||
|
const IndexType outputIdx = bin < nbins - 1 ? bin : nbins - 1;
|
||||||
|
paddle::platform::CudaAtomicAdd(&output[outputIdx], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class HistogramCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
platform::is_gpu_place(context.GetPlace()), true,
|
||||||
|
platform::errors::InvalidArgument("It must use CUDAPlace."));
|
||||||
|
|
||||||
|
const Tensor* input = context.Input<framework::Tensor>("X");
|
||||||
|
Tensor* output = context.Output<framework::Tensor>("Out");
|
||||||
|
auto& nbins = context.Attr<int64_t>("bins");
|
||||||
|
auto& minval = context.Attr<int>("min");
|
||||||
|
auto& maxval = context.Attr<int>("max");
|
||||||
|
|
||||||
|
const T* input_data = input->data<T>();
|
||||||
|
const int input_numel = input->numel();
|
||||||
|
|
||||||
|
T output_min = static_cast<T>(minval);
|
||||||
|
T output_max = static_cast<T>(maxval);
|
||||||
|
|
||||||
|
if (output_min == output_max) {
|
||||||
|
auto input_x = framework::EigenVector<T>::Flatten(*input);
|
||||||
|
|
||||||
|
framework::Tensor input_min_t, input_max_t;
|
||||||
|
auto* input_min_data =
|
||||||
|
input_min_t.mutable_data<T>({1}, context.GetPlace());
|
||||||
|
auto* input_max_data =
|
||||||
|
input_max_t.mutable_data<T>({1}, context.GetPlace());
|
||||||
|
auto input_min_scala = framework::EigenScalar<T>::From(input_min_t);
|
||||||
|
auto input_max_scala = framework::EigenScalar<T>::From(input_max_t);
|
||||||
|
|
||||||
|
auto* place =
|
||||||
|
context.template device_context<DeviceContext>().eigen_device();
|
||||||
|
input_min_scala.device(*place) = input_x.minimum();
|
||||||
|
input_max_scala.device(*place) = input_x.maximum();
|
||||||
|
|
||||||
|
Tensor input_min_cpu, input_max_cpu;
|
||||||
|
TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);
|
||||||
|
TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
|
||||||
|
|
||||||
|
output_min = input_min_cpu.data<T>()[0];
|
||||||
|
output_max = input_max_cpu.data<T>()[0];
|
||||||
|
}
|
||||||
|
if (output_min == output_max) {
|
||||||
|
output_min = output_min - 1;
|
||||||
|
output_max = output_max + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
(std::isinf(static_cast<float>(output_min)) ||
|
||||||
|
std::isnan(static_cast<float>(output_max)) ||
|
||||||
|
std::isinf(static_cast<float>(output_min)) ||
|
||||||
|
std::isnan(static_cast<float>(output_max))),
|
||||||
|
false, platform::errors::OutOfRange("range of min, max is not finite"));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
output_max, output_min,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"max must be larger or equal to min. If min and max are both zero, "
|
||||||
|
"the minimum and maximum values of the data are used. "
|
||||||
|
"But received max is %d, min is %d",
|
||||||
|
maxval, minval));
|
||||||
|
|
||||||
|
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
|
||||||
|
math::SetConstant<platform::CUDADeviceContext, int64_t>()(
|
||||||
|
context.template device_context<platform::CUDADeviceContext>(), output,
|
||||||
|
static_cast<int64_t>(0));
|
||||||
|
|
||||||
|
auto stream =
|
||||||
|
context.template device_context<platform::CUDADeviceContext>().stream();
|
||||||
|
KernelHistogram<T, IndexType><<<GET_BLOCKS(input_numel),
|
||||||
|
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
|
||||||
|
input_data, input_numel, nbins, output_min, output_max, out_data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
histogram,
|
||||||
|
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, int>,
|
||||||
|
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
|
||||||
|
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::HistogramCUDAKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class HistogramKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
const Tensor* input = context.Input<framework::Tensor>("X");
|
||||||
|
Tensor* output = context.Output<framework::Tensor>("Out");
|
||||||
|
auto& nbins = context.Attr<int64_t>("bins");
|
||||||
|
auto& minval = context.Attr<int>("min");
|
||||||
|
auto& maxval = context.Attr<int>("max");
|
||||||
|
|
||||||
|
const T* input_data = input->data<T>();
|
||||||
|
auto input_numel = input->numel();
|
||||||
|
|
||||||
|
T output_min = static_cast<T>(minval);
|
||||||
|
T output_max = static_cast<T>(maxval);
|
||||||
|
if (output_min == output_max) {
|
||||||
|
output_min = *std::min_element(input_data, input_data + input_numel);
|
||||||
|
output_max = *std::max_element(input_data, input_data + input_numel);
|
||||||
|
}
|
||||||
|
if (output_min == output_max) {
|
||||||
|
output_min = output_min - 1;
|
||||||
|
output_max = output_max + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
(std::isinf(static_cast<float>(output_min)) ||
|
||||||
|
std::isnan(static_cast<float>(output_max)) ||
|
||||||
|
std::isinf(static_cast<float>(output_min)) ||
|
||||||
|
std::isnan(static_cast<float>(output_max))),
|
||||||
|
false, platform::errors::OutOfRange("range of min, max is not finite"));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
output_max, output_min,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"max must be larger or equal to min. If min and max are both zero, "
|
||||||
|
"the minimum and maximum values of the data are used. "
|
||||||
|
"But received max is %d, min is %d",
|
||||||
|
maxval, minval));
|
||||||
|
|
||||||
|
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
|
||||||
|
math::SetConstant<DeviceContext, int64_t>()(
|
||||||
|
context.template device_context<DeviceContext>(), output,
|
||||||
|
static_cast<int64_t>(0));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < input_numel; i++) {
|
||||||
|
if (input_data[i] >= output_min && input_data[i] <= output_max) {
|
||||||
|
const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /
|
||||||
|
(output_max - output_min));
|
||||||
|
out_data[std::min(bin, nbins - 1)] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,87 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from paddle.fluid import Program, program_guard
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistogramOpAPI(unittest.TestCase):
|
||||||
|
"""Test histogram api."""
|
||||||
|
|
||||||
|
def test_static_graph(self):
|
||||||
|
startup_program = fluid.Program()
|
||||||
|
train_program = fluid.Program()
|
||||||
|
with fluid.program_guard(train_program, startup_program):
|
||||||
|
inputs = fluid.data(name='input', dtype='int64', shape=[2, 3])
|
||||||
|
output = paddle.histogram(inputs, bins=5, min=1, max=5)
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
if fluid.core.is_compiled_with_cuda():
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup_program)
|
||||||
|
img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64)
|
||||||
|
res = exe.run(train_program,
|
||||||
|
feed={'input': img},
|
||||||
|
fetch_list=[output])
|
||||||
|
actual = np.array(res[0])
|
||||||
|
expected = np.array([0, 3, 0, 2, 1]).astype(np.int64)
|
||||||
|
self.assertTrue(
|
||||||
|
(actual == expected).all(),
|
||||||
|
msg='histogram output is wrong, out =' + str(actual))
|
||||||
|
|
||||||
|
def test_dygraph(self):
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
inputs_np = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int64)
|
||||||
|
inputs = fluid.dygraph.to_variable(inputs_np)
|
||||||
|
actual = paddle.histogram(inputs, bins=5, min=1, max=5)
|
||||||
|
expected = np.array([0, 3, 0, 2, 1]).astype(np.int64)
|
||||||
|
self.assertTrue(
|
||||||
|
(actual.numpy() == expected).all(),
|
||||||
|
msg='histogram output is wrong, out =' + str(actual.numpy()))
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistogramOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "histogram"
|
||||||
|
self.init_test_case()
|
||||||
|
np_input = np.random.randint(
|
||||||
|
low=0, high=20, size=self.in_shape, dtype=np.int64)
|
||||||
|
self.inputs = {"X": np_input}
|
||||||
|
self.init_attrs()
|
||||||
|
Out, _ = np.histogram(
|
||||||
|
np_input, bins=self.bins, range=(self.min, self.max))
|
||||||
|
self.outputs = {"Out": Out.astype(np.int64)}
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.in_shape = (10, 12)
|
||||||
|
self.bins = 5
|
||||||
|
self.min = 1
|
||||||
|
self.max = 5
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.attrs = {"bins": self.bins, "min": self.min, "max": self.max}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue