|
|
|
@ -27,8 +27,6 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
const std::string kFeedOpType = "feed";
|
|
|
|
|
const std::string kFetchOpType = "fetch";
|
|
|
|
|
const std::string kDropOutOpType = "dropout";
|
|
|
|
|
const std::string kBatchNormOpType = "batch_norm";
|
|
|
|
|
|
|
|
|
|
bool HasDependentVar(const proto::OpDesc& op_desc,
|
|
|
|
|
const std::set<std::string>& dependent_vars) {
|
|
|
|
@ -186,18 +184,13 @@ void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
|
|
|
|
|
prune_impl(input, output, 0, -1, dependent_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void inference_optimize_impl(const proto::ProgramDesc& input,
|
|
|
|
|
proto::ProgramDesc* output, int block_id) {
|
|
|
|
|
*output = input;
|
|
|
|
|
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
|
|
|
|
|
void inference_optimize_impl(proto::ProgramDesc* input, int block_id) {
|
|
|
|
|
auto* op_field = input->mutable_blocks(block_id)->mutable_ops();
|
|
|
|
|
for (auto& op_desc : *op_field) {
|
|
|
|
|
if (op_desc.type() == kDropOutOpType ||
|
|
|
|
|
op_desc.type() == kBatchNormOpType) {
|
|
|
|
|
for (auto& attr : *op_desc.mutable_attrs()) {
|
|
|
|
|
if (attr.name() == "is_test") {
|
|
|
|
|
attr.set_b(true);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
for (auto& attr : *op_desc.mutable_attrs()) {
|
|
|
|
|
if (attr.name() == "is_test") {
|
|
|
|
|
attr.set_b(true);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -205,7 +198,12 @@ void inference_optimize_impl(const proto::ProgramDesc& input,
|
|
|
|
|
|
|
|
|
|
void InferenceOptimize(const proto::ProgramDesc& input,
|
|
|
|
|
proto::ProgramDesc* output) {
|
|
|
|
|
inference_optimize_impl(input, output, 0);
|
|
|
|
|
*output = input;
|
|
|
|
|
int num_blocks = output->blocks_size();
|
|
|
|
|
PADDLE_ENFORCE_GT(num_blocks, 0, "ProgramDesc must have at least one block");
|
|
|
|
|
for (int i = 0; i < num_blocks; ++i) {
|
|
|
|
|
inference_optimize_impl(output, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|