|
|
|
@ -135,7 +135,12 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
alias_data, alias_probs_data, seed);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
default: { PADDLE_THROW("Unsupported SamplerType."); }
|
|
|
|
|
default: {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Unsupported SamplerType. SamplerType should be 0: Uniform, "
|
|
|
|
|
"1: LogUniform or 2: CostumDist. Received SamplerType: %d",
|
|
|
|
|
sampler_type));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PrepareSamples<DeviceContext, T>(context, sampler);
|
|
|
|
@ -225,9 +230,9 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
weight, false, table_names, epmap,
|
|
|
|
|
context, local_scope);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
PADDLE_THROW(platform::errors::PreconditionNotMet(
|
|
|
|
|
"paddle is not compiled with distribute support, can not do "
|
|
|
|
|
"parameter prefetch!");
|
|
|
|
|
"parameter prefetch!"));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
auto weight_mat = EigenMatrix<T>::From(
|
|
|
|
@ -347,7 +352,12 @@ class NCEGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
alias_data, alias_probs_data, seed);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
default: { PADDLE_THROW("Unsupported SamplerType."); }
|
|
|
|
|
default: {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Unsupported SamplerType. SamplerType should be 0: Uniform, "
|
|
|
|
|
"1: LogUniform or 2: CostumDist. Received SamplerType: %d",
|
|
|
|
|
sampler_type));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// T b = 1. / num_total_classes * num_neg_samples;
|
|
|
|
@ -409,9 +419,9 @@ class NCEGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("Weight");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The parameter Weight of a NCE_OP "
|
|
|
|
|
"must be either LoDTensor or SelectedRows");
|
|
|
|
|
"must be either LoDTensor or SelectedRows"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));
|
|
|
|
|