|
|
|
@ -26,8 +26,8 @@ template <typename T>
|
|
|
|
|
class GaussianRandomKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
|
|
|
|
|
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
|
|
|
|
|
float mean = context.op_.GetAttr<float>("mean");
|
|
|
|
|
float std = context.op_.GetAttr<float>("std");
|
|
|
|
|
auto* tensor = context.Output<framework::Tensor>(0);
|
|
|
|
|
T* data = tensor->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
@ -40,7 +40,6 @@ class GaussianRandomKernel : public framework::OpKernel {
|
|
|
|
|
&g, CURAND_RNG_PSEUDO_DEFAULT));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
|
|
|
|
|
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
|
|
|
|
|
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
|
|
|
|
|
std);
|
|
|
|
|
}
|
|
|
|
|