faster rcnn input is presistable. (fix it in paddle-trt)

test=develop
revert-15207-remove_op_handle_lock_and_fix_var
nhzlx 6 years ago
parent 73b47df1f4
commit a6aa8ea771

@ -1101,12 +1101,6 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var; return out_var;
} }
// only support "identity" and "relu" now.
/*
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
*/
std::unordered_set<std::string> conv_act_set({"identity", "relu"}); std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {

@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::unordered_set<Node *> &nodes) { const std::unordered_set<Node *> &nodes) {
// We can judge whether a variable is a parameter by
// its presistable property, but sometimes the presistable
// of the feed op output is true, so we have to identify it.
std::vector<std::string> feed_outputs;
for (const auto &node : nodes) {
if (!node->IsOp()) continue;
std::string op_type = node->Op()->Type();
if (op_type == "feed") {
std::vector<std::string> output_names = node->Op()->OutputArgumentNames();
std::copy(output_names.begin(), output_names.end(),
std::back_inserter(feed_outputs));
}
}
std::vector<std::string> parameters; std::vector<std::string> parameters;
for (const auto &node : nodes) { for (const auto &node : nodes) {
if (!node->IsVar()) continue; if (!node->IsVar()) continue;
if (node->Var()->Persistable()) { if (node->Var()->Persistable() &&
std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) ==
feed_outputs.end()) {
parameters.push_back(node->Name()); parameters.push_back(node->Name());
} }
} }

Loading…
Cancel
Save