|
|
|
@ -48,7 +48,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
} else if (out_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto shape = context.Attr<std::vector<int>>("shape");
|
|
|
|
|
auto shape = context.Attr<std::vector<int64_t>>("shape");
|
|
|
|
|
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(shape));
|
|
|
|
|
} else {
|
|
|
|
|