|
|
|
@ -50,7 +50,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
|
|
|
|
|
"uniform_random's min must less then max");
|
|
|
|
|
auto tensor = ctx.Output<framework::Tensor>(0);
|
|
|
|
|
auto* tensor = ctx.Output<framework::Tensor>(0);
|
|
|
|
|
auto dims = GetAttr<std::vector<int>>("dims");
|
|
|
|
|
tensor->Resize(framework::make_ddim(dims));
|
|
|
|
|
}
|
|
|
|
|