|
|
|
@ -25,63 +25,64 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
|
|
|
|
|
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
|
|
|
|
|
"and K is the class number.");
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"(Tensor) The ground truth which is a 2-D tensor. Label is a "
|
|
|
|
|
AddInput("Labels",
|
|
|
|
|
"(Tensor) The ground truth which is a 2-D tensor. Labels is a "
|
|
|
|
|
"Tensor<int64> with shape [N x NT], where NT is the number of"
|
|
|
|
|
"true labels for each example.");
|
|
|
|
|
AddInput(
|
|
|
|
|
"CustomSamples",
|
|
|
|
|
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shaoe [N x "
|
|
|
|
|
"S+NT]."
|
|
|
|
|
"The customized sample labels with true labels at first. This tensor"
|
|
|
|
|
"is only use_custom_samples is true.")
|
|
|
|
|
AddInput("CustomizedSamples",
|
|
|
|
|
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
|
|
|
|
|
"NT + S],"
|
|
|
|
|
" where N is the batch size, NT is the number of true labels "
|
|
|
|
|
"and S is the number of negtive sample for each example."
|
|
|
|
|
"The first NT elements of each row should be the same with true "
|
|
|
|
|
"labels, "
|
|
|
|
|
"followed by S custom negtive samples. This tensor"
|
|
|
|
|
"is only used when use_customized_samples is true.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput(
|
|
|
|
|
"CustomProbabilities",
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shaoe [N x S+NT]."
|
|
|
|
|
"The customized sample probabilities with true labels at first. This "
|
|
|
|
|
"tensor is only use_custom_samples is true.")
|
|
|
|
|
"CustomizedProbabilities",
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
|
|
|
|
|
"The tensor has the same shape with CustomSamples,"
|
|
|
|
|
"and each element represents probability of element in CustomSamples. "
|
|
|
|
|
"This "
|
|
|
|
|
"tensor is only used when use_customized_samples is true.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput(
|
|
|
|
|
"Samples",
|
|
|
|
|
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N x "
|
|
|
|
|
"S+NT]."
|
|
|
|
|
"The outputs value of sampler by given the true label, where S is the "
|
|
|
|
|
"number of negative sample for each example. So Samples includes NT "
|
|
|
|
|
"true"
|
|
|
|
|
"labels and S negative labels for each example. This will be used in"
|
|
|
|
|
"backward calculation.")
|
|
|
|
|
AddOutput("Samples",
|
|
|
|
|
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
|
|
|
|
|
"NT + S]."
|
|
|
|
|
"The outputs value of sampler, including NT true lables and S "
|
|
|
|
|
"negetive samples "
|
|
|
|
|
"for each example. This will be used in"
|
|
|
|
|
"backward calculation.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput(
|
|
|
|
|
"Probabilities",
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x "
|
|
|
|
|
"S+NT]."
|
|
|
|
|
"The outputs value of progabilites of samples by given the true label, "
|
|
|
|
|
"where S is the "
|
|
|
|
|
"number of negative sample for each example. So Samples includes NT "
|
|
|
|
|
"true"
|
|
|
|
|
"labels and S negative labels for each example.")
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
|
|
|
|
|
"The probabilites of sampled positive and negtive labels.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("SampledLogits",
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
|
|
|
|
|
"[N x S+NT]. The outputs value of sample logits, which will be"
|
|
|
|
|
"used in backward calculation.")
|
|
|
|
|
"[N, NT + S]. The outputs value of sampled logits, which will be"
|
|
|
|
|
"used in backward propagation.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput(
|
|
|
|
|
"SampledLabel",
|
|
|
|
|
"(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled label"
|
|
|
|
|
"with shape [N x S + NT].");
|
|
|
|
|
"SampledLabels",
|
|
|
|
|
"(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled labels"
|
|
|
|
|
"with shape [N, NT]. The tonsor contains hard labels as input to "
|
|
|
|
|
" softmax op, that is 0, 1, …, NT-1 because of the first NT elements"
|
|
|
|
|
" of Sampels are positive lables.");
|
|
|
|
|
AddAttr<bool>(
|
|
|
|
|
"use_custom_samples",
|
|
|
|
|
"An indicator whether to use custom samples with probabilities, if True"
|
|
|
|
|
"the operator will use custom samples and custom probabilities"
|
|
|
|
|
"use_customized_samples",
|
|
|
|
|
"An indicator whether to use customized samples with probabilities, if "
|
|
|
|
|
"True"
|
|
|
|
|
"the operator will use customized samples and customized probabilities"
|
|
|
|
|
"otherwise, the operator will generate them by itself.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<bool>(
|
|
|
|
|
"uniq",
|
|
|
|
|
"An indicator whether to sample non-repetitive negtive labels, if True"
|
|
|
|
|
"the operator will sample negtive labels without replacement."
|
|
|
|
|
"otherwise, the operator will sample negtive labels with replacement.")
|
|
|
|
|
"Otherwise, the operator will sample negtive labels with replacement.")
|
|
|
|
|
.SetDefault(true);
|
|
|
|
|
AddAttr<bool>(
|
|
|
|
|
"remove_accidental_hits",
|
|
|
|
@ -95,8 +96,7 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
"""
|
|
|
|
|
Computes sampled output training logits and labels suitable for implementing
|
|
|
|
|
sampled softmax.
|
|
|
|
|
|
|
|
|
|
sampled softmax.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
@ -110,7 +110,8 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
|
|
|
|
"Input(Logits) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
|
|
|
|
"Input(Labels) should be not null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Samples"),
|
|
|
|
|
"Output(Samples) should be not null.");
|
|
|
|
@ -118,11 +119,11 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(Probabilities) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"),
|
|
|
|
|
"Output(SampledLogits) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("SampledLabel"),
|
|
|
|
|
"Output(SampledLabel) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"),
|
|
|
|
|
"Output(SampledLabels) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto logits_dims = ctx->GetInputDim("Logits");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Labels");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
logits_dims.size(), 2UL,
|
|
|
|
@ -135,7 +136,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
|
|
|
|
|
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
|
|
|
|
|
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
|
|
|
|
|
ctx->SetOutputDim("SampledLabel", {logits_dims[0], labels_dims[1]});
|
|
|
|
|
ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -144,7 +145,6 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits"));
|
|
|
|
|
framework::OpKernelType kt =
|
|
|
|
|
framework::OpKernelType(data_type, ctx.device_context());
|
|
|
|
|
// kt.place_ = platform::CPUPlace();
|
|
|
|
|
return kt;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -157,7 +157,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
|
|
|
|
"Input(Logits) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
|
|
|
|
"Input(Labels) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Samples"),
|
|
|
|
|
"Input(Samples) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
|
|
|
|
@ -168,7 +169,7 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(Logits@Grad) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto logit_dims = ctx->GetInputDim("Logits");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Labels");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
|
|
|
|
|
"The label should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL,
|
|
|
|
@ -185,7 +186,6 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
ctx.InputVar(framework::GradVarName("SampledLogits")));
|
|
|
|
|
framework::OpKernelType kt =
|
|
|
|
|
framework::OpKernelType(data_type, ctx.device_context());
|
|
|
|
|
// kt.place_ = platform::CPUPlace();
|
|
|
|
|
return kt;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -200,7 +200,7 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
auto* grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("sample_logits_grad");
|
|
|
|
|
grad_op->SetInput("Logits", Input("Logits"));
|
|
|
|
|
grad_op->SetInput("Label", Input("Label"));
|
|
|
|
|
grad_op->SetInput("Labels", Input("Labels"));
|
|
|
|
|
grad_op->SetInput("Samples", Output("Samples"));
|
|
|
|
|
grad_op->SetInput("SampledLogits", Output("SampledLogits"));
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("SampledLogits"),
|
|
|
|
|