add multinomial op (#27219)
* add multinomial cpu kernel * fix C++ notype error * fix windows ci array len error * let array len be const * change array to vector * add cuda kernrl with num_distribution is 1, and not support replacement=False * add multinomial python api * support num_distribution different multinomial distributions * add multinomial python api unittest * change output dtype to int64 * fix coverage prob * optimize format * fix dtype of output error, should be int64_tmy_2.0rc
parent
d2369dd91f
commit
7cd2c13f1b
@ -0,0 +1,103 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/fluid/operators/multinomial_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/generator.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/common_infer_shape_functions.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "A tensor contains probabilities of categories");
|
||||
AddOutput("Out", "The output tensor of multinomial op");
|
||||
AddAttr<int>("num_samples", "number of the generated samples")
|
||||
.SetDefault(1);
|
||||
AddAttr<bool>("replacement", "can a category be sampled more than once")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities.
|
||||
|
||||
Out ~ Multinomial(X)
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class MultinomialOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial");
|
||||
|
||||
auto x_dim = ctx->GetInputDim("X");
|
||||
int64_t x_rank = x_dim.size();
|
||||
std::vector<int64_t> out_dims(x_rank);
|
||||
for (int64_t i = 0; i < x_rank - 1; i++) {
|
||||
out_dims[i] = x_dim[i];
|
||||
}
|
||||
|
||||
int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
|
||||
out_dims[x_rank - 1] = num_samples;
|
||||
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MultinomialOpKernel<platform::CPUDeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
const auto x = ctx.Input<framework::Tensor>("X");
|
||||
auto out = ctx.Output<framework::Tensor>("Out");
|
||||
const int64_t num_samples = ctx.Attr<int>("num_samples");
|
||||
const bool replacement = ctx.Attr<bool>("replacement");
|
||||
|
||||
auto *in_data = x->data<T>();
|
||||
int64_t *out_data = out->mutable_data<int64_t>(ctx.GetPlace());
|
||||
|
||||
auto in_dims = x->dims();
|
||||
int64_t in_rank = in_dims.size();
|
||||
const int64_t num_categories = in_dims[in_rank - 1];
|
||||
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
|
||||
|
||||
MultinomialFunctor<T>(out_data, in_data, num_samples, replacement,
|
||||
num_categories, num_distributions);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OPERATOR(
|
||||
multinomial, ops::MultinomialOp, ops::MultinomialOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
multinomial, ops::MultinomialOpKernel<plat::CPUDeviceContext, float>,
|
||||
ops::MultinomialOpKernel<plat::CPUDeviceContext, double>);
|
@ -0,0 +1,245 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/multinomial_op.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
|
||||
T* sum_rows) {
|
||||
int id = threadIdx.x + blockIdx.x * blockDim.x +
|
||||
blockIdx.y * gridDim.x * blockDim.x;
|
||||
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void GetCumulativeProbs(T* norm_probs_data,
|
||||
int64_t num_distributions,
|
||||
int64_t num_categories,
|
||||
T* cumulative_probs) {
|
||||
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
|
||||
thrust::inclusive_scan(thrust::device,
|
||||
norm_probs_data + id * num_categories,
|
||||
norm_probs_data + (id + 1) * num_categories,
|
||||
cumulative_probs + id * num_categories);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct RandomGeneratorCudaFunctor {
|
||||
unsigned int seed_;
|
||||
__host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {}
|
||||
|
||||
__host__ __device__ T operator()(const unsigned int n) const {
|
||||
thrust::minstd_rand rng;
|
||||
rng.seed(seed_);
|
||||
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
|
||||
rng.discard(n);
|
||||
return dist(rng);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data,
|
||||
int num_categories, T rng_number) {
|
||||
int left = 0;
|
||||
int right = num_categories;
|
||||
|
||||
while (right - left > 0) {
|
||||
int mid = left + (right - left) / 2;
|
||||
|
||||
T temp_prob = cumulative_probs[mid];
|
||||
if (temp_prob < rng_number) {
|
||||
left = mid + 1;
|
||||
} else {
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
|
||||
if (left == num_categories) {
|
||||
left = num_categories - 1;
|
||||
}
|
||||
|
||||
while (left >= 1 && norm_probs_data[left] == 0) left--;
|
||||
|
||||
return left;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void sampleMultinomialWithReplacement(
|
||||
T* rng_data, const int64_t num_samples, int64_t* out_data,
|
||||
const int64_t num_distributions, const int64_t num_categories,
|
||||
T* cumulative_probs, T* norm_probs_data) {
|
||||
// use binary search to get the selected category sample id.
|
||||
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
|
||||
|
||||
int idx = threadIdx.x + blockIdx.x * blockDim.x +
|
||||
blockIdx.y * gridDim.x * blockDim.x;
|
||||
|
||||
// for every distribution
|
||||
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
|
||||
// for every sample
|
||||
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
sample < num_samples; sample += blockDim.x * gridDim.x) {
|
||||
T rng_number = rng_data[sample + dist * num_samples];
|
||||
|
||||
// Find the bucket that a uniform random number lies in
|
||||
int selected_category = binarySearchFunctor<T>(
|
||||
cumulative_probs + dist * num_categories,
|
||||
norm_probs_data + dist * num_categories, num_categories, rng_number);
|
||||
|
||||
out_data[sample + dist * num_samples] = selected_category;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class MultinomialOpKernel<platform::CUDADeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto x = ctx.Input<framework::Tensor>("X");
|
||||
auto out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
const int64_t num_samples = ctx.Attr<int>("num_samples");
|
||||
const bool replacement = ctx.Attr<bool>("replacement");
|
||||
|
||||
auto* in_data = x->data<T>();
|
||||
int64_t* out_data = out->mutable_data<int64_t>(ctx.GetPlace());
|
||||
|
||||
auto in_dims = x->dims();
|
||||
int64_t in_rank = in_dims.size();
|
||||
const int64_t num_categories = in_dims[in_rank - 1];
|
||||
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
|
||||
|
||||
// If replacement is False, it's not a replaceable sample. Every category
|
||||
// can
|
||||
// be used only once. So after every sample, probability of the distribution
|
||||
// will change. The implementation can't be parallelizable. Thus, call CPU
|
||||
// implementation ``MultinomialFunctor`` to sample the distribution.
|
||||
if (!replacement) {
|
||||
int64_t in_data_numel = x->numel();
|
||||
int64_t out_data_numel = out->numel();
|
||||
|
||||
T* cpu_in_data = new T[in_data_numel];
|
||||
int64_t* cpu_out_data = new int64_t[out_data_numel];
|
||||
|
||||
cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
|
||||
cudaMemcpyDeviceToHost);
|
||||
|
||||
MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement,
|
||||
num_categories, num_distributions);
|
||||
cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
delete[] cpu_in_data;
|
||||
delete[] cpu_out_data;
|
||||
return;
|
||||
}
|
||||
|
||||
// Sum of input may not be 1. To get probability in range [0, 1], calculate
|
||||
// sum of each row of input, and then use the sum to normalize the input.
|
||||
// sum_row_data: sum of each row
|
||||
framework::Tensor sum_rows_tensor;
|
||||
auto* sum_rows_data =
|
||||
sum_rows_tensor.mutable_data<T>({num_distributions}, ctx.GetPlace());
|
||||
|
||||
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
|
||||
if (num_distributions == 1) {
|
||||
auto eigen_input = framework::EigenVector<T>::Flatten(*x);
|
||||
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
|
||||
eigen_sum_rows.device(place) =
|
||||
eigen_input.sum(Eigen::DSizes<int, 1>(1))
|
||||
.eval()
|
||||
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]));
|
||||
} else {
|
||||
auto eigen_input = framework::EigenMatrix<T>::From(*x);
|
||||
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
|
||||
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
|
||||
}
|
||||
|
||||
// Normalize row of each distribution to get the probability in range [0,
|
||||
// 1].
|
||||
// norm_probs_data: probability of the distribution
|
||||
framework::Tensor norm_probs_tensor;
|
||||
auto* norm_probs_data = norm_probs_tensor.mutable_data<T>(
|
||||
{num_distributions, num_categories}, ctx.GetPlace());
|
||||
|
||||
// number of threads in a block is min(num_categories, 512)
|
||||
dim3 block_norm(num_categories < 512 ? num_categories : 512);
|
||||
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
|
||||
NormalizeProbability<
|
||||
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
|
||||
norm_probs_data, in_data, sum_rows_data);
|
||||
|
||||
// Get cumulative probability of each distribution. It's the same function
|
||||
// of
|
||||
// ``cumsum`` op.
|
||||
framework::Tensor cumulative_probs_tensor;
|
||||
auto* cumulative_probs = cumulative_probs_tensor.mutable_data<T>(
|
||||
{num_distributions, num_categories}, ctx.GetPlace());
|
||||
dim3 block_cumsum(1);
|
||||
dim3 grid_cumsum(num_distributions);
|
||||
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0,
|
||||
ctx.cuda_device_context().stream()>>>(
|
||||
norm_probs_data, num_distributions, num_categories, cumulative_probs);
|
||||
|
||||
// Generate random number for each sample.
|
||||
std::random_device rd;
|
||||
auto seed = rd();
|
||||
|
||||
framework::Tensor rng_data_tensor;
|
||||
auto* rng_data = rng_data_tensor.mutable_data<T>(
|
||||
{num_distributions, num_samples}, ctx.GetPlace());
|
||||
|
||||
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
|
||||
platform::Transform<platform::CUDADeviceContext> trans;
|
||||
auto* context =
|
||||
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
|
||||
trans(*context, index_sequence_begin,
|
||||
index_sequence_begin + num_distributions * num_samples, rng_data,
|
||||
RandomGeneratorCudaFunctor<T>(seed));
|
||||
|
||||
// Sample the multinomial distributions.
|
||||
dim3 block_sample(128);
|
||||
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
|
||||
sampleMultinomialWithReplacement<T><<<grid_sample, block_sample, 0,
|
||||
ctx.cuda_device_context().stream()>>>(
|
||||
rng_data, num_samples, out_data, num_distributions, num_categories,
|
||||
cumulative_probs, norm_probs_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::MultinomialOpKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,127 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/generator.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/**
|
||||
* Samples a multinomial distribution given a probability input
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
void MultinomialFunctor(int64_t* out_data, const T* in_data,
|
||||
const int64_t num_samples, const bool replacement,
|
||||
const int64_t num_categories,
|
||||
const int64_t num_distributions) {
|
||||
std::vector<T> cumulative_probs(num_categories);
|
||||
|
||||
std::uniform_real_distribution<T> dist(0, 1);
|
||||
auto gen_ptr = framework::DefaultCPUGenerator();
|
||||
auto engine = gen_ptr->GetCPUEngine();
|
||||
|
||||
for (int64_t i = 0; i < num_distributions; i++) {
|
||||
T probs_sum = 0;
|
||||
T prob_value;
|
||||
int64_t num_zeros = 0;
|
||||
for (int64_t j = 0; j < num_categories; j++) {
|
||||
prob_value = in_data[i * num_categories + j];
|
||||
PADDLE_ENFORCE_GE(
|
||||
prob_value, 0.0,
|
||||
platform::errors::OutOfRange(
|
||||
"The input of multinomial distribution should be >= 0"));
|
||||
PADDLE_ENFORCE_EQ((std::isinf(static_cast<double>(prob_value)) ||
|
||||
std::isnan(static_cast<double>(prob_value))),
|
||||
false, platform::errors::OutOfRange(
|
||||
"The input of multinomial distribution "
|
||||
"shoud not be infinity or NaN"));
|
||||
probs_sum += prob_value;
|
||||
if (prob_value == 0) {
|
||||
num_zeros += 1;
|
||||
}
|
||||
cumulative_probs[j] = probs_sum;
|
||||
}
|
||||
PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange(
|
||||
"The sum of input should not be 0"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
(replacement || (num_categories - num_zeros >= num_samples)), true,
|
||||
platform::errors::OutOfRange("When replacement is False, number of "
|
||||
"samples should be less than non-zero "
|
||||
"categories"));
|
||||
|
||||
for (int64_t j = 0; j < num_categories; j++) {
|
||||
cumulative_probs[j] /= probs_sum;
|
||||
}
|
||||
|
||||
for (int64_t s = 0; s < num_samples; s++) {
|
||||
T uniform_rand = dist(*engine);
|
||||
// use binary search to get the selected category sample id.
|
||||
// let cumulative_probs[id-1] < uniform_rand < cumulative_probs[id].
|
||||
int64_t left = 0;
|
||||
int64_t right = num_categories;
|
||||
int64_t mid;
|
||||
int64_t sample_id;
|
||||
T temp_prob;
|
||||
cumulative_probs[(num_categories - 1)] = 1;
|
||||
|
||||
while (right > left) {
|
||||
mid = left + (right - left) / 2;
|
||||
temp_prob = cumulative_probs[mid];
|
||||
if (temp_prob < uniform_rand) {
|
||||
left = mid + 1;
|
||||
} else {
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
sample_id = left;
|
||||
|
||||
out_data[i * num_samples + s] = sample_id;
|
||||
|
||||
// if replacement is false, the selected category should be removed.
|
||||
if (!replacement && s < num_samples - 1) {
|
||||
T sample_prob;
|
||||
T new_prob = 0;
|
||||
T new_sum;
|
||||
|
||||
if (sample_id != 0) {
|
||||
new_prob = cumulative_probs[sample_id - 1];
|
||||
}
|
||||
sample_prob = cumulative_probs[sample_id] - new_prob;
|
||||
new_sum = 1.0 - sample_prob;
|
||||
|
||||
for (int64_t j = 0; j < num_categories; j++) {
|
||||
new_prob = cumulative_probs[j];
|
||||
if (j >= sample_id) {
|
||||
new_prob -= sample_prob;
|
||||
}
|
||||
new_prob /= new_sum;
|
||||
cumulative_probs[j] = new_prob;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class MultinomialOpKernel;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,179 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from op_test import OpTest
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestMultinomialOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "multinomial"
|
||||
self.init_data()
|
||||
self.inputs = {"X": self.input_np}
|
||||
|
||||
def init_data(self):
|
||||
# input probability is a vector, and replacement is True
|
||||
self.input_np = np.random.rand(4)
|
||||
self.outputs = {"Out": np.zeros(100000).astype("int64")}
|
||||
self.attrs = {"num_samples": 100000, "replacement": True}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def sample_output(self, out):
|
||||
# count numbers of different categories
|
||||
sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
|
||||
sample_prob /= sample_prob.sum()
|
||||
return sample_prob
|
||||
|
||||
def verify_output(self, outs):
|
||||
# normalize the input to get the probability
|
||||
prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True)
|
||||
sample_prob = self.sample_output(np.array(outs[0]))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
sample_prob, prob, rtol=0, atol=0.01),
|
||||
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
|
||||
|
||||
|
||||
class TestMultinomialOp2(TestMultinomialOp):
|
||||
def init_data(self):
|
||||
# input probability is a matrix
|
||||
self.input_np = np.random.rand(3, 4)
|
||||
self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")}
|
||||
self.attrs = {"num_samples": 100000, "replacement": True}
|
||||
|
||||
def sample_output(self, out):
|
||||
out_list = np.split(out, 3, axis=0)
|
||||
count_array = [0] * 3
|
||||
for i in range(3):
|
||||
count_array[i] = np.unique(
|
||||
out_list[i], return_counts=True)[1].astype("float32")
|
||||
sample_prob = np.stack(count_array, axis=0)
|
||||
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
|
||||
return sample_prob
|
||||
|
||||
|
||||
class TestMultinomialOp3(TestMultinomialOp):
|
||||
def init_data(self):
|
||||
# replacement is False. number of samples must be less than number of categories.
|
||||
self.input_np = np.random.rand(1000)
|
||||
self.outputs = {"Out": np.zeros(100).astype("int64")}
|
||||
self.attrs = {"num_samples": 100, "replacement": False}
|
||||
|
||||
def verify_output(self, outs):
|
||||
out = np.array(outs[0])
|
||||
unique_out = np.unique(out)
|
||||
self.assertEqual(
|
||||
len(unique_out), 100,
|
||||
"replacement is False. categories can't be sampled repeatedly")
|
||||
|
||||
|
||||
class TestMultinomialApi(unittest.TestCase):
|
||||
def test_dygraph(self):
|
||||
# input probability is a vector, and replacement is True
|
||||
paddle.disable_static()
|
||||
x = paddle.rand([4])
|
||||
out = paddle.multinomial(x, num_samples=100000, replacement=True)
|
||||
x_numpy = x.numpy()
|
||||
paddle.enable_static()
|
||||
|
||||
sample_prob = np.unique(
|
||||
out.numpy(), return_counts=True)[1].astype("float32")
|
||||
sample_prob /= sample_prob.sum()
|
||||
|
||||
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
sample_prob, prob, rtol=0, atol=0.01),
|
||||
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
|
||||
|
||||
def test_dygraph2(self):
|
||||
# input probability is a matrix, and replacement is True
|
||||
paddle.disable_static()
|
||||
x = paddle.rand([3, 4])
|
||||
out = paddle.multinomial(x, num_samples=100000, replacement=True)
|
||||
x_numpy = x.numpy()
|
||||
|
||||
out_list = np.split(out.numpy(), 3, axis=0)
|
||||
count_array = [0] * 3
|
||||
for i in range(3):
|
||||
count_array[i] = np.unique(
|
||||
out_list[i], return_counts=True)[1].astype("float32")
|
||||
sample_prob = np.stack(count_array, axis=0)
|
||||
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
|
||||
|
||||
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
sample_prob, prob, rtol=0, atol=0.01),
|
||||
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
|
||||
paddle.enable_static()
|
||||
|
||||
def test_dygraph3(self):
|
||||
# replacement is False. number of samples must be less than number of categories.
|
||||
paddle.disable_static()
|
||||
x = paddle.rand([1000])
|
||||
out = paddle.multinomial(x, num_samples=100, replacement=False)
|
||||
x_numpy = x.numpy()
|
||||
|
||||
unique_out = np.unique(out.numpy())
|
||||
self.assertEqual(
|
||||
len(unique_out), 100,
|
||||
"replacement is False. categories can't be sampled repeatedly")
|
||||
paddle.enable_static()
|
||||
|
||||
def test_static(self):
|
||||
paddle.enable_static()
|
||||
startup_program = fluid.Program()
|
||||
train_program = fluid.Program()
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
x = fluid.data('x', shape=[4], dtype='float32')
|
||||
out = paddle.multinomial(x, num_samples=100000, replacement=True)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
exe.run(startup_program)
|
||||
x_np = np.random.rand(4).astype('float32')
|
||||
out = exe.run(train_program, feed={'x': x_np}, fetch_list=[out])
|
||||
|
||||
sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
|
||||
sample_prob /= sample_prob.sum()
|
||||
|
||||
prob = x_np / x_np.sum(axis=-1, keepdims=True)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
sample_prob, prob, rtol=0, atol=0.01),
|
||||
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
|
||||
|
||||
|
||||
class TestMultinomialAlias(unittest.TestCase):
|
||||
def test_alias(self):
|
||||
paddle.disable_static()
|
||||
x = paddle.rand([4])
|
||||
paddle.multinomial(x, num_samples=10, replacement=True)
|
||||
paddle.tensor.multinomial(x, num_samples=10, replacement=True)
|
||||
paddle.tensor.random.multinomial(x, num_samples=10, replacement=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue