|
|
|
@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_probs->numel(), num_total_classes);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_probs->numel(), num_total_classes));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAlias) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_alias->numel(), num_total_classes);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAlias) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_alias->numel(), num_total_classes));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAliasProbs).numel() = %d, "
|
|
|
|
|
"Attr(num_total_classes) = %d.",
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: The number of elements in "
|
|
|
|
|
"Input(CustomDistAliasProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAliasProbs).numel() = %d, "
|
|
|
|
|
"Attr(num_total_classes) = %d.",
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes));
|
|
|
|
|
|
|
|
|
|
const float *probs_data = dist_probs->data<float>();
|
|
|
|
|
const int *alias_data = dist_alias->data<int>();
|
|
|
|
@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
for (int x = 0; x < sample_labels->numel(); x++) {
|
|
|
|
|
PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
|
|
|
|
|
"ValueError: Every sample label should be "
|
|
|
|
|
"non-negative. But received: "
|
|
|
|
|
"Input(SampleLabels)[%d] = %d",
|
|
|
|
|
x, sample_labels_data[x]);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ValueError: Every sample label should be "
|
|
|
|
|
"non-negative. But received: "
|
|
|
|
|
"Input(SampleLabels)[%d] = %d",
|
|
|
|
|
x, sample_labels_data[x]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sample_out = context.Output<Tensor>("SampleLogits");
|
|
|
|
@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_probs->numel(), num_total_classes);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_probs->numel(), num_total_classes));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAlias) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_alias->numel(), num_total_classes);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAlias) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
|
|
|
|
|
"= %d.",
|
|
|
|
|
dist_alias->numel(), num_total_classes));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes,
|
|
|
|
|
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAliasProbs).numel() = %d, "
|
|
|
|
|
"Attr(num_total_classes) = %d.",
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: The number of elements in "
|
|
|
|
|
"Input(CustomDistAliasProbs) "
|
|
|
|
|
"should be equal to the number of total classes. But Received: "
|
|
|
|
|
"Input(CustomDistAliasProbs).numel() = %d, "
|
|
|
|
|
"Attr(num_total_classes) = %d.",
|
|
|
|
|
dist_alias_probs->numel(), num_total_classes));
|
|
|
|
|
|
|
|
|
|
const float *probs_data = dist_probs->data<float>();
|
|
|
|
|
const int *alias_data = dist_alias->data<int>();
|
|
|
|
|