cherry-pick from feature/anakin-engine: Anakin support facebox #16111
parent
a32d420043
commit
a1d200a5de
@ -0,0 +1,233 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
|
||||
#include "paddle/fluid/framework/ir/node.h"
|
||||
#include "paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
template <int times>
|
||||
std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
const std::string pattern_name =
|
||||
"simplify_anakin_detection_pattern_pass" + std::to_string(times);
|
||||
FusePassBase::Init(pattern_name, graph.get());
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
std::vector<PDNode *> input_nodes;
|
||||
for (int i = 0; i < times; i++) {
|
||||
input_nodes.push_back(gpd.mutable_pattern()
|
||||
->NewNode("x" + std::to_string(i))
|
||||
->assert_is_op_input("density_prior_box", "Input")
|
||||
->AsInput());
|
||||
}
|
||||
input_nodes.push_back(gpd.mutable_pattern()
|
||||
->NewNode("x" + std::to_string(times))
|
||||
->assert_is_op_input("box_coder", "TargetBox")
|
||||
->AsInput());
|
||||
|
||||
input_nodes.push_back(gpd.mutable_pattern()
|
||||
->NewNode("x" + std::to_string(times + 1))
|
||||
->assert_is_op_input("multiclass_nms", "Scores")
|
||||
->AsInput());
|
||||
|
||||
patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name);
|
||||
pattern(input_nodes, times);
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
||||
Graph *g) {
|
||||
const int kNumFields = 7;
|
||||
const int kPriorBoxLocOffset = 1;
|
||||
const int kReshape1Offset = 2;
|
||||
const int kReshape1OutOffset = 3;
|
||||
const int kPriorBoxVarOffset = 4;
|
||||
const int kReshape2Offset = 5;
|
||||
const int kReshape2OutOffset = 6;
|
||||
std::vector<Node *> nodes;
|
||||
|
||||
for (int i = 0; i < times; i++) {
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("prior_box" + std::to_string(i))));
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("box_out" + std::to_string(i))));
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("reshape1" + std::to_string(i))));
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("reshape1_out" + std::to_string(i))));
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("reshape2" + std::to_string(i))));
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("reshape2_out" + std::to_string(i))));
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
subgraph.at(pattern.GetPDNode("box_var_out" + std::to_string(i))));
|
||||
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("prior_box" + std::to_string(i))));
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("box_out" + std::to_string(i))));
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("reshape1" + std::to_string(i))));
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("reshape1_out" + std::to_string(i))));
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("box_var_out" + std::to_string(i))));
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("reshape2" + std::to_string(i))));
|
||||
nodes.push_back(
|
||||
subgraph.at(pattern.GetPDNode("reshape2_out" + std::to_string(i))));
|
||||
}
|
||||
|
||||
Node *concat_op1 = subgraph.at(pattern.GetPDNode("concat1"));
|
||||
Node *concat_out1 = subgraph.at(pattern.GetPDNode("concat1_out"));
|
||||
|
||||
Node *concat_op2 = subgraph.at(pattern.GetPDNode("concat2"));
|
||||
Node *concat_out2 = subgraph.at(pattern.GetPDNode("concat2_out"));
|
||||
|
||||
Node *box_coder_third_input = subgraph.at(input_nodes[times]);
|
||||
Node *box_coder_op = subgraph.at(pattern.GetPDNode("box_coder"));
|
||||
Node *box_coder_out = subgraph.at(pattern.GetPDNode("box_coder_out"));
|
||||
|
||||
Node *multiclass_nms_second_input = subgraph.at(input_nodes[times + 1]);
|
||||
Node *multiclass_nms = subgraph.at(pattern.GetPDNode("multiclass_nms"));
|
||||
Node *multiclass_nms_out =
|
||||
subgraph.at(pattern.GetPDNode("multiclass_nms_out"));
|
||||
|
||||
std::string code_type =
|
||||
boost::get<std::string>(box_coder_op->Op()->GetAttr("code_type"));
|
||||
bool box_normalized =
|
||||
boost::get<bool>(box_coder_op->Op()->GetAttr("box_normalized"));
|
||||
// auto variance =
|
||||
// boost::get<std::vector<float>>(box_coder_op->Op()->GetAttr("variance"));
|
||||
int background_label =
|
||||
boost::get<int>(multiclass_nms->Op()->GetAttr("background_label"));
|
||||
float score_threshold =
|
||||
boost::get<float>(multiclass_nms->Op()->GetAttr("score_threshold"));
|
||||
int nms_top_k = boost::get<int>(multiclass_nms->Op()->GetAttr("nms_top_k"));
|
||||
float nms_threshold =
|
||||
boost::get<float>(multiclass_nms->Op()->GetAttr("nms_threshold"));
|
||||
float nms_eta = boost::get<float>(multiclass_nms->Op()->GetAttr("nms_eta"));
|
||||
int keep_top_k =
|
||||
boost::get<int>(multiclass_nms->Op()->GetAttr("keep_top_k"));
|
||||
|
||||
std::vector<std::string> concat1_input_names;
|
||||
for (int i = 0; i < times; i++) {
|
||||
concat1_input_names.push_back(
|
||||
nodes[i * kNumFields + kPriorBoxLocOffset]->Name());
|
||||
}
|
||||
|
||||
int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis"));
|
||||
framework::OpDesc concat1_desc;
|
||||
concat1_desc.SetType("concat");
|
||||
concat1_desc.SetInput("X", concat1_input_names);
|
||||
concat1_desc.SetAttr("axis", axis);
|
||||
concat1_desc.SetOutput("Out", {concat_out1->Name()});
|
||||
|
||||
auto *new_add_concat_op = graph->CreateOpNode(&concat1_desc);
|
||||
|
||||
for (int i = 0; i < times; i++) {
|
||||
nodes[i * kNumFields + kPriorBoxLocOffset]->outputs.push_back(
|
||||
new_add_concat_op);
|
||||
new_add_concat_op->inputs.push_back(
|
||||
nodes[i * kNumFields + kPriorBoxLocOffset]);
|
||||
}
|
||||
|
||||
framework::OpDesc new_op_desc;
|
||||
new_op_desc.SetType("detection_out");
|
||||
new_op_desc.SetInput("PriorBox", {concat_out1->Name()});
|
||||
new_op_desc.SetInput("TargetBox", {box_coder_third_input->Name()});
|
||||
new_op_desc.SetInput("Scores", {multiclass_nms_second_input->Name()});
|
||||
new_op_desc.SetAttr("code_type", code_type);
|
||||
new_op_desc.SetAttr("box_normalized", box_normalized);
|
||||
new_op_desc.SetAttr("background_label", background_label);
|
||||
new_op_desc.SetAttr("score_threshold", score_threshold);
|
||||
new_op_desc.SetAttr("nms_top_k", nms_top_k);
|
||||
new_op_desc.SetAttr("nms_threshold", nms_threshold);
|
||||
new_op_desc.SetAttr("nms_eta", nms_eta);
|
||||
new_op_desc.SetAttr("keep_top_k", keep_top_k);
|
||||
new_op_desc.SetOutput("Out", {multiclass_nms_out->Name()});
|
||||
new_op_desc.Flush();
|
||||
|
||||
// Create a new node for the fused op.
|
||||
auto *detection_out_op = graph->CreateOpNode(&new_op_desc);
|
||||
|
||||
std::unordered_set<const Node *> delete_nodes;
|
||||
|
||||
for (int i = 0; i < times; i++) {
|
||||
nodes[i * kNumFields + kPriorBoxLocOffset]->outputs.push_back(concat_op1);
|
||||
delete_nodes.insert(nodes[i * kNumFields + kReshape1Offset]);
|
||||
delete_nodes.insert(nodes[i * kNumFields + kReshape1OutOffset]);
|
||||
delete_nodes.insert(nodes[i * kNumFields + kPriorBoxVarOffset]);
|
||||
delete_nodes.insert(nodes[i * kNumFields + kReshape2Offset]);
|
||||
delete_nodes.insert(nodes[i * kNumFields + kReshape2OutOffset]);
|
||||
}
|
||||
|
||||
delete_nodes.insert(concat_op1);
|
||||
delete_nodes.insert(concat_op2);
|
||||
delete_nodes.insert(concat_out2);
|
||||
delete_nodes.insert(box_coder_op);
|
||||
delete_nodes.insert(box_coder_out);
|
||||
delete_nodes.insert(multiclass_nms);
|
||||
|
||||
new_add_concat_op->outputs.push_back(concat_out1);
|
||||
concat_out1->inputs.push_back(new_add_concat_op);
|
||||
|
||||
detection_out_op->inputs.push_back(concat_out1);
|
||||
detection_out_op->inputs.push_back(box_coder_third_input);
|
||||
detection_out_op->inputs.push_back(multiclass_nms_second_input);
|
||||
detection_out_op->outputs.push_back(multiclass_nms_out);
|
||||
|
||||
concat_out1->outputs.push_back(detection_out_op);
|
||||
box_coder_third_input->outputs.push_back(detection_out_op);
|
||||
multiclass_nms_second_input->outputs.push_back(detection_out_op);
|
||||
multiclass_nms_out->inputs.push_back(detection_out_op);
|
||||
|
||||
// Delete the unneeded nodes.
|
||||
GraphSafeRemoveNodes(graph.get(), delete_nodes);
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
return graph;
|
||||
}
|
||||
|
||||
template class SimplifyAnakinDetectionPatternPass<1>;
|
||||
template class SimplifyAnakinDetectionPatternPass<3>;
|
||||
template class SimplifyAnakinDetectionPatternPass<4>;
|
||||
template class SimplifyAnakinDetectionPatternPass<5>;
|
||||
template class SimplifyAnakinDetectionPatternPass<6>;
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(simplify_anakin_detection_pattern_pass,
|
||||
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<1>);
|
||||
|
||||
REGISTER_PASS(simplify_anakin_detection_pattern_pass3,
|
||||
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<3>);
|
||||
|
||||
REGISTER_PASS(simplify_anakin_detection_pattern_pass4,
|
||||
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<4>);
|
||||
|
||||
REGISTER_PASS(simplify_anakin_detection_pattern_pass5,
|
||||
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<5>);
|
||||
|
||||
REGISTER_PASS(simplify_anakin_detection_pattern_pass6,
|
||||
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<6>);
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
// There may be many transpose-flatten structures in a model, and the output of
|
||||
// these structures will be used as inputs to the concat Op. This pattern will
|
||||
// be detected by our pass. The times here represents the repeat times of this
|
||||
// structure.
|
||||
template <int times>
|
||||
class SimplifyAnakinDetectionPatternPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~SimplifyAnakinDetectionPatternPass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/anakin/convert/density_prior_box.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
using anakin::graph::GraphGlobalMem;
|
||||
using anakin::AK_FLOAT;
|
||||
using anakin::saber::NV;
|
||||
using anakin::saber::Shape;
|
||||
using anakin::PTuple;
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) {
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
auto input_name = op_desc.Input("Input").front();
|
||||
auto image_name = op_desc.Input("Image").front();
|
||||
auto output_name = op_desc.Output("Boxes").front();
|
||||
|
||||
auto op_name = op_desc.Type() + ":" + op_desc.Output("Boxes").front();
|
||||
|
||||
auto fixed_sizes =
|
||||
boost::get<std::vector<float>>(op_desc.GetAttr("fixed_sizes"));
|
||||
auto fixed_ratios =
|
||||
boost::get<std::vector<float>>(op_desc.GetAttr("fixed_ratios"));
|
||||
auto densities = boost::get<std::vector<int>>(op_desc.GetAttr("densities"));
|
||||
|
||||
// lack flip
|
||||
auto clip = boost::get<bool>(op_desc.GetAttr("clip"));
|
||||
auto variances = boost::get<std::vector<float>>(op_desc.GetAttr("variances"));
|
||||
|
||||
// lack img_h, img_w
|
||||
auto step_h = boost::get<float>(op_desc.GetAttr("step_h"));
|
||||
auto step_w = boost::get<float>(op_desc.GetAttr("step_w"));
|
||||
auto offset = boost::get<float>(op_desc.GetAttr("offset"));
|
||||
std::vector<std::string> order = {"MIN", "COM", "MAX"};
|
||||
std::vector<float> temp_v = {};
|
||||
|
||||
engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name});
|
||||
engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", temp_v);
|
||||
engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", temp_v);
|
||||
engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", temp_v);
|
||||
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_sizes", fixed_sizes);
|
||||
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratios", fixed_ratios);
|
||||
engine_->AddOpAttr<PTuple<int>>(op_name, "density", densities);
|
||||
engine_->AddOpAttr(op_name, "is_flip", false);
|
||||
engine_->AddOpAttr(op_name, "is_clip", clip);
|
||||
engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances);
|
||||
engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
|
||||
engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
|
||||
engine_->AddOpAttr(op_name, "step_h", step_h);
|
||||
engine_->AddOpAttr(op_name, "step_w", step_w);
|
||||
engine_->AddOpAttr(op_name, "offset", offset);
|
||||
engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", order);
|
||||
}
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_ANAKIN_OP_CONVERTER(density_prior_box, DensityPriorBoxOpConverter);
|
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
class DensityPriorBoxOpConverter : public AnakinOpConverter {
|
||||
public:
|
||||
DensityPriorBoxOpConverter() = default;
|
||||
|
||||
virtual void operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) override;
|
||||
virtual ~DensityPriorBoxOpConverter() {}
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/anakin/convert/detection_out.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
|
||||
using anakin::graph::GraphGlobalMem;
|
||||
using anakin::AK_FLOAT;
|
||||
using anakin::saber::NV;
|
||||
using anakin::saber::Shape;
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) {
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
auto target_name = op_desc.Input("TargetBox").front();
|
||||
auto prior_box_name = op_desc.Input("PriorBox").front();
|
||||
auto scores_name = op_desc.Input("Scores").front();
|
||||
auto output_name = op_desc.Output("Out").front();
|
||||
|
||||
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
|
||||
|
||||
auto code_type = boost::get<std::string>(op_desc.GetAttr("code_type"));
|
||||
auto background_label = boost::get<int>(op_desc.GetAttr("background_label"));
|
||||
auto score_threshold = boost::get<float>(op_desc.GetAttr("score_threshold"));
|
||||
auto nms_top_k = boost::get<int>(op_desc.GetAttr("nms_top_k"));
|
||||
auto nms_threshold = boost::get<float>(op_desc.GetAttr("nms_threshold"));
|
||||
auto nms_eta = boost::get<float>(op_desc.GetAttr("nms_eta"));
|
||||
auto keep_top_k = boost::get<int>(op_desc.GetAttr("keep_top_k"));
|
||||
std::string anakin_code_type;
|
||||
if (code_type == "decode_center_size") {
|
||||
anakin_code_type = "CENTER_SIZE";
|
||||
} else if (code_type == "encode_center_size") {
|
||||
PADDLE_THROW(
|
||||
"Not support encode_center_size code_type in DetectionOut of anakin");
|
||||
}
|
||||
|
||||
engine_->AddOp(op_name, "DetectionOutput",
|
||||
{target_name, scores_name, prior_box_name}, {output_name});
|
||||
engine_->AddOpAttr(op_name, "share_location", true);
|
||||
engine_->AddOpAttr(op_name, "variance_encode_in_target", false);
|
||||
engine_->AddOpAttr(op_name, "class_num", static_cast<int>(0));
|
||||
engine_->AddOpAttr(op_name, "background_id", background_label);
|
||||
engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k);
|
||||
engine_->AddOpAttr(op_name, "code_type", anakin_code_type);
|
||||
engine_->AddOpAttr(op_name, "conf_thresh", score_threshold);
|
||||
engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k);
|
||||
engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold);
|
||||
engine_->AddOpAttr(op_name, "nms_eta", nms_eta);
|
||||
}
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_ANAKIN_OP_CONVERTER(detection_out, DetectionOutOpConverter);
|
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
class DetectionOutOpConverter : public AnakinOpConverter {
|
||||
public:
|
||||
DetectionOutOpConverter() = default;
|
||||
|
||||
virtual void operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) override;
|
||||
virtual ~DetectionOutOpConverter() {}
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue