|
|
|
@ -34,15 +34,22 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
|
|
|
|
|
"Input(Label)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
|
|
|
|
|
"The 2nd dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same rank.");
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
|
|
|
|
|
framework::product(labels_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
|
|
|
|
|
framework::slice_ddim(labels_dims, 0, rank),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
@ -65,23 +72,24 @@ class SigmoidCrossEntropyWithLogitsGradOp
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto labels_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims.size(), 2,
|
|
|
|
|
"Input(Label)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dout_dims.size(), 2,
|
|
|
|
|
"Input(Out@Grad)'s rank should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1],
|
|
|
|
|
"The 2nd dimension of Input(X) and Input(Label) should "
|
|
|
|
|
"be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0],
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Out@Grad) "
|
|
|
|
|
"should be equal.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], dout_dims[1],
|
|
|
|
|
"The 2nd dimension of Input(X) and Input(Out@Grad) "
|
|
|
|
|
"should be equal.");
|
|
|
|
|
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
|
|
|
|
|
framework::product(labels_dims) <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
|
|
|
|
|
framework::slice_ddim(labels_dims, 0, rank),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::slice_ddim(x_dims, 0, rank),
|
|
|
|
|
framework::slice_ddim(dout_dims, 0, rank),
|
|
|
|
|
"Input(X) and Input(Out@Grad) shall have the same shape.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
}
|
|
|
|
|