|
|
|
@ -36,13 +36,11 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
|
|
|
|
|
|
|
|
|
|
auto maxlen = ctx->Attrs().Get<int>("maxlen");
|
|
|
|
|
if (maxlen > 0) { // We can only infershape when maxlen > 0
|
|
|
|
|
int maxlen = ctx->Attrs().Get<int>("maxlen");
|
|
|
|
|
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
|
|
|
|
|
dim.push_back(maxlen);
|
|
|
|
|
dim.push_back(maxlen > 0 ? maxlen : -1);
|
|
|
|
|
ctx->SetOutputDim("Y", framework::make_ddim(dim));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|