Cherry-pick benchmark related changes from release/1.4 (#17156)

* cherry-pick commit from 8877054

* cherry-pick commit from 3f0b97d

* cherry-pick from 16691:Anakin subgraph support yolo_v3 and faster-rcnn

(cherry picked from commit 8643dbc233)

* Cherry-Pick from 16662 : Anakin subgraph cpu support

(cherry picked from commit 7ad182e16c)

* Cherry-pick from 1662, 16797.. : add anakin int8 support

(cherry picked from commit e14ab180fe)

* Cherry-pick from 16813 : change singleton to graph RegistBlock
test=release/1.4

(cherry picked from commit 4b9fa42307)

* Cherry Pick : 16837 Support ShuffleNet and MobileNet-v2

Support ShuffleNet and MobileNet-v2, test=release/1.4

(cherry picked from commit a6fb066f90)

* Cherry-pick : anakin subgraph add opt config layout argument #16846
test=release/1.4

(cherry picked from commit 8121b3eccb)

* 1. add shuffle_channel_detect

(cherry picked from commit 6efdea8997)

* update shuffle_channel op convert, test=release/1.4

(cherry picked from commit e4726a066f)

* Modify symbol export rules

test=develop
revert-17304-fix_default_paddle_version
石晓伟 6 years ago committed by GitHub
parent 16922e0093
commit a72dbe9abf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,8 +25,9 @@ endif()
if(ANAKIN_FOUND)
message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ")
include_directories(${ANAKIN_ROOT})
include_directories(${ANAKIN_ROOT}/include)
include_directories(${ANAKIN_ROOT}/include/saber)
include_directories(${ANAKIN_ROOT}/saber)
link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()

@ -71,6 +71,7 @@ pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)

@ -48,17 +48,37 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
auto base_op_desc = mul->Op();
// Create an FC Node.
// OpDesc desc(base_op_desc, nullptr);
OpDesc desc;
std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant"
// can be detected by the quant_dequant_fuse_pass. This pass will add
// "input_scale",
// "weight_scale" which are extracted from fake_quant op and fake_dequant op
// to mul op,
// and then delete the fake_quant op and fake_dequant op in the graph. If
// the mul op
// has the scale info, we should add those to the fused fc.
if (base_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});

@ -1640,7 +1640,8 @@ PDNode *patterns::FillConstantElementWiseMulFuse::operator()(
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times) {
int times,
const std::string &quant_type) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
@ -1648,24 +1649,22 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
// the quant op always be one.
auto quant_op_in_scale =
pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input("fake_quantize_range_abs_max", "InScale")
->AsInput();
auto quant_op = pattern->NewNode(GetNodeName("quant_op"))
->assert_is_op("fake_quantize_range_abs_max");
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output("fake_quantize_range_abs_max", "OutScale")
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto quant_op_out =
pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output("fake_quantize_range_abs_max", "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
@ -1707,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
}
}
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2")
->AsIntermediate();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
reshape1_op->LinksFrom({reshape1_in});
reshape1_out->LinksFrom({reshape1_op});
transpose_op->LinksFrom({reshape1_out});
transpose_out->LinksFrom({transpose_op});
reshape2_op->LinksFrom({transpose_out});
reshape2_out->LinksFrom({reshape2_op});
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -880,7 +880,8 @@ struct QuantDequantOpFuse : public PatternBase {
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times = 1);
const std::string& weight_name, int times,
const std::string& quant_type);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
@ -891,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
}
};
struct ShuffleChannelPattern : public PatternBase {
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
void operator()(PDNode* reshape1_in);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.

@ -25,7 +25,8 @@ namespace framework {
namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::string op_type) {
const std::string& op_type,
const std::string& quant_type) {
const std::string pattern_name = "quant_dequant_fuse";
// FusePassBase::Init(pattern_name, graph);
const int kNumFields = 5;
@ -38,7 +39,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("fake_quantize_range_abs_max", "X")
->assert_is_op_input(quant_type, "X")
->AsInput();
std::string quantized_op_type = "";
@ -46,6 +47,9 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
if (op_type == "conv2d") {
quantized_op_type = "conv2d";
weight_name = "Filter";
} else if (op_type == "depthwise_conv2d") {
quantized_op_type = "depthwise_conv2d";
weight_name = "Filter";
} else if (op_type == "conv2d_fusion") {
quantized_op_type = "conv2d_fusion";
weight_name = "Filter";
@ -62,7 +66,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
}
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, quantized_op_type, weight_name, times);
pattern(x, quantized_op_type, weight_name, times, quant_type);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
@ -103,7 +107,6 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < times; i++) {
// max_range = (range * range) / weight_scale
float max_range = boost::get<float>(
nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range"));
float weight_scale = (range * range) / max_range;
@ -118,7 +121,8 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
new_op_desc.SetType(quantized_op_type);
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion") {
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
@ -156,11 +160,17 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_fuse";
FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul"};
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
"depthwise_conv2d"};
auto* scope = param_scope();
for (auto& op_type : quantized_op_types) {
for (int i = 1; i <= 6; i++) {
RunQuantDequant(graph, scope, i, op_type);
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 6; i >= 1; i--) {
RunQuantDequant(graph, scope, i, op_type, quant_type);
}
}
}
}

