|
|
|
|
@ -33,9 +33,9 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
|
|
|
|
|
if (ctx.HasInput("ShapeTensor")) {
|
|
|
|
|
auto *shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
|
|
|
|
|
new_shape = get_new_data_from_shape_tensor(shape_tensor);
|
|
|
|
|
new_shape = GetNewDataFromShapeTensor(shape_tensor);
|
|
|
|
|
} else if (list_new_shape_tensor.size() > 0) {
|
|
|
|
|
new_shape = get_new_shape_from_shape_tensorlist(list_new_shape_tensor);
|
|
|
|
|
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -169,14 +169,14 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("ShapeTensor",
|
|
|
|
|
"(Tensor<int64_t>, optional). If provided, uniform_ranodom "
|
|
|
|
|
"according to "
|
|
|
|
|
"this given shape. That is to say it has a higher priority than "
|
|
|
|
|
"this given shape. It means that it has a higher priority than "
|
|
|
|
|
"the shape attribute, while the shape attribute still should be "
|
|
|
|
|
"set correctly to gurantee shape inference in compile time.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput("ShapeTensorList",
|
|
|
|
|
"(vector<Tensor<int64_t>>, optional). If provided, uniform_random "
|
|
|
|
|
"will use this"
|
|
|
|
|
"The shape of the tensor in vector MUST BE [1]"
|
|
|
|
|
"use this."
|
|
|
|
|
"The shape of the tensor in vector MUST BE [1],"
|
|
|
|
|
"it has the highest priority compare with Input(Shape) and "
|
|
|
|
|
"attr(shape).")
|
|
|
|
|
.AsDuplicable()
|
|
|
|
|
|