|
|
|
@ -26,25 +26,47 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of %s should not be null.", Type());
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of %s should not be null.", Type());
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type());
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type());
|
|
|
|
|
|
|
|
|
|
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
PADDLE_ENFORCE_GT(shape.size(), 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(shape.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape size must be larger than 0, but received: %s.",
|
|
|
|
|
shape.size()));
|
|
|
|
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
|
|
|
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
|
|
|
|
[](int a) { return static_cast<int64_t>(a); });
|
|
|
|
|
auto output_dim = framework::make_ddim(shape_int64);
|
|
|
|
|
|
|
|
|
|
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
|
|
|
|
|
PADDLE_ENFORCE_GE(input_dim_idx, 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
|
|
|
|
|
int input_dim_size = static_cast<int>(ctx->GetInputDim("Input").size());
|
|
|
|
|
PADDLE_ENFORCE_GE(input_dim_idx, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input dimension index must be larger "
|
|
|
|
|
"equal than 0, but received: %s.",
|
|
|
|
|
input_dim_idx));
|
|
|
|
|
PADDLE_ENFORCE_GT(input_dim_size, input_dim_idx,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input dimension size must be larger than "
|
|
|
|
|
"input dimension index, but received input "
|
|
|
|
|
"dimension size: %s, input dimension index: %s.",
|
|
|
|
|
input_dim_size, input_dim_idx));
|
|
|
|
|
|
|
|
|
|
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
|
|
|
|
|
PADDLE_ENFORCE_GE(output_dim_idx, 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
|
|
|
|
|
int output_dim_size = static_cast<int>(shape.size());
|
|
|
|
|
PADDLE_ENFORCE_GE(output_dim_idx, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output dimension index must be larger "
|
|
|
|
|
"equal than 0, but received: %s.",
|
|
|
|
|
output_dim_idx));
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
output_dim_size, output_dim_idx,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output dimension size must be larger than output dimension index, "
|
|
|
|
|
"but received output dimension size: %s, output dimension index: "
|
|
|
|
|
"%s.",
|
|
|
|
|
output_dim_size, output_dim_idx));
|
|
|
|
|
|
|
|
|
|
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
|
|
|
|
|
ctx->SetOutputDim("Out", output_dim);
|
|
|
|
|