@ -0,0 +1,93 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose_op); \
GET_IR_NODE(transpose_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
PADDLE_ENFORCE(subgraph.count(x));
auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op();
std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name();
auto reshape1_shape =
boost::get<std::vector<int>>(reshape1_desc->GetAttr("shape"));
auto reshape2_shape =
boost::get<std::vector<int>>(reshape2_desc->GetAttr("shape"));
int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1];
int group = o_c / i_c;
framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel");
new_op_desc.SetInput("X", {input_name});
new_op_desc.SetOutput("Out", {output_name});
new_op_desc.SetAttr("group", group);
new_op_desc.Flush();
// Create a new node for the fused op.
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(new_op, reshape2_out);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
transpose_out, reshape2_op});
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass);

@ -0,0 +1,34 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ShuffleChannelDetectPass : public FusePassBase {
public:
virtual ~ShuffleChannelDetectPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -1,4 +1,9 @@
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc
roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto
scope op_registry gtest)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
@ -14,5 +19,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla
cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL)
cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL)
cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL)
#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col)
cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL)
cc_test(test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL)

@ -16,16 +16,13 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
template <typename TargetT, ::anakin::Precision PrecisionT>
ActivationOpConverter<TargetT, PrecisionT>::ActivationOpConverter(
const std::string &op_type)
: op_type_(op_type) {
auto it = anakin_op_types_.find(op_type_);
PADDLE_ENFORCE(it != anakin_op_types_.end(),
@ -33,10 +30,10 @@ ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
anakin_op_type_ = it->second;
}
void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ActivationOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
@ -44,8 +41,17 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
engine_->AddOp(op_name, "Activation", {input_name}, {output_name});
engine_->AddOpAttr(op_name, "type", anakin_op_type_);
this->engine_->AddOp(op_name, "Activation", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "type", anakin_op_type_);
if (op_type_ == "swish") {
float beta = boost::get<float>(op_desc.GetAttr("beta"));
this->engine_->AddOpAttr(op_name, "clip_relu_num", beta);
}
if (op_type_ == "relu6") {
float threshold = boost::get<float>(op_desc.GetAttr("threshold"));
this->engine_->AddOpAttr(op_name, "clip_relu_num", threshold);
}
}
} // namespace anakin
@ -54,3 +60,5 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
REGISTER_ANAKIN_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(tanh, TanhOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(swish, SwishOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(relu6, Relu6OpConverter);

@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ActivationOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ActivationOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
explicit ActivationOpConverter(const std::string &op_type);
@ -36,18 +37,36 @@ class ActivationOpConverter : public AnakinOpConverter {
std::string op_type_;
std::string anakin_op_type_;
std::map<std::string, std::string> anakin_op_types_{{"tanh", "TanH"},
{"sigmoid", "Sigmoid"}};
{"sigmoid", "Sigmoid"},
{"relu6", "ClippedRelu"},
{"swish", "Swish"}};
};
class TanhOpConverter : public ActivationOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class TanhOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
TanhOpConverter() : ActivationOpConverter("tanh") {}
TanhOpConverter() : ActivationOpConverter<TargetT, PrecisionT>("tanh") {}
};
class SigmoidOpConverter : public ActivationOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class SigmoidOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
SigmoidOpConverter() : ActivationOpConverter("sigmoid") {}
SigmoidOpConverter()
: ActivationOpConverter<TargetT, PrecisionT>("sigmoid") {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class Relu6OpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
Relu6OpConverter() : ActivationOpConverter<TargetT, PrecisionT>("relu6") {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class SwishOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
SwishOpConverter() : ActivationOpConverter<TargetT, PrecisionT>("swish") {}
};
} // namespace anakin
} // namespace inference
} // namespace paddle

@ -0,0 +1,55 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/affine_channel.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
void AffineChannelOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
this->engine_->AddOp(op_name, "AffineChannel", {input_name}, {output_name});
// Copy the Scale to CPUPlace and get the pointer.
auto *scale_v = scope.FindVar(op_desc.Input("Scale").front());
PADDLE_ENFORCE_NOT_NULL(scale_v);
auto weight1 = pblock_from_var<TargetT, PrecisionT>(*scale_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
// Copy the Bias to CPUPlace and get the pointer.
auto *bias_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(bias_v);
auto weight2 = pblock_from_var<TargetT, PrecisionT>(*bias_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_2", *weight2);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(affine_channel, AffineChannelOpConverter);

@ -0,0 +1,40 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
class AffineChannelOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
AffineChannelOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) override;
virtual ~AffineChannelOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle

@ -18,107 +18,64 @@
#include <map>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void BatchNormOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
std::map<std::string, std::string> inputs;
for (auto k : {"X", "Scale", "Bias", "Mean", "Variance"}) {
PADDLE_ENFORCE_EQ(op_desc.Input(k).size(), 1UL);
auto v = op_desc.Input(k).front();
inputs.insert({k, v});
}
auto input = op_desc.Input("X").front();
auto output = op_desc.Output("Y").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Y").front();
auto epsilon = boost::get<float>(op_desc.GetAttr("epsilon"));
// auto momentum = boost::get<float>(op_desc.GetAttr("momentum"));
auto bn_op_name = op_name + ":bn";
auto bn_output = bn_op_name + "_output";
engine_->AddOp(bn_op_name, "BatchNorm", {inputs["X"]}, {bn_output});
engine_->AddOpAttr(bn_op_name, "epsilon", epsilon);
engine_->AddOpAttr(bn_op_name, "momentum", static_cast<float>(1.0));
this->engine_->AddOp(bn_op_name, "BatchNorm", {input}, {bn_output});
this->engine_->AddOpAttr(bn_op_name, "epsilon", epsilon);
this->engine_->AddOpAttr(bn_op_name, "momentum", static_cast<float>(1.0));
auto scale_op_name = op_name + ":scale";
auto get_lod_tensor = [this, &scope, &op_name](const std::string &var_name,
framework::LoDTensor *tensor) {
auto *v = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(v);
auto *t = v->GetMutable<framework::LoDTensor>();
tensor->Resize(t->dims());
TensorCopySync(*t, platform::CPUPlace(), tensor);
};
framework::LoDTensor bias_t;
framework::LoDTensor mean_t;
framework::LoDTensor scale_t;
framework::LoDTensor variance_t;
get_lod_tensor(inputs["Bias"], &bias_t);
get_lod_tensor(inputs["Mean"], &mean_t);
get_lod_tensor(inputs["Scale"], &scale_t);
get_lod_tensor(inputs["Variance"], &variance_t);
auto fill_shape = [](size_t n, std::vector<int> shape) {
shape.insert(shape.begin(), 1);
if (shape.size() < n) {
shape.insert(shape.end(), n - shape.size(), 1);
}
return shape;
};
Shape shape1(fill_shape(4, framework::vectorize2int(mean_t.dims())));
Shape shape2(fill_shape(4, framework::vectorize2int(variance_t.dims())));
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape1);
auto *mean_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(mean_t.data<float>(), mean_t.numel(), mean_data);
engine_->AddOpAttr(bn_op_name, "weight_1", *weight1);
auto *weight2 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape2);
auto *variance_data =
static_cast<float *>(weight2->h_tensor().mutable_data());
std::copy_n(variance_t.data<float>(), variance_t.numel(), variance_data);
engine_->AddOpAttr(bn_op_name, "weight_2", *weight2);
Shape shape3(std::vector<int>({1, 1, 1, 1}));
auto *weight3 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape3);
auto *alpha_data = static_cast<float *>(weight3->h_tensor().mutable_data());
float weight3_data[] = {1};
std::copy(std::begin(weight3_data), std::end(weight3_data), alpha_data);
engine_->AddOpAttr(bn_op_name, "weight_3", *weight3);
Shape scale_shape(fill_shape(4, framework::vectorize2int(scale_t.dims())));
auto *scale =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(scale_shape);
auto *scale_data = static_cast<float *>(scale->h_tensor().mutable_data());
std::copy_n(scale_t.data<float>(), scale_t.numel(), scale_data);
Shape bias_shape(fill_shape(4, framework::vectorize2int(bias_t.dims())));
auto *bias =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(bias_shape);
auto *bias_data = static_cast<float *>(bias->h_tensor().mutable_data());
std::copy_n(bias_t.data<float>(), bias_t.numel(), bias_data);
engine_->AddOp(scale_op_name, "Scale", {bn_output}, {output});
engine_->AddOpAttr(scale_op_name, "axis", 1);
engine_->AddOpAttr(scale_op_name, "num_axes", 1);
engine_->AddOpAttr(scale_op_name, "bias_term", true);
engine_->AddOpAttr(scale_op_name, "weight_1", *scale);
engine_->AddOpAttr(scale_op_name, "weight_2", *bias);
this->engine_->AddOp(scale_op_name, "Scale", {bn_output}, {output});
this->engine_->AddOpAttr(scale_op_name, "axis", 1);
this->engine_->AddOpAttr(scale_op_name, "num_axes", 1);
this->engine_->AddOpAttr(scale_op_name, "bias_term", true);
auto *mean_v = scope.FindVar(op_desc.Input("Mean").front());
PADDLE_ENFORCE_NOT_NULL(mean_v);
auto weight1 = pblock_from_var<TargetT, PrecisionT>(*mean_v, this->engine_);
this->engine_->AddOpAttr(bn_op_name, "weight_1", *weight1);
auto *variance_v = scope.FindVar(op_desc.Input("Variance").front());
PADDLE_ENFORCE_NOT_NULL(variance_v);
auto weight2 =
pblock_from_var<TargetT, PrecisionT>(*variance_v, this->engine_);
this->engine_->AddOpAttr(bn_op_name, "weight_2", *weight2);
auto *weight3 = pblock_from_vector<TargetT, PrecisionT>(
std::vector<float>({1}), this->engine_);
this->engine_->AddOpAttr(bn_op_name, "weight_3", *weight3);
auto *scale_v = scope.FindVar(op_desc.Input("Scale").front());
PADDLE_ENFORCE_NOT_NULL(scale_v);
auto scale = pblock_from_var<TargetT, PrecisionT>(*scale_v, this->engine_);
this->engine_->AddOpAttr(scale_op_name, "weight_1", *scale);
auto *bias_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(bias_v);
auto bias = pblock_from_var<TargetT, PrecisionT>(*bias_v, this->engine_);
this->engine_->AddOpAttr(scale_op_name, "weight_2", *bias);
}
} // namespace anakin

