|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/sample_logits_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "paddle/fluid/operators/math/sample_prob.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -60,6 +61,10 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
|
|
|
|
|
"The probabilites of sampled positive and negtive labels.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("LogitsDim", "Store dim information of Logits for gradient op")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("LabelsDim", "Store dim information of Logits for gradient op")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("SampledLogits",
|
|
|
|
|
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
|
|
|
|
|
"[N, NT + S]. The outputs value of sampled logits, which will be"
|
|
|
|
@ -121,6 +126,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(SampledLogits) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"),
|
|
|
|
|
"Output(SampledLabels) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LogitsDim"),
|
|
|
|
|
"Output(LogitsDim) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LabelsDim"),
|
|
|
|
|
"Output(LabelsDim) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto logits_dims = ctx->GetInputDim("Logits");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Labels");
|
|
|
|
@ -137,6 +146,15 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
|
|
|
|
|
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
|
|
|
|
|
ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]});
|
|
|
|
|
|
|
|
|
|
// append 0 to shape variable to avoid optimized by memory optimize pass
|
|
|
|
|
auto logits_dim_vec = framework::vectorize(logits_dims);
|
|
|
|
|
logits_dim_vec.push_back(0);
|
|
|
|
|
ctx->SetOutputDim("LogitsDim", framework::make_ddim(logits_dim_vec));
|
|
|
|
|
|
|
|
|
|
auto labels_dim_vec = framework::vectorize(labels_dims);
|
|
|
|
|
labels_dim_vec.push_back(0);
|
|
|
|
|
ctx->SetOutputDim("LabelsDim", framework::make_ddim(labels_dim_vec));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -155,28 +173,27 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
|
|
|
|
"Input(Logits) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
|
|
|
|
"Input(Labels) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LogitsDim"),
|
|
|
|
|
"Input(LogitsDim) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LabelsDim"),
|
|
|
|
|
"Input(LabelsDim) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Samples"),
|
|
|
|
|
"Input(Samples) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
|
|
|
|
|
"Input(SampledLogits) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")),
|
|
|
|
|
"Input(SampledLogits@Grad) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
|
|
|
|
|
"Output(Logits@Grad) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto logit_dims = ctx->GetInputDim("Logits");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Labels");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
|
|
|
|
|
auto logits_dims = ctx->GetInputDim("LogitsDim");
|
|
|
|
|
logits_dims = framework::DDim(logits_dims.Get(), logits_dims.size() - 1);
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("LabelsDim");
|
|
|
|
|
labels_dims = framework::DDim(labels_dims.Get(), labels_dims.size() - 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
|
|
|
|
"The label should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL,
|
|
|
|
|
PADDLE_ENFORCE_EQ(logits_dims.size(), 2UL,
|
|
|
|
|
"The logits should be a 2-D tensor.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Logits"),
|
|
|
|
|
ctx->GetInputDim("Logits"));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Logits"), logits_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -199,10 +216,9 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto* grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("sample_logits_grad");
|
|
|
|
|
grad_op->SetInput("Logits", Input("Logits"));
|
|
|
|
|
grad_op->SetInput("Labels", Input("Labels"));
|
|
|
|
|
grad_op->SetInput("LogitsDim", Output("LogitsDim"));
|
|
|
|
|
grad_op->SetInput("LabelsDim", Output("LabelsDim"));
|
|
|
|
|
grad_op->SetInput("Samples", Output("Samples"));
|
|
|
|
|
grad_op->SetInput("SampledLogits", Output("SampledLogits"));
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("SampledLogits"),
|
|
|
|
|
OutputGrad("SampledLogits"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
|
|
|
|
|