|
|
|
@ -166,11 +166,12 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::set<T> st(labels.begin(), labels.end());
|
|
|
|
|
labels.assign(st.begin(), st.end());
|
|
|
|
|
|
|
|
|
|
auto &local_scope = context.scope().NewScope();
|
|
|
|
|
framework::Scope &local_scope = context.scope().NewScope();
|
|
|
|
|
|
|
|
|
|
auto height_sections = context.Attr<std::vector<int>>("height_sections");
|
|
|
|
|
auto table_names = context.Attr<std::vector<std::string>>("table_names");
|
|
|
|
|
|
|
|
|
|
auto *ids = local_scope.Var("Ids@Local");
|
|
|
|
|
auto *ids = local_scope.Var("Ids@Prefetch");
|
|
|
|
|
auto *x_tensor = ids->GetMutable<framework::LoDTensor>();
|
|
|
|
|
x_tensor->mutable_data<int64_t>(
|
|
|
|
|
framework::make_ddim({static_cast<int64_t>(labels.size()), 1}),
|
|
|
|
@ -179,12 +180,18 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::memcpy(x_tensor->data<int64_t>(), labels.data(),
|
|
|
|
|
labels.size() * sizeof(int64_t));
|
|
|
|
|
|
|
|
|
|
local_scope.Var("Weight@Local");
|
|
|
|
|
std::vector<int> w_dims = paddle::framework::vectorize2int(
|
|
|
|
|
context.Input<Tensor>("Weight")->dims());
|
|
|
|
|
w_dims[0] = static_cast<int>(labels.size());
|
|
|
|
|
|
|
|
|
|
auto *w_tensor = local_scope.Var("Weight@Prefetch")
|
|
|
|
|
->GetMutable<framework::LoDTensor>();
|
|
|
|
|
w_tensor->Resize(framework::make_ddim(w_dims));
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
|
|
|
|
operators::distributed::prefetch("Ids@Local", "Weight@Local", table_names,
|
|
|
|
|
epmap, height_sections, context,
|
|
|
|
|
&local_scope);
|
|
|
|
|
operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch",
|
|
|
|
|
table_names, epmap, height_sections,
|
|
|
|
|
context, local_scope);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"paddle is not compiled with distribute support, can not do "
|
|
|
|
@ -192,7 +199,7 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
auto weight_mat = EigenMatrix<T>::From(
|
|
|
|
|
(local_scope.Var("Weight@Local")->Get<framework::LoDTensor>()));
|
|
|
|
|
(local_scope.Var("Weight@Prefetch")->Get<framework::LoDTensor>()));
|
|
|
|
|
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
|
|
|
|
|
std::vector<int64_t>::iterator it =
|
|
|
|
|
std::find(labels.begin(), labels.end(), sample_labels_data[i]);
|
|
|
|
|