Fix beam_search InferShape (#25169)

* fix beam_search infershape, test=develop

* fix beam search op unittest, test=develop
fix_copy_if_different
liu zhengxi 5 years ago committed by GitHub
parent cd4d9122e7
commit 68e93d8a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -95,6 +95,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
std::vector<std::string>({"selected_ids", "selected_scores"})) {
OP_INOUT_CHECK(ctx->HasOutput(arg), "Output", arg, "BeamSeach");
}
auto id_dims = ctx->GetInputDim("pre_ids");
ctx->SetOutputDim("selected_scores", ctx->GetInputDim("pre_scores"));
ctx->SetOutputDim("selected_ids", id_dims);
ctx->SetOutputDim("parent_idx", {id_dims[0]});
}
protected:

@ -38,9 +38,9 @@ class BeamSearchOpTester(unittest.TestCase):
self._create_pre_scores()
self._create_scores()
self._create_pre_ids()
self.scope.var('selected_ids')
self.scope.var('selected_scores')
self.scope.var('parent_idx')
self.scope.var('selected_ids').get_tensor()
self.scope.var('selected_scores').get_tensor()
self.scope.var('parent_idx').get_tensor()
def test_run(self):
op = Operator(

Loading…
Cancel
Save