|
|
@ -23,7 +23,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(const framework::InferShapeContext& context) const override {
|
|
|
|
void InferShape(const framework::InferShapeContext& context) const override {
|
|
|
|
auto* tensor = context.Output<framework::Tensor>(0);
|
|
|
|
auto* tensor = context.Output<framework::Tensor>("Out");
|
|
|
|
auto dims = GetAttr<std::vector<int>>("dims");
|
|
|
|
auto dims = GetAttr<std::vector<int>>("dims");
|
|
|
|
PADDLE_ENFORCE(dims.size() > 0UL,
|
|
|
|
PADDLE_ENFORCE(dims.size() > 0UL,
|
|
|
|
"dims can be one int or array. dims must be set.");
|
|
|
|
"dims can be one int or array. dims must be set.");
|
|
|
|