@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class BatchNormOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class BatchNormOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
BatchNormOpConverter() = default;

@ -15,34 +15,23 @@
#include "paddle/fluid/inference/anakin/convert/concat.h"
#include <algorithm>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void ConcatOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ConcatOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
int axis = boost::get<int>(op_desc.GetAttr("axis"));
auto input_names = op_desc.Input("X");
// PADDLE_ENFORCE(axis > 0,
// "The axis attr of Concat op should be large than 0 for trt");
auto y_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Concat", input_names, {y_name});
engine_->AddOpAttr(op_name, "axis", axis);
this->engine_->AddOp(op_name, "Concat", input_names, {y_name});
this->engine_->AddOpAttr(op_name, "axis", axis);
}
} // namespace anakin

@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ConcatOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ConcatOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
ConcatOpConverter() = default;

@ -16,21 +16,18 @@
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/inference/anakin/convert/helper.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void Conv2dOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1UL);
@ -39,46 +36,69 @@ void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op,
auto input_name = op_desc.Input("Input").front();
auto output_name = op_desc.Output("Output").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Output").front();
engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
this->engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
auto *filter_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(filter_v);
auto *filter_t = filter_v->GetMutable<framework::LoDTensor>();
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(filter_t->dims());
TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get());
auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
// const int n_output = weight_tensor->dims()[0];
// const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
// auto filter_num = n_input * filter_h * filter_w ;
auto filter_num = weight_tensor->dims()[0];
engine_->AddOpAttr<int>(op_name, "filter_num", filter_num);
engine_->AddOpAttr<PTuple<int>>(op_name, "kernel_size", {filter_h, filter_w});
this->engine_->template AddOpAttr<int>(op_name, "filter_num", filter_num);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "kernel_size",
{filter_h, filter_w});
auto strides = boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "strides", strides);
auto paddings = boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
engine_->AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
auto dilations = boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
engine_->AddOpAttr<PTuple<int>>(op_name, "dilation_rate", dilations);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dilation_rate",
dilations);
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
engine_->AddOpAttr(op_name, "group", groups);
engine_->AddOpAttr(op_name, "axis", 1);
engine_->AddOpAttr(op_name, "bias_term", false);
this->engine_->AddOpAttr(op_name, "group", groups);
this->engine_->AddOpAttr(op_name, "axis", 1);
this->engine_->AddOpAttr(op_name, "bias_term", false);
::anakin::saber::Shape anakin_shape(weight_shape);
bool enable_int8 = boost::get<bool>(op_desc.HasAttr("enable_int8"));
auto weight_shape = framework::vectorize2int(filter_t->dims());
Shape anakin_shape(weight_shape);
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(anakin_shape);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(weight_tensor->data<float>(), weight_tensor->numel(), cpu_data);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
engine_->AddOpAttr(op_name, "weight_1", *weight1);
if (enable_int8) {
const float int8_range = 127.;
float in_scale = boost::get<float>(op_desc.GetAttr("input_scale"));
float weight_scale = boost::get<float>(op_desc.GetAttr("weight_scale"));
PBlock<TargetT> *weight1 =
new PBlock<TargetT>(anakin_shape, ::anakin::AK_INT8);
this->engine_->RegistBlock(weight1);
float *weight_data = weight_tensor->data<float>();
std::vector<char> weight_int8;
int weight_num = weight_tensor->numel();
for (int i = 0; i < weight_tensor->numel(); i++) {
bool is_valid_int8 =
((weight_data[i] >= -128) && (weight_data[i] <= 127));
PADDLE_ENFORCE(is_valid_int8,
"We are in anakin subgraph int8 mode, the weight of conv "
"should be in range [-128, 127]");
weight_int8.push_back(static_cast<char>(weight_data[i]));
}
memcpy(static_cast<void *>(weight1->h_tensor().mutable_data()),
static_cast<void *>(weight_int8.data()), sizeof(char) * weight_num);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8);
this->engine_->Graph()->SetWeightsScale(op_name,
{weight_scale / int8_range}, false);
this->engine_->AddTensorScale(input_name, in_scale / int8_range);
} else {
auto *weight1 = pblock_from_tensor<TargetT, PrecisionT>(
*weight_tensor, weight_shape, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
}
}
} // namespace anakin

