|
|
|
@ -26,7 +26,6 @@ public:
|
|
|
|
|
void Compute(const framework::KernelContext& context) const override {
|
|
|
|
|
auto mean = context.op_.GetAttr<T>("mean");
|
|
|
|
|
auto std = context.op_.GetAttr<T>("std");
|
|
|
|
|
// auto seed = context.op_.GetAttr<T>("seed");
|
|
|
|
|
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
|
|
|
|
|
T* r = output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto ctx =
|
|
|
|
@ -60,7 +59,6 @@ public:
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized");
|
|
|
|
|
// AddAttr<float>("seed", "random seed generator.").SetDefault(1337);
|
|
|
|
|
AddAttr<float>("mean", "mean value of random.").SetDefault(.0);
|
|
|
|
|
AddAttr<float>("std", "minimum value of random value")
|
|
|
|
|
.SetDefault(1.0)
|
|
|
|
|