"remove unused test net modified"

revert-3824-remove_grad_op_type
dongzhihong 8 years ago
parent df4fe671fe
commit 6bac3e17b5

@ -22,8 +22,8 @@ template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
@ -35,7 +35,8 @@ class GaussianRandomKernel : public framework::OpKernel {
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(tensor->dims()); ++i) {
ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g);
}
}

@ -26,8 +26,8 @@ template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
@ -40,7 +40,6 @@ class GaussianRandomKernel : public framework::OpKernel {
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
std);
}

@ -14,13 +14,15 @@ class GaussianRandomTest(unittest.TestCase):
def test_gaussian_random(self, place):
scope = core.Scope()
scope.new_var("Out").get_tensor()
op = Operator(
"gaussian_random",
Out="Out",
dims=[1000, 784],
mean=.0,
std=1.,
seed=0)
seed=10)
op.infer_shape(scope)
context = core.DeviceContext.create(place)
op.run(scope, context)

Loading…
Cancel
Save