|
|
@ -30,8 +30,13 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
auto x_numel = product(x_dims);
|
|
|
|
auto x_numel = product(x_dims);
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
|
|
|
|
int new_dim = ctx->Attrs().Get<int>("new_dim");
|
|
|
|
int new_dim = ctx->Attrs().Get<int>("new_dim");
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
ctx->SetOutputDim("Out",
|
|
|
|
ctx->SetOutputDim("Out",
|
|
|
|
{x_numel / new_dim, static_cast<int64_t>(new_dim)});
|
|
|
|
{x_numel / new_dim, static_cast<int64_t>(new_dim)});
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// when compiling, the batch size is undetermined, just set to -1
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", {-1, static_cast<int64_t>(new_dim)});
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|