add bernoulli op (#26511)
* add bernoulli op * fix cuda kernel and add unit test * refine doc * fix uniformtest_feature_precision_test_c
parent
f3909020de
commit
aa2a9b5d89
@ -0,0 +1,88 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/fluid/operators/bernoulli_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/generator.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/common_infer_shape_functions.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BernoulliOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"A tensor with probabilities for generating the random binary "
|
||||
"number");
|
||||
AddOutput("Out", "A Tensor filled with random binary number");
|
||||
AddComment(R"DOC(
|
||||
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
|
||||
|
||||
Out ~ Bernoulli(X)
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class BernoulliOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
return UnaryOpUnchangedInferShape(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
|
||||
// Use std::random and thrust::random(thrust is a std library in CUDA) to
|
||||
// implement uniform random.
|
||||
template <typename T>
|
||||
class BernoulliOpKernel<platform::CPUDeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
const auto x = ctx.Input<framework::Tensor>("X");
|
||||
auto out = ctx.Output<framework::Tensor>("Out");
|
||||
auto *in_data = x->data<T>();
|
||||
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int64_t size = x->numel();
|
||||
std::uniform_real_distribution<T> dist(0.0, 1.0);
|
||||
auto gen_ptr = framework::Generator::GetInstance();
|
||||
std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine();
|
||||
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
out_data[i] = BernoulliFunctor(in_data[i], dist(gen_engine));
|
||||
}
|
||||
}
|
||||
}; // namespace operators
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OPERATOR(
|
||||
bernoulli, ops::BernoulliOp, ops::BernoulliOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(bernoulli,
|
||||
ops::BernoulliOpKernel<plat::CPUDeviceContext, float>,
|
||||
ops::BernoulliOpKernel<plat::CPUDeviceContext, double>);
|
@ -0,0 +1,72 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include "paddle/fluid/framework/generator.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/bernoulli_op.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
// it can be consistent with cpu when CUDAGenerator is provided.
|
||||
template <typename T>
|
||||
struct BernoulliCudaFunctor {
|
||||
unsigned int seed_;
|
||||
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {}
|
||||
|
||||
__host__ __device__ T operator()(const unsigned int n, const T p) const {
|
||||
thrust::minstd_rand rng;
|
||||
rng.seed(seed_);
|
||||
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
|
||||
rng.discard(n);
|
||||
return static_cast<T>(dist(rng) < p);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BernoulliOpKernel<platform::CUDADeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
std::random_device rd;
|
||||
auto seed = rd();
|
||||
const auto x = ctx.Input<framework::Tensor>("X");
|
||||
auto out = ctx.Output<framework::Tensor>("Out");
|
||||
auto* in_data = x->data<T>();
|
||||
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int64_t size = x->numel();
|
||||
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
|
||||
platform::Transform<platform::CUDADeviceContext> trans;
|
||||
auto* context =
|
||||
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
|
||||
trans(*context, index_sequence_begin, index_sequence_begin + size, in_data,
|
||||
out_data, BernoulliCudaFunctor<T>(seed));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
bernoulli, ops::BernoulliOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::BernoulliOpKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,39 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/**
|
||||
* Samples a bernoulli distribution given a probability input
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
inline HOSTDEVICE T BernoulliFunctor(T p, T rand) {
|
||||
PADDLE_ENFORCE_LE(p, 1, platform::errors::OutOfRange(
|
||||
"The probability should be <= 1, but got %f", p));
|
||||
PADDLE_ENFORCE_GE(p, 0, platform::errors::OutOfRange(
|
||||
"The probability should be >= 1, but got %f", p));
|
||||
return static_cast<T>(rand < p);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BernoulliOpKernel;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
from op_test import OpTest
|
||||
import numpy as np
|
||||
|
||||
|
||||
def output_hist(out):
|
||||
hist, _ = np.histogram(out, bins=2)
|
||||
hist = hist.astype("float32")
|
||||
hist /= float(out.size)
|
||||
prob = 0.5 * np.ones((2))
|
||||
return hist, prob
|
||||
|
||||
|
||||
class TestBernoulliOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "bernoulli"
|
||||
self.inputs = {"X": np.random.uniform(size=(1000, 784))}
|
||||
self.init_attrs()
|
||||
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
|
||||
|
||||
def init_attrs(self):
|
||||
self.attrs = {}
|
||||
self.output_hist = output_hist
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
hist, prob = self.output_hist(np.array(outs[0]))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
|
||||
|
||||
|
||||
class TestBernoulliApi(unittest.TestCase):
|
||||
def test_dygraph(self):
|
||||
paddle.disable_static()
|
||||
x = paddle.rand([1024, 1024])
|
||||
out = paddle.bernoulli(x)
|
||||
paddle.enable_static()
|
||||
hist, prob = output_hist(out.numpy())
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
|
||||
|
||||
def test_static(self):
|
||||
x = paddle.rand([1024, 1024])
|
||||
out = paddle.bernoulli(x)
|
||||
exe = paddle.static.Executor(paddle.CPUPlace())
|
||||
out = exe.run(paddle.static.default_main_program(),
|
||||
fetch_list=[out.name])
|
||||
hist, prob = output_hist(out[0])
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue