|
|
|
@ -27,14 +27,23 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs("X"),
|
|
|
|
|
"Input(X) of RefByTrainerIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("TrainerId"),
|
|
|
|
|
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of RefByTrainerIdOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1,
|
|
|
|
|
"TrainerId should be a scalar.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) of RefByTrainerIdOp should not be null."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("TrainerId"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(TrainerId) of RefByTrainerIdOp should not be null."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Out) of RefByTrainerIdOp should not be null."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("TrainerId").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument("TrainerId should be a scalar."));
|
|
|
|
|
// Out's shape is determined at runtime.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|