|
|
|
@ -43,7 +43,7 @@ template <typename T>
|
|
|
|
|
class GPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
framework::Tensor* tensor(nullptr);
|
|
|
|
|
framework::Tensor* tensor = nullptr;
|
|
|
|
|
auto out_var = context.OutputVar("Out");
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
@ -52,7 +52,9 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(shape));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Only support SelectedRows and Tensor");
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"uniform_random_op's output only"
|
|
|
|
|
"supports SelectedRows and Tensor");
|
|
|
|
|
}
|
|
|
|
|
T* data = tensor->mutable_data<T>(context.GetPlace());
|
|
|
|
|
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
|
|
|
|
|