add tensor support for gaussian_random_op test=develop (#24389)

release/2.0-alpha
wangchaochaohu 5 years ago committed by GitHub
parent da4a1db7bb
commit 53bdee64e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,48 +20,26 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline framework::DDim GetShape(const framework::ExecutionContext &ctx) {
inline framework::DDim GetShape(const framework::ExecutionContext &ctx,
std::string op_type) {
// 1. shape is a Tensor
if (ctx.HasInput("ShapeTensor")) {
auto *shape_tensor = ctx.Input<framework::LoDTensor>("ShapeTensor");
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto vec_shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
auto vec_shape = GetDataFromTensor<int>(shape_tensor);
return framework::make_ddim(vec_shape);
}
// 2. shape is a list/tuple containing Tensor
auto shape_tensor_list = ctx.MultiInput<framework::Tensor>("ShapeTensorList");
if (shape_tensor_list.size() > 0) {
std::vector<int> vec_shape;
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
auto tensor = shape_tensor_list[i];
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument(
"If the element type of 'shape'(tensor_list type) in "
"FillConstantOp is Tensor, the shape of this Tensor element must "
"be [1]. But received the Tensor element's shape is [%s]",
tensor->dims()));
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_shape.push_back(*temp.data<int>());
} else {
vec_shape.push_back(*tensor->data<int>());
}
}
auto vec_shape = GetDataFromTensorList(shape_tensor_list);
return framework::make_ddim(vec_shape);
}
@ -115,7 +93,8 @@ class FillConstantKernel : public framework::OpKernel<T> {
}
value = tensor_data[0];
}
auto shape = GetShape(ctx);
const std::string op_type = "fill_constant";
auto shape = GetShape(ctx, op_type);
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();

@ -14,7 +14,7 @@ limitations under the License. */
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
@ -22,8 +22,37 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
const std::string op_type = "gaussian_random";
auto shape = GetShape(context, op_type);
tensor->Resize(shape);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
};
template <typename T>
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
@ -58,12 +87,26 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim));
}
PADDLE_ENFORCE_GT(
shape.size(), 0UL,
platform::errors::InvalidArgument(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d",
shape.size()));
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int>(num_ele, -1);
ctx->SetOutputDim("Out", framework::make_ddim(vec_dims));
return;
}
if (!(ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList"))) {
PADDLE_ENFORCE_GT(
shape.size(), 0UL,
platform::errors::InvalidArgument(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d",
shape.size()));
}
ctx->SetOutputDim("Out", framework::make_ddim(temp));
}
@ -85,6 +128,16 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context(), layout, library);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
@ -94,7 +147,18 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) "
"The dimension of random tensor.");
"The dimension of random tensor.")
.SetDefault({});
AddInput("ShapeTensor",
"(Tensor<int>), optional). The shape of the output."
"It has a higher priority than Attr(shape).")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int>>, optional). The shape of the output. "
"It has a higher priority than Attr(shape)."
"The shape of the element in vector must be [1].")
.AsDuplicable()
.AsDispensable();
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
@ -135,5 +199,5 @@ REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
ops::CPUGaussianRandomBatchSizeLikeKernel<double>);

@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/transform.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
namespace paddle {
namespace operators {
@ -41,7 +42,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
if (seed == 0) {
std::random_device rd;
@ -50,6 +50,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
const std::string op_type = "gaussian_random";
auto shape = GetShape(context, op_type);
tensor->Resize(shape);
T* data = tensor->mutable_data<T>(context.GetPlace());
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
@ -57,12 +62,33 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(gaussian_random,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(
gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<float>,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<double>);

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"
namespace paddle {
@ -26,7 +27,6 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
@ -35,6 +35,11 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
const std::string op_type = "gaussian_random";
auto shape = GetShape(context, op_type);
tensor->Resize(shape);
T* data = tensor->mutable_data<T>(context.GetPlace());
int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);

@ -357,8 +357,9 @@ class Normal(Distribution):
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.)
normal_random_tmp = nn.gaussian_random_batch_size_like(
zero_tmp, zero_tmp.shape, mean=0., std=1., seed=seed)
zero_tmp_shape = nn.shape(zero_tmp)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, mean=0., std=1., seed=seed)
output = normal_random_tmp * (zero_tmp + self.scale) + self.loc
return nn.reshape(output, output_shape)
else:

@ -10169,33 +10169,55 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
Generate a random tensor whose data is drawn from a Gaussian distribution.
Args:
shape (Tuple[int] | List[int]): Shape of the generated random tensor.
shape (tuple[int] | list[int] | Variable | list[Variable]): Shape of the generated random tensor.
mean (float): Mean of the random tensor, defaults to 0.0.
std (float): Standard deviation of the random tensor, defaults to 1.0.
seed (int): ${seed_comment}
dtype(np.dtype | core.VarDesc.VarType | str): Output data type, float32 or float64.
Returns:
Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = fluid.layers.gaussian_random(shape=[3, 4])
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = fluid.layers.gaussian_random(shape=[dim_1, dim_2])
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = fluid.layers.gaussian_random(var_shape)
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = fluid.layers.gaussian_random(var_shape_int32)
.. code-block:: python
# declarative mode
import numpy as np
from paddle import fluid
x = fluid.layers.gaussian_random((2, 3), std=2., seed=10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
start = fluid.default_startup_program()
main = fluid.default_main_program()
exe.run(start)
x_np, = exe.run(main, feed={}, fetch_list=[x])
@ -10209,33 +10231,44 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
place = fluid.CPUPlace()
with dg.guard(place) as g:
x = fluid.layers.gaussian_random((2, 4), mean=2., dtype="float32", seed=10)
x_np = x.numpy()
x_np = x.numpy()
x_np
# array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ],
# [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32)
"""
helper = LayerHelper('gaussian_random', **locals())
check_type(shape, 'shape', (list, tuple), 'fluid.layers.gaussian_random')
check_dtype(dtype, 'dtype', ['float32', 'float64'],
'fluid.layers.gaussian_random')
out = helper.create_variable_for_type_inference(dtype)
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape)))
c_dtype = convert_np_dtype_to_dtype_(dtype)
attrs = {
'mean': mean,
'std': std,
'seed': seed,
'dtype': c_dtype,
'use_mkldnn': False
}
inputs = {}
utils._get_shape_tensor_inputs(
inputs=inputs,
helper=helper,
attrs=attrs,
shape=shape,
op_type='gaussian_random')
helper.append_op(
type='gaussian_random',
inputs=inputs,
outputs={'Out': out},
attrs={
'shape': shape,
'mean': mean,
'std': std,
'seed': seed,
'dtype': c_dtype,
'use_mkldnn': False
})
attrs=attrs)
return out

@ -27,17 +27,15 @@ class TestMKLDNNGaussianRandomOpSeed10(TestGaussianRandomOp):
class TestMKLDNNGaussianRandomOpSeed0(TestGaussianRandomOp):
def setUp(self):
TestGaussianRandomOp.setUp(self)
self.use_mkldnn = True
self.attrs = {
"shape": [1000, 784],
"mean": .0,
"std": 1.,
"seed": 0,
"shape": [123, 92],
"mean": 1.0,
"std": 2.0,
"seed": 10,
"use_mkldnn": self.use_mkldnn
}
def init_kernel_type(self):
self.use_mkldnn = True
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save