parent
							
								
									ea6a251c0b
								
							
						
					
					
						commit
						009c049e82
					
				| @ -0,0 +1,170 @@ | ||||
| // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| #include "paddle/fluid/framework/operator.h" | ||||
| #include "paddle/fluid/operators/uniform_random_op.h" | ||||
| #include "paddle/fluid/platform/enforce.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| template <typename T> | ||||
| class CPURandintKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     std::vector<int64_t> new_shape; | ||||
|     auto list_new_shape_tensor = | ||||
|         ctx.MultiInput<framework::Tensor>("ShapeTensorList"); | ||||
|     if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) { | ||||
|       if (ctx.HasInput("ShapeTensor")) { | ||||
|         auto* shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor"); | ||||
|         new_shape = GetNewDataFromShapeTensor(shape_tensor); | ||||
|       } else if (list_new_shape_tensor.size() > 0) { | ||||
|         new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     auto* out = ctx.Output<framework::LoDTensor>("Out"); | ||||
|     if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape)); | ||||
|     T* data = out->mutable_data<T>(ctx.GetPlace()); | ||||
|     int64_t size = out->numel(); | ||||
|     std::random_device rd; | ||||
|     std::mt19937 gen(rd()); | ||||
|     std::uniform_int_distribution<> dist(ctx.Attr<int>("low"), | ||||
|                                          ctx.Attr<int>("high") - 1); | ||||
|     for (int64_t i = 0; i < size; ++i) data[i] = dist(gen); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class RandintOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
| 
 | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE_EQ( | ||||
|         ctx->HasOutput("Out"), true, | ||||
|         platform::errors::InvalidArgument("Output(Out) of RandintOp is null.")); | ||||
|     PADDLE_ENFORCE_LT( | ||||
|         ctx->Attrs().Get<int>("low"), ctx->Attrs().Get<int>("high"), | ||||
|         platform::errors::InvalidArgument("randint's low must less then high, " | ||||
|                                           "but received: low = %d, high = %d.", | ||||
|                                           ctx->Attrs().Get<int>("low"), | ||||
|                                           ctx->Attrs().Get<int>("high"))); | ||||
| 
 | ||||
|     if (ctx->HasInputs("ShapeTensorList")) { | ||||
|       // top prority shape
 | ||||
|       auto inputs_name = ctx->Inputs("ShapeTensorList"); | ||||
|       PADDLE_ENFORCE_GT( | ||||
|           inputs_name.size(), 0, | ||||
|           platform::errors::InvalidArgument( | ||||
|               "Input(ShapeTensorList)'size of Op(randint) can't be zero." | ||||
|               "Please check the Attr(shape)'s size of" | ||||
|               "Op(fluid.layers.randint).)")); | ||||
|       auto out_dims = std::vector<int>(inputs_name.size(), -1); | ||||
|       ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); | ||||
| 
 | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape"); | ||||
