|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/prune.h"
|
|
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/attribute.h"
|
|
|
|
@ -58,12 +59,13 @@ TEST(Prune, one_operator) {
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc *pdesc = program.Proto();
|
|
|
|
|
f::proto::ProgramDesc pruned;
|
|
|
|
|
|
|
|
|
|
f::Prune(*pdesc, &pruned);
|
|
|
|
|
std::set<std::string> feed_var_names = {};
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
|
|
|
|
|
|
|
|
|
|
feed_var_names.insert("a");
|
|
|
|
|
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
|
|
|
|
|
f::Prune(*pdesc, &pruned);
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -81,11 +83,11 @@ TEST(Prune, forward) {
|
|
|
|
|
block);
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc *pdesc = program.Proto();
|
|
|
|
|
|
|
|
|
|
std::set<std::string> feed_var_names = {"a"};
|
|
|
|
|
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
|
|
|
|
|
f::proto::ProgramDesc pruned;
|
|
|
|
|
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
|
|
|
|
|
f::Prune(*pdesc, &pruned);
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -107,7 +109,8 @@ TEST(Prune, multi_input_op) {
|
|
|
|
|
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc pruned;
|
|
|
|
|
f::Prune(*pdesc, &pruned);
|
|
|
|
|
std::set<std::string> feed_var_names = {"a0", "a1", "a2"};
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -126,7 +129,8 @@ TEST(Prune, multi_output_op) {
|
|
|
|
|
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc pruned;
|
|
|
|
|
f::Prune(*pdesc, &pruned);
|
|
|
|
|
std::set<std::string> feed_var_names = {"a"};
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -146,6 +150,7 @@ TEST(Prune, multi_target) {
|
|
|
|
|
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
|
|
|
|
|
|
|
|
|
|
f::proto::ProgramDesc pruned;
|
|
|
|
|
f::Prune(*pdesc, &pruned);
|
|
|
|
|
std::set<std::string> feed_var_names = {"a"};
|
|
|
|
|
f::Prune(*pdesc, feed_var_names, &pruned);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
|
|
|
|
|
}
|
|
|
|
|