@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class Conv2dOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class Conv2dOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
Conv2dOpConverter() = default;

@ -16,21 +16,18 @@
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/inference/anakin/convert/helper.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void Conv2dFusionOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1UL);
@ -40,71 +37,74 @@ void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op,
auto input_name = op_desc.Input("Input").front();
auto output_name = op_desc.Output("Output").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Output").front();
engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
this->engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
auto *filter_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(filter_v);
auto *filter_t = filter_v->GetMutable<framework::LoDTensor>();
auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
auto *b_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(b_v);
auto *b_t = b_v->GetMutable<framework::LoDTensor>();
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(filter_t->dims());
TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
// const int n_output = weight_tensor->dims()[0];
// const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
// auto filter_num = n_input * filter_h * filter_w ;
auto filter_num = weight_tensor->dims()[0];
engine_->AddOpAttr<int>(op_name, "filter_num", filter_num);
engine_->AddOpAttr<PTuple<int>>(op_name, "kernel_size", {filter_h, filter_w});
this->engine_->template AddOpAttr<int>(op_name, "filter_num", filter_num);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "kernel_size",
{filter_h, filter_w});
auto strides = boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "strides", strides);
auto paddings = boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
engine_->AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
auto dilations = boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
engine_->AddOpAttr<PTuple<int>>(op_name, "dilation_rate", dilations);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dilation_rate",
dilations);
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
engine_->AddOpAttr(op_name, "group", groups);
engine_->AddOpAttr(op_name, "axis", 1);
engine_->AddOpAttr(op_name, "bias_term", true);
auto weight_shape = framework::vectorize2int(filter_t->dims());
Shape anakin_shape(weight_shape);
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(anakin_shape);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(weight_tensor->data<float>(), weight_tensor->numel(), cpu_data);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
engine_->AddOpAttr(op_name, "weight_1", *weight1);
auto bias_shape = framework::vectorize2int(b_t->dims());
framework::LoDTensor bias_tensor;
bias_tensor.Resize(b_t->dims());
TensorCopySync((*b_t), platform::CPUPlace(), &bias_tensor);
auto *bias_data = bias_tensor.data<float>();
bias_shape.insert(bias_shape.begin(), 1);
bias_shape.insert(bias_shape.begin(), 1);
bias_shape.insert(bias_shape.begin(), 1);
// bias_shape.push_back(1);
// bias_shape.push_back(1);
Shape anakin_bias_shape(bias_shape);
this->engine_->AddOpAttr(op_name, "group", groups);
this->engine_->AddOpAttr(op_name, "axis", 1);
this->engine_->AddOpAttr(op_name, "bias_term", true);
auto *weight2 = GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(
anakin_bias_shape);
float *cpu_data2 = static_cast<float *>(weight2->h_tensor().mutable_data());
std::copy_n(bias_data, bias_tensor.numel(), cpu_data2);
weight2->d_tensor().set_shape(anakin_bias_shape);
weight2->d_tensor().copy_from(weight2->h_tensor());
engine_->AddOpAttr(op_name, "weight_2", *weight2);
::anakin::saber::Shape anakin_shape(weight_shape);
bool enable_int8 = boost::get<bool>(op_desc.HasAttr("enable_int8"));
if (enable_int8) {
const float int8_range = 127.;
float in_scale = boost::get<float>(op_desc.GetAttr("input_scale"));
float weight_scale = boost::get<float>(op_desc.GetAttr("weight_scale"));
PBlock<TargetT> *weight1 =
new PBlock<TargetT>(anakin_shape, ::anakin::AK_INT8);
this->engine_->RegistBlock(weight1);
float *weight_data = weight_tensor->data<float>();
std::vector<char> weight_int8;
int weight_num = weight_tensor->numel();
for (int i = 0; i < weight_tensor->numel(); i++) {
bool is_valid_int8 =
((weight_data[i] >= -128) && (weight_data[i] <= 127));
PADDLE_ENFORCE(is_valid_int8,
"We are in anakin subgraph int8 mode, the weight of conv "
"should be in range [-128, 127]");
weight_int8.push_back(static_cast<char>(weight_data[i]));
}
memcpy(static_cast<void *>(weight1->h_tensor().mutable_data()),
static_cast<void *>(weight_int8.data()), sizeof(char) * weight_num);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8);
this->engine_->Graph()->SetWeightsScale(op_name,
{weight_scale / int8_range}, false);
this->engine_->AddTensorScale(input_name, in_scale / int8_range);
} else {
auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
auto *weight1 = pblock_from_tensor<TargetT, PrecisionT>(
*weight_tensor, weight_shape, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
auto weight2 = pblock_from_var<TargetT, PrecisionT>(*b_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_2", *weight2);
}
}
} // namespace anakin