|     if (ctx->HasInput("ShapeTensor") && shape.empty()) { | ||||
|       auto shape_dims = ctx->GetInputDim("ShapeTensor"); | ||||
|       PADDLE_ENFORCE_EQ(shape_dims.size(), 1, | ||||
|                         platform::errors::InvalidArgument( | ||||
|                             "ShapeError: Input(ShapeTensor)' dimension size of " | ||||
|                             "Op(randint) must be 1." | ||||
|                             "But received ShapeTensor's dimensions = %d.", | ||||
|                             shape_dims.size())); | ||||
|       int num_ele = 1; | ||||
|       for (int i = 0; i < shape_dims.size(); ++i) { | ||||
|         num_ele *= shape_dims[i]; | ||||
|       } | ||||
|       auto vec_dims = std::vector<int64_t>(num_ele, -1); | ||||
|       auto out_dims = framework::make_ddim(vec_dims); | ||||
|       ctx->SetOutputDim("Out", out_dims); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     PADDLE_ENFORCE_EQ(shape.empty(), false, | ||||
|                       platform::errors::InvalidArgument( | ||||
|                           "if there is no Input(ShapeTensorList) and no " | ||||
|                           "Input(ShapeTensor),the " | ||||
|                           "attr(shape) information must " | ||||
|                           "be set by Attr(shape).")); | ||||
| 
 | ||||
|     std::vector<int64_t> tensor_shape; | ||||
|     tensor_shape.reserve(shape.size()); | ||||
|     for (auto dim : shape) { | ||||
|       tensor_shape.push_back(static_cast<int64_t>(dim)); | ||||
|     } | ||||
|     ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   framework::OpKernelType GetExpectedKernelType( | ||||
|       const framework::ExecutionContext& ctx) const override { | ||||
|     return framework::OpKernelType( | ||||
|         static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")), | ||||
|         ctx.GetPlace()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class RandintOpMaker : public framework::OpProtoAndCheckerMaker { | ||||
|  public: | ||||
|   void Make() override { | ||||
|     AddInput("ShapeTensor", | ||||
|              "(Tensor<int64_t> or Tensor<int32_t>, optional) . If provided, " | ||||
|              "randint" | ||||
|              "according to " | ||||
|              "this given shape. It means that it has a higher priority than " | ||||
|              "Attr(shape) but a lower priority than Input(ShapeTensor).") | ||||
|         .AsDispensable(); | ||||
|     AddInput("ShapeTensorList", | ||||
|              "(vector<Tensor<int64_t>> or vector<Tensor<int32_t>>, optional). " | ||||
|              "If provided, randint use this. The shape of the tensor " | ||||
|              "must be [1], it has the highest priority comparing with " | ||||
|              "Input(ShapeTensor) and attr(shape).") | ||||
|         .AsDuplicable() | ||||
|         .AsDispensable(); | ||||
|     AddOutput("Out", "The output tensor of randint op"); | ||||
|     AddComment(R"DOC( | ||||
| This operator initializes a tensor with random integers sampled from a | ||||
| uniform distribution. The random result is in set [low, high). | ||||
| )DOC"); | ||||
|     AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor.") | ||||
|         .SetDefault({}); | ||||
|     AddAttr<int>("low", | ||||
|                  "The lower bound on the range of random values to generate."); | ||||
|     AddAttr<int>("high", | ||||
|                  "The upper bound on the range of random values to generate."); | ||||
|     AddAttr<int>("dtype", "Output tensor data type. [Default INT64].") | ||||
|         .SetDefault(framework::proto::VarType::INT64); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| 
 | ||||
| REGISTER_OPERATOR( | ||||
|     randint, ops::RandintOp, ops::RandintOpMaker, | ||||
|     paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||||
|     paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>) | ||||
| 
 | ||||
| REGISTER_OP_CPU_KERNEL(randint, ops::CPURandintKernel<int>, | ||||
|                        ops::CPURandintKernel<int64_t>) | ||||
| @ -0,0 +1,76 @@ | ||||
| // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| #include <thrust/random.h> | ||||
| #include <thrust/transform.h> | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| #include "paddle/fluid/operators/uniform_random_op.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| template <typename T> | ||||
| struct UniformIntGenerator { | ||||
|   T low_, high_; | ||||
|   __host__ __device__ UniformIntGenerator(T low, T high) | ||||
|       : low_(low), high_(high) {} | ||||
| 
 | ||||
|   __host__ __device__ T operator()(const unsigned int n) const { | ||||
|     thrust::minstd_rand rng; | ||||
|     rng.seed(0); | ||||
|     thrust::uniform_int_distribution<T> dist(low_, high_); | ||||
|     rng.discard(n); | ||||
|     T out = dist(rng); | ||||
|     return out; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Use std::uniform_int_distribution and thrust::uniform_int_distribution(thrust | ||||
| // is a std library in CUDA) to | ||||
| // implement randint. | ||||
| template <typename T> | ||||
| class GPURandintKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& context) const override { | ||||
|     std::vector<int64_t> new_shape; | ||||
|     auto list_new_shape_tensor = | ||||
|         context.MultiInput<framework::Tensor>("ShapeTensorList"); | ||||
|     if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) { | ||||
|       if (context.HasInput("ShapeTensor")) { | ||||
|         auto* shape_tensor = context.Input<framework::Tensor>("ShapeTensor"); | ||||
|         new_shape = GetNewDataFromShapeTensor(shape_tensor); | ||||
|       } else if (list_new_shape_tensor.size() > 0) { | ||||
|         new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     auto* out = context.Output<framework::LoDTensor>("Out"); | ||||
|     if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape)); | ||||
|     T* data = out->mutable_data<T>(context.GetPlace()); | ||||
|     T low = static_cast<T>(context.Attr<int>("low")); | ||||
|     T high = static_cast<T>(context.Attr<int>("high")) - 1; | ||||
| 
 | ||||
|     thrust::counting_iterator<unsigned int> index_sequence_begin(0); | ||||
|     int64_t size = out->numel(); | ||||
|     thrust::transform(index_sequence_begin, index_sequence_begin + size, | ||||
|                       thrust::device_ptr<T>(data), | ||||
|                       UniformIntGenerator<T>(low, high)); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators | ||||
| }  // namespace paddle | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OP_CUDA_KERNEL(randint, ops::GPURandintKernel<int>, | ||||
|                         ops::GPURandintKernel<int64_t>) | ||||
| @ -0,0 +1,173 @@ | ||||
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import unittest | ||||
| import numpy as np | ||||
| from op_test import OpTest | ||||
| 
 | ||||
| import paddle.fluid.core as core | ||||
| from paddle.fluid.op import Operator | ||||
| import paddle.fluid as fluid | ||||
| from paddle.fluid import Program, program_guard | ||||
| import paddle | ||||
| 
 | ||||
| 
 | ||||
| def output_hist(out): | ||||
|     hist, _ = np.histogram(out, range=(-5, 10)) | ||||
|     hist = hist.astype("float32") | ||||
|     hist /= float(out.size) | ||||
|     prob = 0.1 * np.ones((10)) | ||||
|     return hist, prob | ||||
| 
 | ||||
| 
 | ||||
| class TestRandintOp(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "randint" | ||||
|         self.inputs = {} | ||||
|         self.init_attrs() | ||||
|         self.outputs = {"Out": np.zeros((10000, 784)).astype("float32")} | ||||
| 
 | ||||
|     def init_attrs(self): | ||||
|         self.attrs = {"shape": [10000, 784], "low": -5, "high": 10} | ||||
|         self.output_hist = output_hist | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output_customized(self.verify_output) | ||||
| 
 | ||||
|     def verify_output(self, outs): | ||||
|         hist, prob = self.output_hist(np.array(outs[0])) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 hist, prob, rtol=0, atol=0.1), "hist: " + str(hist)) | ||||
| 
 | ||||
| 
 | ||||
| class TestRandintOpError(unittest.TestCase): | ||||
|     def test_errors(self): | ||||
|         main_prog = Program() | ||||
|         start_prog = Program() | ||||
|         with program_guard(main_prog, start_prog): | ||||
| 
 | ||||
|             def test_shape(): | ||||
|                 shape = np.array([2, 3]) | ||||
|                 paddle.randint(5, shape=shape, dtype='int32') | ||||
| 
 | ||||
|             self.assertRaises(TypeError, test_shape) | ||||
| 
 | ||||
|             def test_dtype(): | ||||
|                 paddle.randint(5, shape=[32, 32], dtype='float32') | ||||
| 
 | ||||
|             self.assertRaises(TypeError, test_dtype) | ||||
| 
 | ||||
|             def test_low_high(): | ||||
|                 paddle.randint(low=5, high=5, shape=[32, 32], dtype='int32') | ||||
| 
 | ||||
|             self.assertRaises(ValueError, test_low_high) | ||||
| 
 | ||||
| 
 | ||||
| class TestRandintOp_attr_tensorlist(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "randint" | ||||
|         self.new_shape = (10000, 784) | ||||
|         shape_tensor = [] | ||||
|         for index, ele in enumerate(self.new_shape): | ||||
|             shape_tensor.append(("x" + str(index), np.ones( | ||||
|                 (1)).astype("int64") * ele)) | ||||
|         self.inputs = {'ShapeTensorList': shape_tensor} | ||||
|         self.init_attrs() | ||||
|         self.outputs = {"Out": np.zeros((10000, 784)).astype("int32")} | ||||
| 
 | ||||
|     def init_attrs(self): | ||||
|         self.attrs = {"low": -5, "high": 10} | ||||
|         self.output_hist = output_hist | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output_customized(self.verify_output) | ||||
| 
 | ||||
|     def verify_output(self, outs): | ||||
|         hist, prob = self.output_hist(np.array(outs[0])) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 hist, prob, rtol=0, atol=0.1), "hist: " + str(hist)) | ||||
| 
 | ||||
| 
 | ||||
| class TestRandint_attr_tensor(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "randint" | ||||
|         self.inputs = {"ShapeTensor": np.array([10000, 784]).astype("int64")} | ||||
|         self.init_attrs() | ||||
|         self.outputs = {"Out": np.zeros((10000, 784)).astype("int64")} | ||||
| 
 | ||||
|     def init_attrs(self): | ||||
|         self.attrs = {"low": -5, "high": 10} | ||||
|         self.output_hist = output_hist | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output_customized(self.verify_output) | ||||
| 
 | ||||
|     def verify_output(self, outs): | ||||
|         hist, prob = self.output_hist(np.array(outs[0])) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 hist, prob, rtol=0, atol=0.1), "hist: " + str(hist)) | ||||
| 
 | ||||
| 
 | ||||
| # Test python API | ||||
| class TestRandintAPI(unittest.TestCase): | ||||
|     def test_api(self): | ||||
|         startup_program = fluid.Program() | ||||
|         train_program = fluid.Program() | ||||
|         with fluid.program_guard(train_program, startup_program): | ||||
|             # results are from [0, 5). | ||||
|             output1 = paddle.randint(5) | ||||
|             # shape is a list and dtype is 'int32' | ||||
|             output2 = paddle.randint( | ||||
|                 low=-100, high=100, shape=[64, 64], dtype='int32') | ||||
|             # shape is a tuple and dtype is 'int64' | ||||
|             output3 = paddle.randint( | ||||
|                 low=-100, high=100, shape=(32, 32, 3), dtype='int64') | ||||
|             # shape is a tensorlist and dtype is 'float32' | ||||
|             dim_1 = fluid.layers.fill_constant([1], "int64", 32) | ||||
|             dim_2 = fluid.layers.fill_constant([1], "int32", 50) | ||||
|             output4 = paddle.randint( | ||||
|                 low=-100, high=100, shape=[dim_1, 5], dtype='int32') | ||||
|             # shape is a tensor and dtype is 'float64' | ||||
|             var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") | ||||
|             output5 = paddle.randint( | ||||
|                 low=1, high=1000, shape=var_shape, dtype='int64') | ||||
| 
 | ||||
|             place = fluid.CPUPlace() | ||||
|             if fluid.core.is_compiled_with_cuda(): | ||||
|                 place = fluid.CUDAPlace(0) | ||||
|             exe = fluid.Executor(place) | ||||
| 
 | ||||
|             exe.run(startup_program) | ||||
|             outs = exe.run( | ||||
|                 train_program, | ||||
|                 feed={'var_shape': np.array([100, 100]).astype('int64')}, | ||||
|                 fetch_list=[output1, output2, output3, output4, output5]) | ||||
| 
 | ||||
| 
 | ||||
| class TestRandintDygraphMode(unittest.TestCase): | ||||
|     def test_check_output(self): | ||||
|         with fluid.dygraph.guard(): | ||||
|             x = paddle.randint(10, shape=[10], dtype="int32") | ||||
|             x_np = x.numpy() | ||||
|             for i in range(10): | ||||
|                 self.assertTrue((x_np[i] >= 0 and x_np[i] < 10)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue