|
|
@ -116,7 +116,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
auto* d_output_data = d_output->data<T>();
|
|
|
|
auto* d_output_data = d_output->data<T>();
|
|
|
|
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
|
|
|
|
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
|
|
|
|
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
|
|
|
|
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
|
|
|
|
d_output->numel(), stream);
|
|
|
|
d_output->numel() * sizeof(T), stream);
|
|
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
auto ids_t = context.Input<LoDTensor>("Ids");
|
|
|
|
auto ids_t = context.Input<LoDTensor>("Ids");
|
|
|
|