|
|
|
@ -102,9 +102,27 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto dist_alias = context.Input<Tensor>("CustomDistAlias");
|
|
|
|
|
auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_probs->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAlias) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_alias->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAliasProbs).numel() = %d, "
|
|
|
|
|
"Attr(num_total_classes) = %d.",
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes);
|
|
|
|
|
|
|
|
|
|
const float *probs_data = dist_probs->data<float>();
|
|
|
|
|
const int *alias_data = dist_alias->data<int>();
|
|
|
|
@ -121,7 +139,11 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int64_t *sample_labels_data = sample_labels->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
for (int x = 0; x < sample_labels->numel(); x++) {
|
|
|
|
|
PADDLE_ENFORCE_GE(sample_labels_data[x], 0, "nce sample label %d", x);
|
|
|
|
|
PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
|
|
|
|
|
"ValueError: Every sample label should be "
|
|
|
|
|
"non-negative. But received: "
|
|
|
|
|
"Input(SampleLabels)[%d] = %d",
|
|
|
|
|
x, sample_labels_data[x]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sample_out = context.Output<Tensor>("SampleLogits");
|
|
|
|
@ -289,9 +311,27 @@ class NCEGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto dist_alias = context.Input<Tensor>("CustomDistAlias");
|
|
|
|
|
auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_probs->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAlias) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_alias->numel(), num_total_classes);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAliasProbs).numel() = %d, "
|
|
|
|
|
"Attr(num_total_classes) = %d.",
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes);
|
|
|
|
|
|
|
|
|
|
const float *probs_data = dist_probs->data<float>();
|
|
|
|
|
const int *alias_data = dist_alias->data<int>();
|
|
|
|
|