|
|
|
@ -44,11 +44,11 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
framework::Tensor* tensor(nullptr);
|
|
|
|
|
auto out_var = ctx.OutputVar("Out");
|
|
|
|
|
auto out_var = context.OutputVar("Out");
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
} else if (out_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto shape = ctx.Attr<std::vector<int>>("shape");
|
|
|
|
|
auto shape = context.Attr<std::vector<int>>("shape");
|
|
|
|
|
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(shape));
|
|
|
|
|
} else {
|
|
|
|
|