parent
ea6a251c0b
commit
009c049e82
@ -0,0 +1,170 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/uniform_random_op.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class CPURandintKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
std::vector<int64_t> new_shape;
|
||||
auto list_new_shape_tensor =
|
||||
ctx.MultiInput<framework::Tensor>("ShapeTensorList");
|
||||
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
|
||||
if (ctx.HasInput("ShapeTensor")) {
|
||||
auto* shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
|
||||
new_shape = GetNewDataFromShapeTensor(shape_tensor);
|
||||
} else if (list_new_shape_tensor.size() > 0) {
|
||||
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
auto* out = ctx.Output<framework::LoDTensor>("Out");
|
||||
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
|
||||
T* data = out->mutable_data<T>(ctx.GetPlace());
|
||||
int64_t size = out->numel();
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dist(ctx.Attr<int>("low"),
|
||||
ctx.Attr<int>("high") - 1);
|
||||
for (int64_t i = 0; i < size; ++i) data[i] = dist(gen);
|
||||
}
|
||||
};
|
||||
|
||||
class RandintOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument("Output(Out) of RandintOp is null."));
|
||||
PADDLE_ENFORCE_LT(
|
||||
ctx->Attrs().Get<int>("low"), ctx->Attrs().Get<int>("high"),
|
||||
platform::errors::InvalidArgument("randint's low must less then high, "
|
||||
"but received: low = %d, high = %d.",
|
||||
ctx->Attrs().Get<int>("low"),
|
||||
ctx->Attrs().Get<int>("high")));
|
||||
|
||||
if (ctx->HasInputs("ShapeTensorList")) {
|
||||
// top prority shape
|
||||
auto inputs_name = ctx->Inputs("ShapeTensorList");
|
||||
PADDLE_ENFORCE_GT(
|
||||
inputs_name.size(), 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(ShapeTensorList)'size of Op(randint) can't be zero."
|
||||
"Please check the Attr(shape)'s size of"
|
||||
"Op(fluid.layers.randint).)"));
|
||||
auto out_dims = std::vector<int>(inputs_name.size(), -1);
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
|
||||
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
|
||||
auto shape_dims = ctx->GetInputDim("ShapeTensor");
|
||||
PADDLE_ENFORCE_EQ(shape_dims.size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"ShapeError: Input(ShapeTensor)' dimension size of "
|
||||
"Op(randint) must be 1."
|
||||
"But received ShapeTensor's dimensions = %d.",
|
||||
shape_dims.size()));
|
||||
int num_ele = 1;
|
||||
for (int i = 0; i < shape_dims.size(); ++i) {
|
||||
num_ele *= shape_dims[i];
|
||||
}
|
||||
auto vec_dims = std::vector<int64_t>(num_ele, -1);
|
||||
auto out_dims = framework::make_ddim(vec_dims);
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
return;
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(shape.empty(), false,
|
||||
platform::errors::InvalidArgument(
|
||||
"if there is no Input(ShapeTensorList) and no "
|
||||
"Input(ShapeTensor),the "
|
||||
"attr(shape) information must "
|
||||
"be set by Attr(shape)."));
|
||||
|
||||
std::vector<int64_t> tensor_shape;
|
||||
tensor_shape.reserve(shape.size());
|
||||
for (auto dim : shape) {
|
||||
tensor_shape.push_back(static_cast<int64_t>(dim));
|
||||
}
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class RandintOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("ShapeTensor",
|
||||
"(Tensor<int64_t> or Tensor<int32_t>, optional) . If provided, "
|
||||
"randint"
|
||||
"according to "
|
||||
"this given shape. It means that it has a higher priority than "
|
||||
"Attr(shape) but a lower priority than Input(ShapeTensor).")
|
||||
.AsDispensable();
|
||||
AddInput("ShapeTensorList",
|
||||
"(vector<Tensor<int64_t>> or vector<Tensor<int32_t>>, optional). "
|
||||
"If provided, randint use this. The shape of the tensor "
|
||||
"must be [1], it has the highest priority comparing with "
|
||||
"Input(ShapeTensor) and attr(shape).")
|
||||
.AsDuplicable()
|
||||
.AsDispensable();
|
||||
AddOutput("Out", "The output tensor of randint op");
|
||||
AddComment(R"DOC(
|
||||
This operator initializes a tensor with random integers sampled from a
|
||||
uniform distribution. The random result is in set [low, high).
|
||||
)DOC");
|
||||
AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor.")
|
||||
.SetDefault({});
|
||||
AddAttr<int>("low",
|
||||
"The lower bound on the range of random values to generate.");
|
||||
AddAttr<int>("high",
|
||||
"The upper bound on the range of random values to generate.");
|
||||
AddAttr<int>("dtype", "Output tensor data type. [Default INT64].")
|
||||
.SetDefault(framework::proto::VarType::INT64);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
randint, ops::RandintOp, ops::RandintOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(randint, ops::CPURandintKernel<int>,
|
||||
ops::CPURandintKernel<int64_t>)
|
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/transform.h>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/uniform_random_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct UniformIntGenerator {
|
||||
T low_, high_;
|
||||
__host__ __device__ UniformIntGenerator(T low, T high)
|
||||
: low_(low), high_(high) {}
|
||||
|
||||
__host__ __device__ T operator()(const unsigned int n) const {
|
||||
thrust::minstd_rand rng;
|
||||
rng.seed(0);
|
||||
thrust::uniform_int_distribution<T> dist(low_, high_);
|
||||
rng.discard(n);
|
||||
T out = dist(rng);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Use std::uniform_int_distribution and thrust::uniform_int_distribution(thrust
|
||||
// is a std library in CUDA) to
|
||||
// implement randint.
|
||||
template <typename T>
|
||||
class GPURandintKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
std::vector<int64_t> new_shape;
|
||||
auto list_new_shape_tensor =
|
||||
context.MultiInput<framework::Tensor>("ShapeTensorList");
|
||||
if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) {
|
||||
if (context.HasInput("ShapeTensor")) {
|
||||
auto* shape_tensor = context.Input<framework::Tensor>("ShapeTensor");
|
||||
new_shape = GetNewDataFromShapeTensor(shape_tensor);
|
||||
} else if (list_new_shape_tensor.size() > 0) {
|
||||
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
auto* out = context.Output<framework::LoDTensor>("Out");
|
||||
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
|
||||
T* data = out->mutable_data<T>(context.GetPlace());
|
||||
T low = static_cast<T>(context.Attr<int>("low"));
|
||||
T high = static_cast<T>(context.Attr<int>("high")) - 1;
|
||||
|
||||
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
|
||||
int64_t size = out->numel();
|
||||
thrust::transform(index_sequence_begin, index_sequence_begin + size,
|
||||
thrust::device_ptr<T>(data),
|
||||
UniformIntGenerator<T>(low, high));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(randint, ops::GPURandintKernel<int>,
|
||||
ops::GPURandintKernel<int64_t>)
|
@ -0,0 +1,173 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
import paddle
|
||||
|
||||
|
||||
def output_hist(out):
|
||||
hist, _ = np.histogram(out, range=(-5, 10))
|
||||
hist = hist.astype("float32")
|
||||
hist /= float(out.size)
|
||||
prob = 0.1 * np.ones((10))
|
||||
return hist, prob
|
||||
|
||||
|
||||
class TestRandintOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "randint"
|
||||
self.inputs = {}
|
||||
self.init_attrs()
|
||||
self.outputs = {"Out": np.zeros((10000, 784)).astype("float32")}
|
||||
|
||||
def init_attrs(self):
|
||||
self.attrs = {"shape": [10000, 784], "low": -5, "high": 10}
|
||||
self.output_hist = output_hist
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
hist, prob = self.output_hist(np.array(outs[0]))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
|
||||
|
||||
|
||||
class TestRandintOpError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
main_prog = Program()
|
||||
start_prog = Program()
|
||||
with program_guard(main_prog, start_prog):
|
||||
|
||||
def test_shape():
|
||||
shape = np.array([2, 3])
|
||||
paddle.randint(5, shape=shape, dtype='int32')
|
||||
|
||||
self.assertRaises(TypeError, test_shape)
|
||||
|
||||
def test_dtype():
|
||||
paddle.randint(5, shape=[32, 32], dtype='float32')
|
||||
|
||||
self.assertRaises(TypeError, test_dtype)
|
||||
|
||||
def test_low_high():
|
||||
paddle.randint(low=5, high=5, shape=[32, 32], dtype='int32')
|
||||
|
||||
self.assertRaises(ValueError, test_low_high)
|
||||
|
||||
|
||||
class TestRandintOp_attr_tensorlist(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "randint"
|
||||
self.new_shape = (10000, 784)
|
||||
shape_tensor = []
|
||||
for index, ele in enumerate(self.new_shape):
|
||||
shape_tensor.append(("x" + str(index), np.ones(
|
||||
(1)).astype("int64") * ele))
|
||||
self.inputs = {'ShapeTensorList': shape_tensor}
|
||||
self.init_attrs()
|
||||
self.outputs = {"Out": np.zeros((10000, 784)).astype("int32")}
|
||||
|
||||
def init_attrs(self):
|
||||
self.attrs = {"low": -5, "high": 10}
|
||||
self.output_hist = output_hist
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
hist, prob = self.output_hist(np.array(outs[0]))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
|
||||
|
||||
|
||||
class TestRandint_attr_tensor(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "randint"
|
||||
self.inputs = {"ShapeTensor": np.array([10000, 784]).astype("int64")}
|
||||
self.init_attrs()
|
||||
self.outputs = {"Out": np.zeros((10000, 784)).astype("int64")}
|
||||
|
||||
def init_attrs(self):
|
||||
self.attrs = {"low": -5, "high": 10}
|
||||
self.output_hist = output_hist
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
hist, prob = self.output_hist(np.array(outs[0]))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
|
||||
|
||||
|
||||
# Test python API
|
||||
class TestRandintAPI(unittest.TestCase):
|
||||
def test_api(self):
|
||||
startup_program = fluid.Program()
|
||||
train_program = fluid.Program()
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
# results are from [0, 5).
|
||||
output1 = paddle.randint(5)
|
||||
# shape is a list and dtype is 'int32'
|
||||
output2 = paddle.randint(
|
||||
low=-100, high=100, shape=[64, 64], dtype='int32')
|
||||
# shape is a tuple and dtype is 'int64'
|
||||
output3 = paddle.randint(
|
||||
low=-100, high=100, shape=(32, 32, 3), dtype='int64')
|
||||
# shape is a tensorlist and dtype is 'float32'
|
||||
dim_1 = fluid.layers.fill_constant([1], "int64", 32)
|
||||
dim_2 = fluid.layers.fill_constant([1], "int32", 50)
|
||||
output4 = paddle.randint(
|
||||
low=-100, high=100, shape=[dim_1, 5], dtype='int32')
|
||||
# shape is a tensor and dtype is 'float64'
|
||||
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
|
||||
output5 = paddle.randint(
|
||||
low=1, high=1000, shape=var_shape, dtype='int64')
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
exe.run(startup_program)
|
||||
outs = exe.run(
|
||||
train_program,
|
||||
feed={'var_shape': np.array([100, 100]).astype('int64')},
|
||||
fetch_list=[output1, output2, output3, output4, output5])
|
||||
|
||||
|
||||
class TestRandintDygraphMode(unittest.TestCase):
|
||||
def test_check_output(self):
|
||||
with fluid.dygraph.guard():
|
||||
x = paddle.randint(10, shape=[10], dtype="int32")
|
||||
x_np = x.numpy()
|
||||
for i in range(10):
|
||||
self.assertTrue((x_np[i] >= 0 and x_np[i] < 10))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue