|
|
|
@ -23,7 +23,7 @@ template <typename T>
|
|
|
|
|
class GaussianRandomOpKernel<platform::CPUPlace, T>
|
|
|
|
|
: public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::KernelContext& context) const override {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto mean = context.op_.GetAttr<T>("mean");
|
|
|
|
|
auto std = context.op_.GetAttr<T>("std");
|
|
|
|
|
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
|
|
|
|
@ -41,15 +41,14 @@ class GaussianRandomOpKernel<platform::CPUPlace, T>
|
|
|
|
|
|
|
|
|
|
class GaussianRandomOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(
|
|
|
|
|
const std::vector<const framework::Tensor*>& inputs,
|
|
|
|
|
const std::vector<framework::Tensor*>& outputs) const override {
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero.");
|
|
|
|
|
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one.");
|
|
|
|
|
PADDLE_ENFORCE(outputs[0] != nullptr,
|
|
|
|
|
"Outputs of RandomOp must all be set.");
|
|
|
|
|
outputs[0]->Resize(
|
|
|
|
|
framework::make_ddim(this->GetAttr<std::vector<int>>("shape")));
|
|
|
|
|
auto* tensor = ctx.Output<Tensor>(0);
|
|
|
|
|
auto dims = GetAttr(std::vector<int>("shape"));
|
|
|
|
|
tensor->Resize(framework::make_ddim(dims));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|