|
|
|
@ -78,27 +78,47 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *output_t = context.Output<LoDTensor>("Out");
|
|
|
|
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
size_t N = table_t->dims()[0];
|
|
|
|
|
size_t D = table_t->dims()[1];
|
|
|
|
|
size_t K = ids_t->numel();
|
|
|
|
|
|
|
|
|
|
auto *ids = ids_t->data<int64_t>();
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
|
|
|
|
|
if (padding_idx == -1)
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 128, 8, 8,
|
|
|
|
|
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
else
|
|
|
|
|
LookupTable<
|
|
|
|
|
T, 128, 8, 8,
|
|
|
|
|
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
auto id_name = context.Inputs("Ids").front();
|
|
|
|
|
auto out_name = context.Outputs("Out").front();
|
|
|
|
|
|
|
|
|
|
// for remote prefetch
|
|
|
|
|
auto epmap = context.Attr<std::vector<std::string>>("epmap");
|
|
|
|
|
auto height_sections = context.Attr<std::vector<int>>("height_sections");
|
|
|
|
|
auto table_names = context.Attr<std::vector<std::string>>("table_names");
|
|
|
|
|
|
|
|
|
|
if (!epmap.empty()) {
|
|
|
|
|
// if epmap is not empty, then the parameter will be fetched from remote
|
|
|
|
|
// parameter
|
|
|
|
|
// server
|
|
|
|
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
|
|
|
|
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
|
|
|
|
|
height_sections, context);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"paddle is not compiled with distribute support, can not do "
|
|
|
|
|
"parameter prefetch!");
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
size_t N = table_t->dims()[0];
|
|
|
|
|
size_t D = table_t->dims()[1];
|
|
|
|
|
size_t K = ids_t->numel();
|
|
|
|
|
|
|
|
|
|
auto *ids = ids_t->data<int64_t>();
|
|
|
|
|
auto *table = table_t->data<T>();
|
|
|
|
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grids(8, 1);
|
|
|
|
|
|
|
|
|
|
if (padding_idx == -1)
|
|
|
|
|
LookupTable<T, 128, 8, 8, false><<<
|
|
|
|
|
grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
else
|
|
|
|
|
LookupTable<T, 128, 8, 8, true><<<
|
|
|
|
|
grids, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
|
output, table, ids, N, K, D, padding_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto &dev_ctx =
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|
|
|
|
|
|
// Since paddings are not trainable and fixed in forward, the gradient of
|
|
|
|
|
// paddings makes no sense and we don't deal with it in backward.
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|