|
|
@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) {
|
|
|
|
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
|
|
|
|
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
|
|
|
|
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
|
|
|
|
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If the output of an op modifies feed vars, the op should not clip.
|
|
|
|
|
|
|
|
TEST(Prune, recurrrent_op_2) {
|
|
|
|
|
|
|
|
f::ProgramDesc program;
|
|
|
|
|
|
|
|
f::BlockDesc *block = program.MutableBlock(0);
|
|
|
|
|
|
|
|
f::BlockDesc *sub_block = program.AppendBlock(*block);
|
|
|
|
|
|
|
|
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
|
|
|
|
|
|
|
|
f::AttributeMap{}, block);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> state_var_name(1, "y");
|
|
|
|
|
|
|
|
AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}},
|
|
|
|
|
|
|
|
{{"ex_states", state_var_name},
|
|
|
|
|
|
|
|
{"states", state_var_name},
|
|
|
|
|
|
|
|
{"sub_block", sub_block}},
|
|
|
|
|
|
|
|
block);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(sub_block != nullptr);
|
|
|
|
|
|
|
|
AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}},
|
|
|
|
|
|
|
|
f::AttributeMap{}, sub_block);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc *pdesc = program.Proto();
|
|
|
|
|
|
|
|
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc pruned;
|
|
|
|
|
|
|
|
std::set<std::string> feed_var_names = {"x", "a"};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
|
|
|
EXPECT_EQ(pruned.blocks_size(), 2);
|
|
|
|
|
|
|
|
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
|
|
|
|
|
|
|
|
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
|
|
|
|
|
|
|
|
}
|
|
|
|