@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class Conv2dFusionOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class Conv2dFusionOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
Conv2dFusionOpConverter() = default;

@ -17,17 +17,14 @@
#include <map>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void DensityPriorBoxOpConverter::operator()(
template <typename TargetT, ::anakin::Precision PrecisionT>
void DensityPriorBoxOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc& op, const framework::BlockDesc& block_desc,
const framework::Scope& scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
@ -81,22 +78,30 @@ void DensityPriorBoxOpConverter::operator()(
std::vector<float> temp_v = {};
engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", min_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", max_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", aspect_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_size", fixed_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratio", fixed_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "density", dens);
engine_->AddOpAttr(op_name, "is_flip", is_flip);
engine_->AddOpAttr(op_name, "is_clip", is_clip);
engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances);
engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
engine_->AddOpAttr(op_name, "step_h", step_h);
engine_->AddOpAttr(op_name, "step_w", step_w);
engine_->AddOpAttr(op_name, "offset", offset);
engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", t_order);
this->engine_->AddOp(op_name, "PriorBox", {input_name, image_name},
{output_name});
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "min_size",
min_sizes);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "max_size",
max_sizes);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "aspect_ratio",
aspect_ratios);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "fixed_size",
fixed_sizes);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "fixed_ratio",
fixed_ratios);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "density", dens);
this->engine_->AddOpAttr(op_name, "is_flip", is_flip);
this->engine_->AddOpAttr(op_name, "is_clip", is_clip);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "variance",
variances);
this->engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
this->engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
this->engine_->AddOpAttr(op_name, "step_h", step_h);
this->engine_->AddOpAttr(op_name, "step_w", step_w);
this->engine_->AddOpAttr(op_name, "offset", offset);
this->engine_->template AddOpAttr<PTuple<std::string>>(op_name, "order",
t_order);
}
} // namespace anakin

