Generating random numbers with given batch size (#8337)
* Generating random numbers with given batch size uniform_random_batch_size_like_op gaussian_random_batch_size_like_op * More comments about random seed. * Move test_*_random_batch_size_like_op to unittestsemailweixu-patch-1
parent
118d950e74
commit
6752b06f8c
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/batch_size_like.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void BatchSizeLikeOp::InferShape(framework::InferShapeContext *ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of %s should not be null.", Type());
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of %s should not be null.",
|
||||
Type());
|
||||
|
||||
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
||||
PADDLE_ENFORCE_GT(shape.size(), 0);
|
||||
std::vector<int64_t> shape_int64(shape.size(), 0);
|
||||
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
||||
[](int a) { return static_cast<int64_t>(a); });
|
||||
auto output_dim = framework::make_ddim(shape_int64);
|
||||
|
||||
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
|
||||
PADDLE_ENFORCE_GE(input_dim_idx, 0);
|
||||
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
|
||||
|
||||
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
|
||||
PADDLE_ENFORCE_GE(output_dim_idx, 0);
|
||||
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
|
||||
|
||||
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
|
||||
ctx->SetOutputDim("Out", output_dim);
|
||||
}
|
||||
|
||||
BatchSizeLikeOpMaker::BatchSizeLikeOpMaker(OpProto *proto,
|
||||
OpAttrChecker *op_checker)
|
||||
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input",
|
||||
"(Tensor) Tensor "
|
||||
"whose input_dim_idx'th dimension specifies the batch_size");
|
||||
AddOutput("Out",
|
||||
"(Tensor) Tensor of specified shape will be filled "
|
||||
"with the specified value");
|
||||
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
|
||||
AddAttr<int>("input_dim_idx",
|
||||
"(int, default 0) The index of input's batch size dimension")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>("output_dim_idx",
|
||||
"(int, default 0) The index of output's batch size dimension")
|
||||
.SetDefault(0);
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,36 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BatchSizeLikeOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override;
|
||||
};
|
||||
|
||||
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker);
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/batch_size_like.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
|
||||
protected:
|
||||
using BatchSizeLikeOp::BatchSizeLikeOp;
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
|
||||
public:
|
||||
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: BatchSizeLikeOpMaker(proto, op_checker) {
|
||||
AddAttr<float>("mean",
|
||||
"(float, default 0.0) "
|
||||
"mean of random tensor.")
|
||||
.SetDefault(.0f);
|
||||
AddAttr<float>("std",
|
||||
"(float, default 1.0) "
|
||||
"std of random tensor.")
|
||||
.SetDefault(1.0f);
|
||||
AddAttr<int>("seed",
|
||||
"(int, default 0) "
|
||||
"Random seed of generator."
|
||||
"0 means use system wide seed."
|
||||
"Note that if seed is not 0, this operator will always "
|
||||
"generate the same random numbers every time.")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>("dtype",
|
||||
"(int, default 5(FP32)) "
|
||||
"Output data type.")
|
||||
.SetDefault(framework::proto::DataType::FP32);
|
||||
|
||||
AddComment(R"DOC(
|
||||
GaussianRandom Operator.
|
||||
|
||||
Used to initialize tensors with gaussian random generator.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(
|
||||
gaussian_random_batch_size_like,
|
||||
paddle::operators::GaussianRandomBatchSizeLikeOp,
|
||||
paddle::operators::GaussianRandomBatchSizeLikeOpMaker);
|
||||
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
|
@ -0,0 +1,72 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/batch_size_like.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
|
||||
protected:
|
||||
using BatchSizeLikeOp::BatchSizeLikeOp;
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
|
||||
public:
|
||||
UniformRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: BatchSizeLikeOpMaker(proto, op_checker) {
|
||||
AddComment(R"DOC(
|
||||
Uniform random operator
|
||||
|
||||
This operator initializes a tensor with the same batch_size as the Input tensor
|
||||
with random values sampled from a uniform distribution.
|
||||
|
||||
)DOC");
|
||||
AddAttr<float>("min",
|
||||
"(float, default -1.0) "
|
||||
"Minimum value of uniform random")
|
||||
.SetDefault(-1.0f);
|
||||
AddAttr<float>("max",
|
||||
"(float, default 1.0) "
|
||||
"Maximun value of uniform random")
|
||||
.SetDefault(1.0f);
|
||||
AddAttr<int>("seed",
|
||||
"(int, default 0) "
|
||||
"Random seed used for generating samples. "
|
||||
"0 means use a seed generated by the system."
|
||||
"Note that if seed is not 0, this operator will always "
|
||||
"generate the same random numbers every time.")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
|
||||
.SetDefault(framework::proto::DataType::FP32);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(
|
||||
uniform_random_batch_size_like,
|
||||
paddle::operators::UniformRandomBatchSizeLikeOp,
|
||||
paddle::operators::UniformRandomBatchSizeLikeOpMaker);
|
||||
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
|
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestGaussianRandomBatchSizeLike(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "gaussian_random_batch_size_like"
|
||||
self.inputs = {'Input': np.zeros((500, 2000), dtype="float32")}
|
||||
self.attrs = {'mean': 1., 'std': 2., 'shape': [-1, 2000]}
|
||||
self.outputs = {'Out': np.zeros((500, 2000), dtype='float32')}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
self.assertEqual(outs[0].shape, (500, 2000))
|
||||
hist, _ = np.histogram(outs[0], range=(-3, 5))
|
||||
hist = hist.astype("float32")
|
||||
hist /= float(outs[0].size)
|
||||
data = np.random.normal(size=(500, 2000), loc=1, scale=2)
|
||||
hist2, _ = np.histogram(data, range=(-3, 5))
|
||||
hist2 = hist2.astype("float32")
|
||||
hist2 /= float(outs[0].size)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, hist2, rtol=0, atol=0.01),
|
||||
"hist: " + str(hist) + " hist2: " + str(hist2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestUniformRandomBatchSizeLike(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "uniform_random_batch_size_like"
|
||||
self.inputs = {'Input': np.zeros((500, 2000), dtype="float32")}
|
||||
self.attrs = {'min': 1., 'max': 2., 'shape': [-1, 2000]}
|
||||
self.outputs = {'Out': np.zeros((500, 2000), dtype='float32')}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
self.assertEqual(outs[0].shape, (500, 2000))
|
||||
hist, _ = np.histogram(outs[0], range=(1, 2))
|
||||
hist = hist.astype("float32")
|
||||
hist /= float(outs[0].size)
|
||||
prob = 0.1 * np.ones((10))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue