Merge pull request #14720 from sneaxiy/fix_seq_mask_op_infershape

Fix sequence_mask_op InferShape
revert-14398-imperative
Zeng Jinle 6 years ago committed by GitHub
commit ff4237309a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,13 +36,11 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
auto maxlen = ctx->Attrs().Get<int>("maxlen");
if (maxlen > 0) { // We can only infershape when maxlen > 0
int maxlen = ctx->Attrs().Get<int>("maxlen");
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
dim.push_back(maxlen);
dim.push_back(maxlen > 0 ? maxlen : -1);
ctx->SetOutputDim("Y", framework::make_ddim(dim));
}
}
};
class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {

Loading…
Cancel
Save