@ -22,7 +22,9 @@ namespace paddle {
namespace inference {
namespace anakin {
class DensityPriorBoxOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class DensityPriorBoxOpConverter
: public AnakinOpConverter<TargetT, PrecisionT> {
public:
DensityPriorBoxOpConverter() = default;

@ -16,19 +16,14 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void DetectionOutOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto target_name = op_desc.Input("TargetBox").front();
auto prior_box_name = op_desc.Input("PriorBox").front();
@ -52,18 +47,19 @@ void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op,
"Not support encode_center_size code_type in DetectionOut of anakin");
}
engine_->AddOp(op_name, "DetectionOutput",
{target_name, scores_name, prior_box_name}, {output_name});
engine_->AddOpAttr(op_name, "share_location", true);
engine_->AddOpAttr(op_name, "variance_encode_in_target", false);
engine_->AddOpAttr(op_name, "class_num", static_cast<int>(0));
engine_->AddOpAttr(op_name, "background_id", background_label);
engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k);
engine_->AddOpAttr(op_name, "code_type", anakin_code_type);
engine_->AddOpAttr(op_name, "conf_thresh", score_threshold);
engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k);
engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold);
engine_->AddOpAttr(op_name, "nms_eta", nms_eta);
this->engine_->AddOp(op_name, "DetectionOutput",
{target_name, scores_name, prior_box_name},
{output_name});
this->engine_->AddOpAttr(op_name, "share_location", true);
this->engine_->AddOpAttr(op_name, "variance_encode_in_target", false);
this->engine_->AddOpAttr(op_name, "class_num", static_cast<int>(0));
this->engine_->AddOpAttr(op_name, "background_id", background_label);
this->engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k);
this->engine_->AddOpAttr(op_name, "code_type", anakin_code_type);
this->engine_->AddOpAttr(op_name, "conf_thresh", score_threshold);
this->engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k);
this->engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold);
this->engine_->AddOpAttr(op_name, "nms_eta", nms_eta);
}
} // namespace anakin

@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class DetectionOutOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class DetectionOutOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
DetectionOutOpConverter() = default;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save