|
|
@ -22,18 +22,16 @@ class GatherTreeOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"),
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree");
|
|
|
|
"Input(Ids) of GatherTreeOp should not be null.");
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Parents"),
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree");
|
|
|
|
"Input(Parents) of GatherTreeOp should not be null.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
|
|
|
"Output(Out) of GatherTreeOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto ids_dims = ctx->GetInputDim("Ids");
|
|
|
|
auto ids_dims = ctx->GetInputDim("Ids");
|
|
|
|
auto parents_dims = ctx->GetInputDim("Parents");
|
|
|
|
auto parents_dims = ctx->GetInputDim("Parents");
|
|
|
|
PADDLE_ENFORCE(ids_dims == parents_dims,
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true,
|
|
|
|
"The shape of Input(Parents) must be same with the shape of "
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"Input(Ids).");
|
|
|
|
"The shape of Input(Parents) must be same with the "
|
|
|
|
|
|
|
|
"shape of Input(Ids)."));
|
|
|
|
ctx->SetOutputDim("Out", ids_dims);
|
|
|
|
ctx->SetOutputDim("Out", ids_dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|