|
|
|
@ -30,8 +30,10 @@ class CVMOp : public framework::OperatorWithKernel {
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank should be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank should be 2, but got %d", x_dims.size()));
|
|
|
|
|
|
|
|
|
|
if (ctx->Attrs().Get<bool>("use_cvm")) {
|
|
|
|
|
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
|
|
|
|
@ -68,26 +70,31 @@ class CVMGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto cvm_dims = ctx->GetInputDim("CVM");
|
|
|
|
|
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank should be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expect Input(X)'s rank == 2, but got %d", x_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dy_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Y@Grad)'s rank should be 2."));
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expect Input(X)'s rank == 2, but got %d", dy_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cvm_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expect Input(X)'s rank == 2, but got %d", cvm_dims.size()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[0], dy_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The 1st dimension of Input(X) and Input(Y@Grad) should "
|
|
|
|
|
"be equal."));
|
|
|
|
|
"be equal, X is %d, Y@Grad is %d",
|
|
|
|
|
x_dims[0], dy_dims[0]));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cvm_dims[1], 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"When Attr(soft_label) == false, the 2nd dimension of "
|
|
|
|
|
"Input(CVM) should be 2."));
|
|
|
|
|
"Input(CVM) should be 2, but got %d cvm_dims[1]"));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
|