change api based on design doc

revert-4814-Add_sequence_project_op
Yang Yang 7 years ago
parent e0cee58c84
commit bdca4b37c4

@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
return false;
}
void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) {
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
@ -99,8 +99,10 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) {
*op_field->Add() = input.blocks(block_id).ops(i);
}
}
}
// return should_run;
void Prune(const ProgramDesc& input, ProgramDesc& output) {
prune_impl(input, output, 0);
}
} // namespace framework

@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id);
void Prune(const ProgramDesc& input, ProgramDesc& output);
} // namespace framework
} // namespace paddle

@ -68,11 +68,11 @@ TEST(Prune, one_operator) {
f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
@ -91,7 +91,7 @@ TEST(Prune, forward) {
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
@ -111,7 +111,7 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
@ -128,7 +128,7 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
@ -146,6 +146,6 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned, 0);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}

Loading…
Cancel
Save