parent
fcd6f64b98
commit
e2c08d286f
@ -1,30 +1,39 @@
|
|||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include "paddle/platform/dynload/curand.h"
|
||||||
|
#include "paddle/platform/gpu_info.h"
|
||||||
|
|
||||||
#include "paddle/framework/op_registry.h"
|
#include "paddle/framework/op_registry.h"
|
||||||
#include "paddle/operators/guassian_random_op.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
namespace paddle {
|
||||||
namespace operators {
|
namespace operators {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class GaussianRandomOpKernel<platform::GPUPlace, T>
|
class GaussianRandomKernel : public framework::OpKernel {
|
||||||
: public framework::OpKernel {
|
|
||||||
public:
|
public:
|
||||||
void Compute(const framework::KernelContext& context) const override {
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
auto mean = context.op_.GetAttr<T>("mean");
|
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
|
||||||
auto std = context.op_.GetAttr<T>("std");
|
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
|
||||||
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
|
auto* tensor = context.Output<framework::Tensor>(0);
|
||||||
T* r = output->mutable_data<T>(context.GetPlace());
|
T* data = tensor->mutable_data<T>(context.GetPlace());
|
||||||
auto ctx =
|
|
||||||
static_cast<const platform::GPUDeviceContext*>(context.device_context_);
|
int seed = context.op_.GetAttr<int>("seed");
|
||||||
// generator need to modify context
|
if (seed == 0) {
|
||||||
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
|
seed = std::random_device()();
|
||||||
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
|
}
|
||||||
|
curandGenerator_t g;
|
||||||
|
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
|
||||||
|
&g, CURAND_RNG_PSEUDO_DEFAULT));
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
|
||||||
|
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
|
||||||
|
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
|
||||||
|
std);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace operators
|
} // namespace operators
|
||||||
} // namespace paddle
|
} // namespace paddle
|
||||||
|
|
||||||
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace,
|
namespace ops = paddle::operators;
|
||||||
float>
|
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
|
||||||
RandomOpKernel_GPU_float;
|
|
||||||
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include <random>
|
|
||||||
#include "glog/logging.h"
|
|
||||||
#include "paddle/framework/eigen.h"
|
|
||||||
#include "paddle/framework/operator.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
|
|
||||||
template <typename Place, typename T>
|
|
||||||
class GaussianRandomOpKernel : public framework::OpKernel {
|
|
||||||
public:
|
|
||||||
void Compute(const framework::KernelContext& context) const override {}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
||||||
Loading…
Reference in new issue