From a72dbe9abf1936d99c0afc370308519c3889b4e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Tue, 7 May 2019 10:55:19 +0800 Subject: [PATCH] Cherry-pick benchmark related changes from release/1.4 (#17156) * cherry-pick commit from 8877054 * cherry-pick commit from 3f0b97d * cherry-pick from 16691:Anakin subgraph support yolo_v3 and faster-rcnn (cherry picked from commit 8643dbc233f12f829b64cc0ee6926e41fb891ddf) * Cherry-Pick from 16662 : Anakin subgraph cpu support (cherry picked from commit 7ad182e16cbd099523dd274d3b4051b3734c9adf) * Cherry-pick from 1662, 16797.. : add anakin int8 support (cherry picked from commit e14ab180fe76b97aa33c0089f98d1cfa771905e9) * Cherry-pick from 16813 : change singleton to graph RegistBlock test=release/1.4 (cherry picked from commit 4b9fa42307aeb90c8e2710ad07d02e286a4620aa) * Cherry Pick : 16837 Support ShuffleNet and MobileNet-v2 Support ShuffleNet and MobileNet-v2, test=release/1.4 (cherry picked from commit a6fb066f90d1009ab32e981cf8e7d47d55bbc9e6) * Cherry-pick : anakin subgraph add opt config layout argument #16846 test=release/1.4 (cherry picked from commit 8121b3eccbd57a41e448ae0e1e716d634abd338d) * 1. add shuffle_channel_detect (cherry picked from commit 6efdea8997cda4d737777f9be89cc9991120df64) * update shuffle_channel op convert, test=release/1.4 (cherry picked from commit e4726a066fd6b7790745cb02a7601531c556793b) * Modify symbol export rules test=develop --- cmake/anakin_subgraph.cmake | 3 +- paddle/fluid/framework/ir/CMakeLists.txt | 1 + paddle/fluid/framework/ir/fc_fuse_pass.cc | 20 +++ .../framework/ir/graph_pattern_detector.cc | 56 ++++++-- .../framework/ir/graph_pattern_detector.h | 18 ++- .../ir/quant_conv2d_dequant_fuse_pass.cc | 28 ++-- .../ir/shuffle_channel_detect_pass.cc | 93 ++++++++++++ .../ir/shuffle_channel_detect_pass.h | 34 +++++ .../inference/anakin/convert/CMakeLists.txt | 9 +- .../inference/anakin/convert/activation.cc | 32 +++-- .../inference/anakin/convert/activation.h | 31 +++- .../anakin/convert/affine_channel.cc | 55 ++++++++ .../inference/anakin/convert/affine_channel.h | 40 ++++++ .../inference/anakin/convert/batch_norm.cc | 119 +++++----------- .../inference/anakin/convert/batch_norm.h | 3 +- .../fluid/inference/anakin/convert/concat.cc | 23 +-- .../fluid/inference/anakin/convert/concat.h | 3 +- .../fluid/inference/anakin/convert/conv2d.cc | 88 +++++++----- .../fluid/inference/anakin/convert/conv2d.h | 3 +- .../inference/anakin/convert/conv2d_fusion.cc | 114 +++++++-------- .../inference/anakin/convert/conv2d_fusion.h | 3 +- .../anakin/convert/density_prior_box.cc | 47 ++++--- .../anakin/convert/density_prior_box.h | 4 +- .../inference/anakin/convert/detection_out.cc | 38 +++-- .../inference/anakin/convert/detection_out.h | 3 +- .../fluid/inference/anakin/convert/dropout.cc | 36 ++--- .../fluid/inference/anakin/convert/dropout.h | 3 +- .../inference/anakin/convert/elementwise.cc | 41 ++---- .../inference/anakin/convert/elementwise.h | 8 +- paddle/fluid/inference/anakin/convert/fc.cc | 115 ++++++++------- paddle/fluid/inference/anakin/convert/fc.h | 9 +- .../fluid/inference/anakin/convert/flatten.cc | 16 +-- .../fluid/inference/anakin/convert/flatten.h | 3 +- .../fluid/inference/anakin/convert/helper.cc | 32 +++++ .../fluid/inference/anakin/convert/helper.h | 95 +++++++++++++ .../inference/anakin/convert/im2sequence.cc | 27 ++-- .../inference/anakin/convert/im2sequence.h | 3 +- .../inference/anakin/convert/op_converter.h | 133 +++++++++++++----- .../fluid/inference/anakin/convert/pool2d.cc | 29 ++-- .../fluid/inference/anakin/convert/pool2d.h | 3 +- paddle/fluid/inference/anakin/convert/relu.cc | 35 +++-- paddle/fluid/inference/anakin/convert/relu.h | 15 +- .../fluid/inference/anakin/convert/reshape.cc | 16 +-- .../fluid/inference/anakin/convert/reshape.h | 3 +- .../inference/anakin/convert/roi_align.cc | 54 +++++++ .../inference/anakin/convert/roi_align.h | 39 +++++ .../fluid/inference/anakin/convert/scale.cc | 21 ++- paddle/fluid/inference/anakin/convert/scale.h | 3 +- .../anakin/convert/shuffle_channel.cc | 47 +++++++ .../anakin/convert/shuffle_channel.h | 38 +++++ .../fluid/inference/anakin/convert/softmax.cc | 17 +-- .../fluid/inference/anakin/convert/softmax.h | 3 +- .../fluid/inference/anakin/convert/split.cc | 25 ++-- paddle/fluid/inference/anakin/convert/split.h | 3 +- paddle/fluid/inference/anakin/convert/sum.cc | 21 ++- paddle/fluid/inference/anakin/convert/sum.h | 3 +- .../anakin/convert/test_activation_op.cc | 83 ++++++++++- .../anakin/convert/test_affine_channel_op.cc | 75 ++++++++++ .../anakin/convert/test_batch_norm_op.cc | 24 +++- .../anakin/convert/test_concat_op.cc | 41 +++--- .../anakin/convert/test_conv2d_op.cc | 27 +++- .../anakin/convert/test_dropout_op.cc | 23 ++- .../anakin/convert/test_elementwise_op.cc | 41 +++++- .../inference/anakin/convert/test_fc_op.cc | 27 +++- .../anakin/convert/test_flatten_op.cc | 26 +++- .../anakin/convert/test_pool2d_op.cc | 96 +++++++------ .../inference/anakin/convert/test_relu_op.cc | 35 ++++- .../anakin/convert/test_reshape_op.cc | 44 +++++- .../anakin/convert/test_softmax_op.cc | 26 +++- .../inference/anakin/convert/test_split_op.cc | 74 ++++++---- .../inference/anakin/convert/test_sum_op.cc | 23 ++- .../anakin/convert/test_transpose_op.cc | 44 +++++- .../inference/anakin/convert/transpose.cc | 16 +-- .../inference/anakin/convert/transpose.h | 3 +- .../inference/anakin/convert/ut_helper.h | 60 ++++---- paddle/fluid/inference/anakin/engine.cc | 90 ++++++++++-- paddle/fluid/inference/anakin/engine.h | 56 +++++--- paddle/fluid/inference/anakin/op_teller.cc | 5 + .../inference/anakin/test_anakin_engine.cc | 9 +- paddle/fluid/inference/analysis/argument.h | 35 +++-- .../inference/analysis/ir_pass_manager.cc | 8 ++ .../ir_passes/anakin_subgraph_pass.cc | 90 ++++++++++-- .../analysis/ir_passes/anakin_subgraph_pass.h | 8 ++ .../analysis/ir_passes/subgraph_util.cc | 2 - paddle/fluid/inference/api/analysis_config.cc | 25 +++- .../fluid/inference/api/analysis_predictor.cc | 11 +- .../inference/api/paddle_analysis_config.h | 9 +- .../inference/api/paddle_pass_builder.cc | 8 +- paddle/fluid/inference/check_symbol.sh | 2 +- .../fluid/operators/anakin/anakin_engine_op.h | 60 ++++---- paddle/fluid/pybind/inference_api.cc | 10 ++ 91 files changed, 2128 insertions(+), 852 deletions(-) create mode 100644 paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc create mode 100644 paddle/fluid/framework/ir/shuffle_channel_detect_pass.h create mode 100644 paddle/fluid/inference/anakin/convert/affine_channel.cc create mode 100644 paddle/fluid/inference/anakin/convert/affine_channel.h create mode 100644 paddle/fluid/inference/anakin/convert/helper.cc create mode 100644 paddle/fluid/inference/anakin/convert/helper.h create mode 100644 paddle/fluid/inference/anakin/convert/roi_align.cc create mode 100644 paddle/fluid/inference/anakin/convert/roi_align.h create mode 100644 paddle/fluid/inference/anakin/convert/shuffle_channel.cc create mode 100644 paddle/fluid/inference/anakin/convert/shuffle_channel.h create mode 100644 paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc diff --git a/cmake/anakin_subgraph.cmake b/cmake/anakin_subgraph.cmake index 4a7d32a635..b5437e776d 100644 --- a/cmake/anakin_subgraph.cmake +++ b/cmake/anakin_subgraph.cmake @@ -25,8 +25,9 @@ endif() if(ANAKIN_FOUND) message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ") + include_directories(${ANAKIN_ROOT}) include_directories(${ANAKIN_ROOT}/include) - include_directories(${ANAKIN_ROOT}/include/saber) + include_directories(${ANAKIN_ROOT}/saber) link_directories(${ANAKIN_ROOT}) add_definitions(-DPADDLE_WITH_ANAKIN) endif() diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 16fc1721eb..943b76e376 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -71,6 +71,7 @@ pass_library(runtime_context_cache_pass base) pass_library(expected_kernel_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) +pass_library(shuffle_channel_detect_pass inference) if(ANAKIN_FOUND) pass_library(simplify_anakin_priorbox_detection_out_pass inference) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index ca008763bf..cd8030519c 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -48,17 +48,37 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); + auto base_op_desc = mul->Op(); // Create an FC Node. + // OpDesc desc(base_op_desc, nullptr); OpDesc desc; std::string fc_x_in = subgraph.at(x)->Name(); std::string fc_Y_in = w->Name(); std::string fc_bias_in = fc_bias->Name(); std::string fc_out_out = fc_out->Name(); + desc.SetInput("Input", std::vector({fc_x_in})); desc.SetInput("W", std::vector({fc_Y_in})); desc.SetInput("Bias", std::vector({fc_bias_in})); desc.SetOutput("Out", std::vector({fc_out_out})); desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); + + // For anakin subgraph int8 + // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + + // fake_dequant" + // can be detected by the quant_dequant_fuse_pass. This pass will add + // "input_scale", + // "weight_scale" which are extracted from fake_quant op and fake_dequant op + // to mul op, + // and then delete the fake_quant op and fake_dequant op in the graph. If + // the mul op + // has the scale info, we should add those to the fused fc. + if (base_op_desc->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8")); + desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale")); + desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale")); + } + desc.SetType("fc"); auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 8468f9ccc1..0dcf064902 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1640,7 +1640,8 @@ PDNode *patterns::FillConstantElementWiseMulFuse::operator()( void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, const std::string &op_type, const std::string &weight_name, - int times) { + int times, + const std::string &quant_type) { const int kNumFields = 5; const int kQuantizedWeightOffset = 0; const int kQuantizedOpOffset = 1; @@ -1648,24 +1649,22 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, const int kDequantOpOffset = 3; const int kDequantOpOutOffset = 4; // the quant op always be one. - auto quant_op_in_scale = - pattern->NewNode(GetNodeName("quant_op_in_scale")) - ->assert_is_op_input("fake_quantize_range_abs_max", "InScale") - ->AsInput(); - auto quant_op = pattern->NewNode(GetNodeName("quant_op")) - ->assert_is_op("fake_quantize_range_abs_max"); + auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale")) + ->assert_is_op_input(quant_type, "InScale") + ->AsInput(); + auto quant_op = + pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type); auto quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale")) - ->assert_is_op_output("fake_quantize_range_abs_max", "OutScale") + ->assert_is_op_output(quant_type, "OutScale") ->assert_is_op_input("fake_dequantize_max_abs", "Scale") ->AsIntermediate(); - auto quant_op_out = - pattern->NewNode(GetNodeName("quant_op_out")) - ->assert_is_op_output("fake_quantize_range_abs_max", "Out") - ->assert_is_op_input(op_type) - ->AsIntermediate(); + auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out")) + ->assert_is_op_output(quant_type, "Out") + ->assert_is_op_input(op_type) + ->AsIntermediate(); // there are 'times' quantized and dequant op std::vector nodes; @@ -1707,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, } } +void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { + auto reshape1_op = + pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2") + ->AsIntermediate(); + + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + reshape1_op->LinksFrom({reshape1_in}); + reshape1_out->LinksFrom({reshape1_op}); + transpose_op->LinksFrom({reshape1_out}); + transpose_out->LinksFrom({transpose_op}); + reshape2_op->LinksFrom({transpose_out}); + reshape2_out->LinksFrom({reshape2_op}); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index a5ac3a0c37..907371b56b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -880,7 +880,8 @@ struct QuantDequantOpFuse : public PatternBase { : PatternBase(pattern, name_scope, "quant_dequant_fuse") {} void operator()(PDNode* quant_op_input, const std::string& op_name, - const std::string& weight_name, int times = 1); + const std::string& weight_name, int times, + const std::string& quant_type); std::string GetNodeName(const std::string& op_type) { return PDNodeName(name_scope_, repr_, id_, op_type); @@ -891,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase { } }; +struct ShuffleChannelPattern : public PatternBase { + ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "shufflechannel_pattern") {} + + void operator()(PDNode* reshape1_in); + + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 7cab9c353d..017e3ef234 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -25,7 +25,8 @@ namespace framework { namespace ir { void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, - std::string op_type) { + const std::string& op_type, + const std::string& quant_type) { const std::string pattern_name = "quant_dequant_fuse"; // FusePassBase::Init(pattern_name, graph); const int kNumFields = 5; @@ -38,7 +39,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode("x") - ->assert_is_op_input("fake_quantize_range_abs_max", "X") + ->assert_is_op_input(quant_type, "X") ->AsInput(); std::string quantized_op_type = ""; @@ -46,6 +47,9 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, if (op_type == "conv2d") { quantized_op_type = "conv2d"; weight_name = "Filter"; + } else if (op_type == "depthwise_conv2d") { + quantized_op_type = "depthwise_conv2d"; + weight_name = "Filter"; } else if (op_type == "conv2d_fusion") { quantized_op_type = "conv2d_fusion"; weight_name = "Filter"; @@ -62,7 +66,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, } patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); - pattern(x, quantized_op_type, weight_name, times); + pattern(x, quantized_op_type, weight_name, times, quant_type); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -103,7 +107,6 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, std::unordered_set delete_nodes; for (int i = 0; i < times; i++) { - // max_range = (range * range) / weight_scale float max_range = boost::get( nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range")); float weight_scale = (range * range) / max_range; @@ -118,7 +121,8 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, new_op_desc.SetType(quantized_op_type); if (quantized_op_type == "conv2d" || - quantized_op_type == "conv2d_fusion") { + quantized_op_type == "conv2d_fusion" || + quantized_op_type == "depthwise_conv2d") { new_op_desc.SetInput("Input", {new_input}); new_op_desc.SetOutput("Output", {new_output}); } else if (quantized_op_type == "fc") { @@ -156,11 +160,17 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "quant_dequant_fuse"; FusePassBase::Init(pattern_name, graph); - std::unordered_set quantized_op_types = {"conv2d", "mul"}; + std::unordered_set quant_types = { + "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + + std::unordered_set quantized_op_types = {"conv2d", "mul", + "depthwise_conv2d"}; auto* scope = param_scope(); - for (auto& op_type : quantized_op_types) { - for (int i = 1; i <= 6; i++) { - RunQuantDequant(graph, scope, i, op_type); + for (auto& quant_type : quant_types) { + for (auto& op_type : quantized_op_types) { + for (int i = 6; i >= 1; i--) { + RunQuantDequant(graph, scope, i, op_type, quant_type); + } } } } diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc new file mode 100644 index 0000000000..e55783637a --- /dev/null +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(reshape1_op); \ + GET_IR_NODE(reshape1_out); \ + GET_IR_NODE(transpose_op); \ + GET_IR_NODE(transpose_out); \ + GET_IR_NODE(reshape2_op); \ + GET_IR_NODE(reshape2_out); + +void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "shufflechannel_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + + patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + + PADDLE_ENFORCE(subgraph.count(x)); + auto* input_node = subgraph.at(x); + auto reshape1_desc = reshape1_op->Op(); + auto reshape2_desc = reshape2_op->Op(); + std::string input_name = input_node->Name(); + std::string output_name = reshape2_out->Name(); + + auto reshape1_shape = + boost::get>(reshape1_desc->GetAttr("shape")); + auto reshape2_shape = + boost::get>(reshape2_desc->GetAttr("shape")); + + int i_c = reshape1_shape[2]; + int o_c = reshape2_shape[1]; + int group = o_c / i_c; + + framework::OpDesc new_op_desc; + new_op_desc.SetType("shuffle_channel"); + new_op_desc.SetInput("X", {input_name}); + new_op_desc.SetOutput("Out", {output_name}); + + new_op_desc.SetAttr("group", group); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto* new_op = graph->CreateOpNode(&new_op_desc); + + IR_NODE_LINK_TO(input_node, new_op); + IR_NODE_LINK_TO(new_op, reshape2_out); + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op, + transpose_out, reshape2_op}); + }; + + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(shuffle_channel_detect_pass, + paddle::framework::ir::ShuffleChannelDetectPass); diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h new file mode 100644 index 0000000000..008f8013ef --- /dev/null +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ShuffleChannelDetectPass : public FusePassBase { + public: + virtual ~ShuffleChannelDetectPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/CMakeLists.txt b/paddle/fluid/inference/anakin/convert/CMakeLists.txt index d3d1522dcc..5d85525a65 100644 --- a/paddle/fluid/inference/anakin/convert/CMakeLists.txt +++ b/paddle/fluid/inference/anakin/convert/CMakeLists.txt @@ -1,4 +1,9 @@ -cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry) +cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc +elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc +batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc +detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc +roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto +scope op_registry gtest) cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL) @@ -14,5 +19,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL) cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL) cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL) -#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col) cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL) +cc_test(test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL) diff --git a/paddle/fluid/inference/anakin/convert/activation.cc b/paddle/fluid/inference/anakin/convert/activation.cc index a9aeb19ffd..523571f1aa 100644 --- a/paddle/fluid/inference/anakin/convert/activation.cc +++ b/paddle/fluid/inference/anakin/convert/activation.cc @@ -16,16 +16,13 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; - namespace paddle { namespace inference { namespace anakin { -ActivationOpConverter::ActivationOpConverter(const std::string &op_type) +template +ActivationOpConverter::ActivationOpConverter( + const std::string &op_type) : op_type_(op_type) { auto it = anakin_op_types_.find(op_type_); PADDLE_ENFORCE(it != anakin_op_types_.end(), @@ -33,10 +30,10 @@ ActivationOpConverter::ActivationOpConverter(const std::string &op_type) anakin_op_type_ = it->second; } -void ActivationOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void ActivationOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); @@ -44,8 +41,17 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op, auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); auto input_name = op_desc.Input("X").front(); auto output_name = op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Activation", {input_name}, {output_name}); - engine_->AddOpAttr(op_name, "type", anakin_op_type_); + this->engine_->AddOp(op_name, "Activation", {input_name}, {output_name}); + this->engine_->AddOpAttr(op_name, "type", anakin_op_type_); + + if (op_type_ == "swish") { + float beta = boost::get(op_desc.GetAttr("beta")); + this->engine_->AddOpAttr(op_name, "clip_relu_num", beta); + } + if (op_type_ == "relu6") { + float threshold = boost::get(op_desc.GetAttr("threshold")); + this->engine_->AddOpAttr(op_name, "clip_relu_num", threshold); + } } } // namespace anakin @@ -54,3 +60,5 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op, REGISTER_ANAKIN_OP_CONVERTER(sigmoid, SigmoidOpConverter); REGISTER_ANAKIN_OP_CONVERTER(tanh, TanhOpConverter); +REGISTER_ANAKIN_OP_CONVERTER(swish, SwishOpConverter); +REGISTER_ANAKIN_OP_CONVERTER(relu6, Relu6OpConverter); diff --git a/paddle/fluid/inference/anakin/convert/activation.h b/paddle/fluid/inference/anakin/convert/activation.h index 592a3d5bd9..a2475e492c 100644 --- a/paddle/fluid/inference/anakin/convert/activation.h +++ b/paddle/fluid/inference/anakin/convert/activation.h @@ -22,7 +22,8 @@ namespace paddle { namespace inference { namespace anakin { -class ActivationOpConverter : public AnakinOpConverter { +template +class ActivationOpConverter : public AnakinOpConverter { public: explicit ActivationOpConverter(const std::string &op_type); @@ -36,18 +37,36 @@ class ActivationOpConverter : public AnakinOpConverter { std::string op_type_; std::string anakin_op_type_; std::map anakin_op_types_{{"tanh", "TanH"}, - {"sigmoid", "Sigmoid"}}; + {"sigmoid", "Sigmoid"}, + {"relu6", "ClippedRelu"}, + {"swish", "Swish"}}; }; -class TanhOpConverter : public ActivationOpConverter { +template +class TanhOpConverter : public ActivationOpConverter { public: - TanhOpConverter() : ActivationOpConverter("tanh") {} + TanhOpConverter() : ActivationOpConverter("tanh") {} }; -class SigmoidOpConverter : public ActivationOpConverter { +template +class SigmoidOpConverter : public ActivationOpConverter { public: - SigmoidOpConverter() : ActivationOpConverter("sigmoid") {} + SigmoidOpConverter() + : ActivationOpConverter("sigmoid") {} }; + +template +class Relu6OpConverter : public ActivationOpConverter { + public: + Relu6OpConverter() : ActivationOpConverter("relu6") {} +}; + +template +class SwishOpConverter : public ActivationOpConverter { + public: + SwishOpConverter() : ActivationOpConverter("swish") {} +}; + } // namespace anakin } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/affine_channel.cc b/paddle/fluid/inference/anakin/convert/affine_channel.cc new file mode 100644 index 0000000000..534e7dca81 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/affine_channel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/affine_channel.h" +#include +#include +#include +#include "paddle/fluid/inference/anakin/convert/helper.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +void AffineChannelOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + auto input_name = op_desc.Input("X").front(); + auto output_name = op_desc.Output("Out").front(); + this->engine_->AddOp(op_name, "AffineChannel", {input_name}, {output_name}); + + // Copy the Scale to CPUPlace and get the pointer. + auto *scale_v = scope.FindVar(op_desc.Input("Scale").front()); + PADDLE_ENFORCE_NOT_NULL(scale_v); + auto weight1 = pblock_from_var(*scale_v, this->engine_); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + + // Copy the Bias to CPUPlace and get the pointer. + auto *bias_v = scope.FindVar(op_desc.Input("Bias").front()); + PADDLE_ENFORCE_NOT_NULL(bias_v); + auto weight2 = pblock_from_var(*bias_v, this->engine_); + this->engine_->AddOpAttr(op_name, "weight_2", *weight2); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(affine_channel, AffineChannelOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/affine_channel.h b/paddle/fluid/inference/anakin/convert/affine_channel.h new file mode 100644 index 0000000000..443f610128 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/affine_channel.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +class AffineChannelOpConverter : public AnakinOpConverter { + public: + AffineChannelOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~AffineChannelOpConverter() {} + + private: +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/batch_norm.cc b/paddle/fluid/inference/anakin/convert/batch_norm.cc index 38cf617202..b41f5dc925 100644 --- a/paddle/fluid/inference/anakin/convert/batch_norm.cc +++ b/paddle/fluid/inference/anakin/convert/batch_norm.cc @@ -18,107 +18,64 @@ #include #include #include - -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; +#include "paddle/fluid/inference/anakin/convert/helper.h" namespace paddle { namespace inference { namespace anakin { -void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void BatchNormOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1); std::map inputs; for (auto k : {"X", "Scale", "Bias", "Mean", "Variance"}) { PADDLE_ENFORCE_EQ(op_desc.Input(k).size(), 1UL); - auto v = op_desc.Input(k).front(); - inputs.insert({k, v}); } + auto input = op_desc.Input("X").front(); auto output = op_desc.Output("Y").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Y").front(); auto epsilon = boost::get(op_desc.GetAttr("epsilon")); - // auto momentum = boost::get(op_desc.GetAttr("momentum")); auto bn_op_name = op_name + ":bn"; auto bn_output = bn_op_name + "_output"; - engine_->AddOp(bn_op_name, "BatchNorm", {inputs["X"]}, {bn_output}); - engine_->AddOpAttr(bn_op_name, "epsilon", epsilon); - engine_->AddOpAttr(bn_op_name, "momentum", static_cast(1.0)); + this->engine_->AddOp(bn_op_name, "BatchNorm", {input}, {bn_output}); + this->engine_->AddOpAttr(bn_op_name, "epsilon", epsilon); + this->engine_->AddOpAttr(bn_op_name, "momentum", static_cast(1.0)); auto scale_op_name = op_name + ":scale"; - auto get_lod_tensor = [this, &scope, &op_name](const std::string &var_name, - framework::LoDTensor *tensor) { - auto *v = scope.FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL(v); - auto *t = v->GetMutable(); - tensor->Resize(t->dims()); - TensorCopySync(*t, platform::CPUPlace(), tensor); - }; - - framework::LoDTensor bias_t; - framework::LoDTensor mean_t; - framework::LoDTensor scale_t; - framework::LoDTensor variance_t; - get_lod_tensor(inputs["Bias"], &bias_t); - get_lod_tensor(inputs["Mean"], &mean_t); - get_lod_tensor(inputs["Scale"], &scale_t); - get_lod_tensor(inputs["Variance"], &variance_t); - - auto fill_shape = [](size_t n, std::vector shape) { - shape.insert(shape.begin(), 1); - if (shape.size() < n) { - shape.insert(shape.end(), n - shape.size(), 1); - } - return shape; - }; - Shape shape1(fill_shape(4, framework::vectorize2int(mean_t.dims()))); - Shape shape2(fill_shape(4, framework::vectorize2int(variance_t.dims()))); - auto *weight1 = - GraphGlobalMem::Global().template new_block(shape1); - auto *mean_data = static_cast(weight1->h_tensor().mutable_data()); - std::copy_n(mean_t.data(), mean_t.numel(), mean_data); - engine_->AddOpAttr(bn_op_name, "weight_1", *weight1); - - auto *weight2 = - GraphGlobalMem::Global().template new_block(shape2); - auto *variance_data = - static_cast(weight2->h_tensor().mutable_data()); - std::copy_n(variance_t.data(), variance_t.numel(), variance_data); - engine_->AddOpAttr(bn_op_name, "weight_2", *weight2); - - Shape shape3(std::vector({1, 1, 1, 1})); - auto *weight3 = - GraphGlobalMem::Global().template new_block(shape3); - auto *alpha_data = static_cast(weight3->h_tensor().mutable_data()); - float weight3_data[] = {1}; - std::copy(std::begin(weight3_data), std::end(weight3_data), alpha_data); - engine_->AddOpAttr(bn_op_name, "weight_3", *weight3); - - Shape scale_shape(fill_shape(4, framework::vectorize2int(scale_t.dims()))); - auto *scale = - GraphGlobalMem::Global().template new_block(scale_shape); - auto *scale_data = static_cast(scale->h_tensor().mutable_data()); - std::copy_n(scale_t.data(), scale_t.numel(), scale_data); - - Shape bias_shape(fill_shape(4, framework::vectorize2int(bias_t.dims()))); - auto *bias = - GraphGlobalMem::Global().template new_block(bias_shape); - auto *bias_data = static_cast(bias->h_tensor().mutable_data()); - std::copy_n(bias_t.data(), bias_t.numel(), bias_data); - - engine_->AddOp(scale_op_name, "Scale", {bn_output}, {output}); - engine_->AddOpAttr(scale_op_name, "axis", 1); - engine_->AddOpAttr(scale_op_name, "num_axes", 1); - engine_->AddOpAttr(scale_op_name, "bias_term", true); - engine_->AddOpAttr(scale_op_name, "weight_1", *scale); - engine_->AddOpAttr(scale_op_name, "weight_2", *bias); + this->engine_->AddOp(scale_op_name, "Scale", {bn_output}, {output}); + this->engine_->AddOpAttr(scale_op_name, "axis", 1); + this->engine_->AddOpAttr(scale_op_name, "num_axes", 1); + this->engine_->AddOpAttr(scale_op_name, "bias_term", true); + + auto *mean_v = scope.FindVar(op_desc.Input("Mean").front()); + PADDLE_ENFORCE_NOT_NULL(mean_v); + auto weight1 = pblock_from_var(*mean_v, this->engine_); + this->engine_->AddOpAttr(bn_op_name, "weight_1", *weight1); + + auto *variance_v = scope.FindVar(op_desc.Input("Variance").front()); + PADDLE_ENFORCE_NOT_NULL(variance_v); + auto weight2 = + pblock_from_var(*variance_v, this->engine_); + this->engine_->AddOpAttr(bn_op_name, "weight_2", *weight2); + + auto *weight3 = pblock_from_vector( + std::vector({1}), this->engine_); + this->engine_->AddOpAttr(bn_op_name, "weight_3", *weight3); + + auto *scale_v = scope.FindVar(op_desc.Input("Scale").front()); + PADDLE_ENFORCE_NOT_NULL(scale_v); + auto scale = pblock_from_var(*scale_v, this->engine_); + this->engine_->AddOpAttr(scale_op_name, "weight_1", *scale); + + auto *bias_v = scope.FindVar(op_desc.Input("Bias").front()); + PADDLE_ENFORCE_NOT_NULL(bias_v); + auto bias = pblock_from_var(*bias_v, this->engine_); + this->engine_->AddOpAttr(scale_op_name, "weight_2", *bias); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/batch_norm.h b/paddle/fluid/inference/anakin/convert/batch_norm.h index c56735f15b..52156aeb02 100644 --- a/paddle/fluid/inference/anakin/convert/batch_norm.h +++ b/paddle/fluid/inference/anakin/convert/batch_norm.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class BatchNormOpConverter : public AnakinOpConverter { +template +class BatchNormOpConverter : public AnakinOpConverter { public: BatchNormOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/concat.cc b/paddle/fluid/inference/anakin/convert/concat.cc index ae90c08369..584a82ead4 100644 --- a/paddle/fluid/inference/anakin/convert/concat.cc +++ b/paddle/fluid/inference/anakin/convert/concat.cc @@ -15,34 +15,23 @@ #include "paddle/fluid/inference/anakin/convert/concat.h" #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; -using anakin::PTuple; - namespace paddle { namespace inference { namespace anakin { -void ConcatOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void ConcatOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); int axis = boost::get(op_desc.GetAttr("axis")); auto input_names = op_desc.Input("X"); - // PADDLE_ENFORCE(axis > 0, - // "The axis attr of Concat op should be large than 0 for trt"); auto y_name = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Concat", input_names, {y_name}); - engine_->AddOpAttr(op_name, "axis", axis); + this->engine_->AddOp(op_name, "Concat", input_names, {y_name}); + this->engine_->AddOpAttr(op_name, "axis", axis); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/concat.h b/paddle/fluid/inference/anakin/convert/concat.h index 974ff689bf..fb5514affa 100644 --- a/paddle/fluid/inference/anakin/convert/concat.h +++ b/paddle/fluid/inference/anakin/convert/concat.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class ConcatOpConverter : public AnakinOpConverter { +template +class ConcatOpConverter : public AnakinOpConverter { public: ConcatOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/conv2d.cc b/paddle/fluid/inference/anakin/convert/conv2d.cc index 308f14604b..70e0adf5ea 100644 --- a/paddle/fluid/inference/anakin/convert/conv2d.cc +++ b/paddle/fluid/inference/anakin/convert/conv2d.cc @@ -16,21 +16,18 @@ #include #include #include +#include "paddle/fluid/inference/anakin/convert/helper.h" -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void Conv2dOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1UL); @@ -39,46 +36,69 @@ void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op, auto input_name = op_desc.Input("Input").front(); auto output_name = op_desc.Output("Output").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Output").front(); - engine_->AddOp(op_name, "Convolution", {input_name}, {output_name}); + this->engine_->AddOp(op_name, "Convolution", {input_name}, {output_name}); auto *filter_v = scope.FindVar(op_desc.Input("Filter").front()); PADDLE_ENFORCE_NOT_NULL(filter_v); - auto *filter_t = filter_v->GetMutable(); - std::unique_ptr weight_tensor( - new framework::LoDTensor()); - weight_tensor->Resize(filter_t->dims()); - TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get()); + auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace()); + auto weight_shape = framework::vectorize2int(weight_tensor->dims()); PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); - // const int n_output = weight_tensor->dims()[0]; - // const int n_input = weight_tensor->dims()[1]; const int filter_h = weight_tensor->dims()[2]; const int filter_w = weight_tensor->dims()[3]; - // auto filter_num = n_input * filter_h * filter_w ; + auto filter_num = weight_tensor->dims()[0]; - engine_->AddOpAttr(op_name, "filter_num", filter_num); - engine_->AddOpAttr>(op_name, "kernel_size", {filter_h, filter_w}); + this->engine_->template AddOpAttr(op_name, "filter_num", filter_num); + this->engine_->template AddOpAttr>(op_name, "kernel_size", + {filter_h, filter_w}); auto strides = boost::get>(op_desc.GetAttr("strides")); - engine_->AddOpAttr>(op_name, "strides", strides); + this->engine_->template AddOpAttr>(op_name, "strides", strides); auto paddings = boost::get>(op_desc.GetAttr("paddings")); - engine_->AddOpAttr>(op_name, "padding", paddings); + this->engine_->template AddOpAttr>(op_name, "padding", paddings); auto dilations = boost::get>(op_desc.GetAttr("dilations")); - engine_->AddOpAttr>(op_name, "dilation_rate", dilations); + this->engine_->template AddOpAttr>(op_name, "dilation_rate", + dilations); const int groups = boost::get(op_desc.GetAttr("groups")); - engine_->AddOpAttr(op_name, "group", groups); - engine_->AddOpAttr(op_name, "axis", 1); - engine_->AddOpAttr(op_name, "bias_term", false); + this->engine_->AddOpAttr(op_name, "group", groups); + this->engine_->AddOpAttr(op_name, "axis", 1); + this->engine_->AddOpAttr(op_name, "bias_term", false); + + ::anakin::saber::Shape anakin_shape(weight_shape); + bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); - auto weight_shape = framework::vectorize2int(filter_t->dims()); - Shape anakin_shape(weight_shape); - auto *weight1 = - GraphGlobalMem::Global().template new_block(anakin_shape); - float *cpu_data = static_cast(weight1->h_tensor().mutable_data()); - std::copy_n(weight_tensor->data(), weight_tensor->numel(), cpu_data); - weight1->d_tensor().set_shape(anakin_shape); - weight1->d_tensor().copy_from(weight1->h_tensor()); - engine_->AddOpAttr(op_name, "weight_1", *weight1); + if (enable_int8) { + const float int8_range = 127.; + float in_scale = boost::get(op_desc.GetAttr("input_scale")); + float weight_scale = boost::get(op_desc.GetAttr("weight_scale")); + PBlock *weight1 = + new PBlock(anakin_shape, ::anakin::AK_INT8); + this->engine_->RegistBlock(weight1); + float *weight_data = weight_tensor->data(); + std::vector weight_int8; + int weight_num = weight_tensor->numel(); + for (int i = 0; i < weight_tensor->numel(); i++) { + bool is_valid_int8 = + ((weight_data[i] >= -128) && (weight_data[i] <= 127)); + PADDLE_ENFORCE(is_valid_int8, + "We are in anakin subgraph int8 mode, the weight of conv " + "should be in range [-128, 127]"); + weight_int8.push_back(static_cast(weight_data[i])); + } + memcpy(static_cast(weight1->h_tensor().mutable_data()), + static_cast(weight_int8.data()), sizeof(char) * weight_num); + weight1->d_tensor().set_shape(anakin_shape); + weight1->d_tensor().copy_from(weight1->h_tensor()); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8); + this->engine_->Graph()->SetWeightsScale(op_name, + {weight_scale / int8_range}, false); + this->engine_->AddTensorScale(input_name, in_scale / int8_range); + } else { + auto *weight1 = pblock_from_tensor( + *weight_tensor, weight_shape, this->engine_); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + } } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/conv2d.h b/paddle/fluid/inference/anakin/convert/conv2d.h index dca5d19f46..b22cb8ea93 100644 --- a/paddle/fluid/inference/anakin/convert/conv2d.h +++ b/paddle/fluid/inference/anakin/convert/conv2d.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class Conv2dOpConverter : public AnakinOpConverter { +template +class Conv2dOpConverter : public AnakinOpConverter { public: Conv2dOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/conv2d_fusion.cc b/paddle/fluid/inference/anakin/convert/conv2d_fusion.cc index fa1ab0efee..a1568b8bde 100644 --- a/paddle/fluid/inference/anakin/convert/conv2d_fusion.cc +++ b/paddle/fluid/inference/anakin/convert/conv2d_fusion.cc @@ -16,21 +16,18 @@ #include #include #include +#include "paddle/fluid/inference/anakin/convert/helper.h" -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void Conv2dFusionOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1UL); @@ -40,71 +37,74 @@ void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op, auto input_name = op_desc.Input("Input").front(); auto output_name = op_desc.Output("Output").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Output").front(); - engine_->AddOp(op_name, "Convolution", {input_name}, {output_name}); + this->engine_->AddOp(op_name, "Convolution", {input_name}, {output_name}); auto *filter_v = scope.FindVar(op_desc.Input("Filter").front()); PADDLE_ENFORCE_NOT_NULL(filter_v); - auto *filter_t = filter_v->GetMutable(); + + auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace()); + auto weight_shape = framework::vectorize2int(weight_tensor->dims()); auto *b_v = scope.FindVar(op_desc.Input("Bias").front()); PADDLE_ENFORCE_NOT_NULL(b_v); - auto *b_t = b_v->GetMutable(); - - std::unique_ptr weight_tensor( - new framework::LoDTensor()); - weight_tensor->Resize(filter_t->dims()); - TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get()); PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); - - // const int n_output = weight_tensor->dims()[0]; - // const int n_input = weight_tensor->dims()[1]; const int filter_h = weight_tensor->dims()[2]; const int filter_w = weight_tensor->dims()[3]; - // auto filter_num = n_input * filter_h * filter_w ; auto filter_num = weight_tensor->dims()[0]; - engine_->AddOpAttr(op_name, "filter_num", filter_num); - engine_->AddOpAttr>(op_name, "kernel_size", {filter_h, filter_w}); + this->engine_->template AddOpAttr(op_name, "filter_num", filter_num); + this->engine_->template AddOpAttr>(op_name, "kernel_size", + {filter_h, filter_w}); auto strides = boost::get>(op_desc.GetAttr("strides")); - engine_->AddOpAttr>(op_name, "strides", strides); + this->engine_->template AddOpAttr>(op_name, "strides", strides); auto paddings = boost::get>(op_desc.GetAttr("paddings")); - engine_->AddOpAttr>(op_name, "padding", paddings); + this->engine_->template AddOpAttr>(op_name, "padding", paddings); auto dilations = boost::get>(op_desc.GetAttr("dilations")); - engine_->AddOpAttr>(op_name, "dilation_rate", dilations); + this->engine_->template AddOpAttr>(op_name, "dilation_rate", + dilations); const int groups = boost::get(op_desc.GetAttr("groups")); - engine_->AddOpAttr(op_name, "group", groups); - engine_->AddOpAttr(op_name, "axis", 1); - engine_->AddOpAttr(op_name, "bias_term", true); - - auto weight_shape = framework::vectorize2int(filter_t->dims()); - Shape anakin_shape(weight_shape); - auto *weight1 = - GraphGlobalMem::Global().template new_block(anakin_shape); - float *cpu_data = static_cast(weight1->h_tensor().mutable_data()); - std::copy_n(weight_tensor->data(), weight_tensor->numel(), cpu_data); - weight1->d_tensor().set_shape(anakin_shape); - weight1->d_tensor().copy_from(weight1->h_tensor()); - engine_->AddOpAttr(op_name, "weight_1", *weight1); - - auto bias_shape = framework::vectorize2int(b_t->dims()); - framework::LoDTensor bias_tensor; - bias_tensor.Resize(b_t->dims()); - TensorCopySync((*b_t), platform::CPUPlace(), &bias_tensor); - auto *bias_data = bias_tensor.data(); - bias_shape.insert(bias_shape.begin(), 1); - bias_shape.insert(bias_shape.begin(), 1); - bias_shape.insert(bias_shape.begin(), 1); - // bias_shape.push_back(1); - // bias_shape.push_back(1); - Shape anakin_bias_shape(bias_shape); + this->engine_->AddOpAttr(op_name, "group", groups); + this->engine_->AddOpAttr(op_name, "axis", 1); + this->engine_->AddOpAttr(op_name, "bias_term", true); - auto *weight2 = GraphGlobalMem::Global().template new_block( - anakin_bias_shape); - float *cpu_data2 = static_cast(weight2->h_tensor().mutable_data()); - std::copy_n(bias_data, bias_tensor.numel(), cpu_data2); - weight2->d_tensor().set_shape(anakin_bias_shape); - weight2->d_tensor().copy_from(weight2->h_tensor()); - engine_->AddOpAttr(op_name, "weight_2", *weight2); + ::anakin::saber::Shape anakin_shape(weight_shape); + bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); + if (enable_int8) { + const float int8_range = 127.; + float in_scale = boost::get(op_desc.GetAttr("input_scale")); + float weight_scale = boost::get(op_desc.GetAttr("weight_scale")); + PBlock *weight1 = + new PBlock(anakin_shape, ::anakin::AK_INT8); + this->engine_->RegistBlock(weight1); + float *weight_data = weight_tensor->data(); + std::vector weight_int8; + int weight_num = weight_tensor->numel(); + for (int i = 0; i < weight_tensor->numel(); i++) { + bool is_valid_int8 = + ((weight_data[i] >= -128) && (weight_data[i] <= 127)); + PADDLE_ENFORCE(is_valid_int8, + "We are in anakin subgraph int8 mode, the weight of conv " + "should be in range [-128, 127]"); + weight_int8.push_back(static_cast(weight_data[i])); + } + memcpy(static_cast(weight1->h_tensor().mutable_data()), + static_cast(weight_int8.data()), sizeof(char) * weight_num); + weight1->d_tensor().set_shape(anakin_shape); + weight1->d_tensor().copy_from(weight1->h_tensor()); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8); + this->engine_->Graph()->SetWeightsScale(op_name, + {weight_scale / int8_range}, false); + this->engine_->AddTensorScale(input_name, in_scale / int8_range); + } else { + auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace()); + auto weight_shape = framework::vectorize2int(weight_tensor->dims()); + auto *weight1 = pblock_from_tensor( + *weight_tensor, weight_shape, this->engine_); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + auto weight2 = pblock_from_var(*b_v, this->engine_); + this->engine_->AddOpAttr(op_name, "weight_2", *weight2); + } } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/conv2d_fusion.h b/paddle/fluid/inference/anakin/convert/conv2d_fusion.h index 0d9ef28183..768814d3f9 100644 --- a/paddle/fluid/inference/anakin/convert/conv2d_fusion.h +++ b/paddle/fluid/inference/anakin/convert/conv2d_fusion.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class Conv2dFusionOpConverter : public AnakinOpConverter { +template +class Conv2dFusionOpConverter : public AnakinOpConverter { public: Conv2dFusionOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/density_prior_box.cc b/paddle/fluid/inference/anakin/convert/density_prior_box.cc index 30796f7592..5bbaeb57a7 100644 --- a/paddle/fluid/inference/anakin/convert/density_prior_box.cc +++ b/paddle/fluid/inference/anakin/convert/density_prior_box.cc @@ -17,17 +17,14 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void DensityPriorBoxOpConverter::operator()( +template +void DensityPriorBoxOpConverter::operator()( const framework::proto::OpDesc& op, const framework::BlockDesc& block_desc, const framework::Scope& scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); @@ -81,22 +78,30 @@ void DensityPriorBoxOpConverter::operator()( std::vector temp_v = {}; - engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name}); - engine_->AddOpAttr>(op_name, "min_size", min_sizes); - engine_->AddOpAttr>(op_name, "max_size", max_sizes); - engine_->AddOpAttr>(op_name, "aspect_ratio", aspect_ratios); - engine_->AddOpAttr>(op_name, "fixed_size", fixed_sizes); - engine_->AddOpAttr>(op_name, "fixed_ratio", fixed_ratios); - engine_->AddOpAttr>(op_name, "density", dens); - engine_->AddOpAttr(op_name, "is_flip", is_flip); - engine_->AddOpAttr(op_name, "is_clip", is_clip); - engine_->AddOpAttr>(op_name, "variance", variances); - engine_->AddOpAttr(op_name, "img_h", static_cast(0)); - engine_->AddOpAttr(op_name, "img_w", static_cast(0)); - engine_->AddOpAttr(op_name, "step_h", step_h); - engine_->AddOpAttr(op_name, "step_w", step_w); - engine_->AddOpAttr(op_name, "offset", offset); - engine_->AddOpAttr>(op_name, "order", t_order); + this->engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, + {output_name}); + this->engine_->template AddOpAttr>(op_name, "min_size", + min_sizes); + this->engine_->template AddOpAttr>(op_name, "max_size", + max_sizes); + this->engine_->template AddOpAttr>(op_name, "aspect_ratio", + aspect_ratios); + this->engine_->template AddOpAttr>(op_name, "fixed_size", + fixed_sizes); + this->engine_->template AddOpAttr>(op_name, "fixed_ratio", + fixed_ratios); + this->engine_->template AddOpAttr>(op_name, "density", dens); + this->engine_->AddOpAttr(op_name, "is_flip", is_flip); + this->engine_->AddOpAttr(op_name, "is_clip", is_clip); + this->engine_->template AddOpAttr>(op_name, "variance", + variances); + this->engine_->AddOpAttr(op_name, "img_h", static_cast(0)); + this->engine_->AddOpAttr(op_name, "img_w", static_cast(0)); + this->engine_->AddOpAttr(op_name, "step_h", step_h); + this->engine_->AddOpAttr(op_name, "step_w", step_w); + this->engine_->AddOpAttr(op_name, "offset", offset); + this->engine_->template AddOpAttr>(op_name, "order", + t_order); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/density_prior_box.h b/paddle/fluid/inference/anakin/convert/density_prior_box.h index bf9210711a..5714f57a04 100644 --- a/paddle/fluid/inference/anakin/convert/density_prior_box.h +++ b/paddle/fluid/inference/anakin/convert/density_prior_box.h @@ -22,7 +22,9 @@ namespace paddle { namespace inference { namespace anakin { -class DensityPriorBoxOpConverter : public AnakinOpConverter { +template +class DensityPriorBoxOpConverter + : public AnakinOpConverter { public: DensityPriorBoxOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/detection_out.cc b/paddle/fluid/inference/anakin/convert/detection_out.cc index 262ad28a65..73dd6f2832 100644 --- a/paddle/fluid/inference/anakin/convert/detection_out.cc +++ b/paddle/fluid/inference/anakin/convert/detection_out.cc @@ -16,19 +16,14 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; - namespace paddle { namespace inference { namespace anakin { -void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void DetectionOutOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); auto target_name = op_desc.Input("TargetBox").front(); auto prior_box_name = op_desc.Input("PriorBox").front(); @@ -52,18 +47,19 @@ void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op, "Not support encode_center_size code_type in DetectionOut of anakin"); } - engine_->AddOp(op_name, "DetectionOutput", - {target_name, scores_name, prior_box_name}, {output_name}); - engine_->AddOpAttr(op_name, "share_location", true); - engine_->AddOpAttr(op_name, "variance_encode_in_target", false); - engine_->AddOpAttr(op_name, "class_num", static_cast(0)); - engine_->AddOpAttr(op_name, "background_id", background_label); - engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k); - engine_->AddOpAttr(op_name, "code_type", anakin_code_type); - engine_->AddOpAttr(op_name, "conf_thresh", score_threshold); - engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k); - engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold); - engine_->AddOpAttr(op_name, "nms_eta", nms_eta); + this->engine_->AddOp(op_name, "DetectionOutput", + {target_name, scores_name, prior_box_name}, + {output_name}); + this->engine_->AddOpAttr(op_name, "share_location", true); + this->engine_->AddOpAttr(op_name, "variance_encode_in_target", false); + this->engine_->AddOpAttr(op_name, "class_num", static_cast(0)); + this->engine_->AddOpAttr(op_name, "background_id", background_label); + this->engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k); + this->engine_->AddOpAttr(op_name, "code_type", anakin_code_type); + this->engine_->AddOpAttr(op_name, "conf_thresh", score_threshold); + this->engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k); + this->engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold); + this->engine_->AddOpAttr(op_name, "nms_eta", nms_eta); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/detection_out.h b/paddle/fluid/inference/anakin/convert/detection_out.h index ca78f10fdc..c34342a66c 100644 --- a/paddle/fluid/inference/anakin/convert/detection_out.h +++ b/paddle/fluid/inference/anakin/convert/detection_out.h @@ -22,7 +22,8 @@ namespace paddle { namespace inference { namespace anakin { -class DetectionOutOpConverter : public AnakinOpConverter { +template +class DetectionOutOpConverter : public AnakinOpConverter { public: DetectionOutOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/dropout.cc b/paddle/fluid/inference/anakin/convert/dropout.cc index bc9b26dcf2..6c5f80b5f8 100644 --- a/paddle/fluid/inference/anakin/convert/dropout.cc +++ b/paddle/fluid/inference/anakin/convert/dropout.cc @@ -16,24 +16,16 @@ #include #include #include - -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; -using anakin::PTuple; +#include "paddle/fluid/inference/anakin/convert/helper.h" namespace paddle { namespace inference { namespace anakin { -void DropoutOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void DropoutOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Mask").size(), 1); @@ -43,21 +35,17 @@ void DropoutOpConverter::operator()(const framework::proto::OpDesc &op, auto out_name = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Scale", {x_name}, {out_name}); + this->engine_->AddOp(op_name, "Scale", {x_name}, {out_name}); auto dropout_prob = boost::get(op_desc.GetAttr("dropout_prob")); auto factor = 1 - dropout_prob; - Shape shape1(std::vector({1, 1, 1, 1})); - auto *weight1 = - GraphGlobalMem::Global().template new_block(shape1); - auto *factor_data = static_cast(weight1->h_tensor().mutable_data()); - float weight1_data[] = {factor}; - std::copy(std::begin(weight1_data), std::end(weight1_data), factor_data); + auto *weight1 = pblock_from_vector( + std::vector({factor}), this->engine_); - engine_->AddOpAttr(op_name, "weight_1", *weight1); - engine_->AddOpAttr(op_name, "axis", 0); - engine_->AddOpAttr(op_name, "num_axes", 0); - engine_->AddOpAttr(op_name, "bias_term", false); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + this->engine_->AddOpAttr(op_name, "axis", 0); + this->engine_->AddOpAttr(op_name, "num_axes", 0); + this->engine_->AddOpAttr(op_name, "bias_term", false); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/dropout.h b/paddle/fluid/inference/anakin/convert/dropout.h index 11412e217e..801aa3dd16 100644 --- a/paddle/fluid/inference/anakin/convert/dropout.h +++ b/paddle/fluid/inference/anakin/convert/dropout.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class DropoutOpConverter : public AnakinOpConverter { +template +class DropoutOpConverter : public AnakinOpConverter { public: DropoutOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/elementwise.cc b/paddle/fluid/inference/anakin/convert/elementwise.cc index fe9a896d82..dd32baa0b9 100644 --- a/paddle/fluid/inference/anakin/convert/elementwise.cc +++ b/paddle/fluid/inference/anakin/convert/elementwise.cc @@ -17,20 +17,14 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void ElementwiseAddOpConverter::operator()( +template +void ElementwiseAddOpConverter::operator()( const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); @@ -43,14 +37,16 @@ void ElementwiseAddOpConverter::operator()( auto out_name = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Eltwise", {x_name, y_name}, {out_name}); + this->engine_->AddOp(op_name, "Eltwise", {x_name, y_name}, {out_name}); std::string elementwise_type = "Add"; - engine_->AddOpAttr(op_name, "type", elementwise_type); + this->engine_->template AddOpAttr(op_name, "type", + elementwise_type); std::vector coeff = {1.0, 1.0}; - engine_->AddOpAttr>(op_name, "coeff", coeff); + this->engine_->template AddOpAttr>(op_name, "coeff", coeff); } -void ElementwiseMulOpConverter::operator()( +template +void ElementwiseMulOpConverter::operator()( const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); @@ -63,21 +59,12 @@ void ElementwiseMulOpConverter::operator()( auto out_name = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Scale", {x_name, y_name}, {out_name}); - // Fill a number to weight_1 as a placeholder. - Shape shape1(std::vector({1, 1, 1, 1})); - auto *weight1 = - GraphGlobalMem::Global().template new_block(shape1); - auto *placeholder_data = - static_cast(weight1->h_tensor().mutable_data()); - float weight1_data[] = {1}; - std::copy(std::begin(weight1_data), std::end(weight1_data), placeholder_data); - engine_->AddOpAttr(op_name, "weight_1", *weight1); - - auto axis = boost::get(op_desc.GetAttr("axis")); - engine_->AddOpAttr(op_name, "axis", axis); - engine_->AddOpAttr(op_name, "num_axes", 1); - engine_->AddOpAttr(op_name, "bias_term", false); + this->engine_->AddOp(op_name, "Eltwise", {x_name, y_name}, {out_name}); + std::string elementwise_type = "Prod"; + this->engine_->template AddOpAttr(op_name, "type", + elementwise_type); + std::vector coeff = {1.0, 1.0}; + this->engine_->template AddOpAttr>(op_name, "coeff", coeff); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/elementwise.h b/paddle/fluid/inference/anakin/convert/elementwise.h index e4664493a9..190a8b55f0 100644 --- a/paddle/fluid/inference/anakin/convert/elementwise.h +++ b/paddle/fluid/inference/anakin/convert/elementwise.h @@ -20,7 +20,9 @@ namespace paddle { namespace inference { namespace anakin { -class ElementwiseAddOpConverter : public AnakinOpConverter { +template +class ElementwiseAddOpConverter + : public AnakinOpConverter { public: ElementwiseAddOpConverter() = default; @@ -33,7 +35,9 @@ class ElementwiseAddOpConverter : public AnakinOpConverter { private: }; -class ElementwiseMulOpConverter : public AnakinOpConverter { +template +class ElementwiseMulOpConverter + : public AnakinOpConverter { public: ElementwiseMulOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/fc.cc b/paddle/fluid/inference/anakin/convert/fc.cc index a80a1a47e9..0621e3377b 100644 --- a/paddle/fluid/inference/anakin/convert/fc.cc +++ b/paddle/fluid/inference/anakin/convert/fc.cc @@ -16,23 +16,19 @@ #include #include #include - -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; +#include "paddle/fluid/inference/anakin/convert/helper.h" namespace paddle { namespace inference { namespace anakin { -void FcBaseOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void FcBaseOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); auto input_names = op_desc.InputNames(); - bool with_bias = input_names.size() == 3; + bool with_bias = input_names.size() >= 3; std::string w_name = "Y"; std::string i_name = "X"; @@ -46,71 +42,74 @@ void FcBaseOpConverter::operator()(const framework::proto::OpDesc &op, // get weights auto *y_v = scope.FindVar(op_desc.Input(w_name).front()); PADDLE_ENFORCE_NOT_NULL(y_v); - auto *y_t = y_v->GetMutable(); - - auto input_name = op_desc.Input(i_name).front(); - auto output_name = op_desc.Output("Out").front(); + auto weight_tensor = tensor_from_var(*y_v, platform::CPUPlace()); + auto weight_shape = framework::vectorize2int(weight_tensor->dims()); - engine_->AddOp(op_name, "Dense", {input_name}, {output_name}); - engine_->AddOpAttr(op_name, "bias_term", with_bias); - engine_->AddOpAttr(op_name, "axis", 1); - - auto weight_shape = framework::vectorize2int(y_t->dims()); int out_dim = weight_shape[1]; - engine_->AddOpAttr(op_name, "out_dim", out_dim); const int w_m = weight_shape[0]; const int w_k = weight_shape[1]; - if (weight_shape.size() < 4UL) { - weight_shape.insert(weight_shape.begin(), 4UL - weight_shape.size(), 1); - } - Shape anakin_shape(weight_shape); + auto input_name = op_desc.Input(i_name).front(); + auto output_name = op_desc.Output("Out").front(); - framework::LoDTensor weight_tensor; - weight_tensor.Resize(y_t->dims()); - TensorCopySync((*y_t), platform::CPUPlace(), &weight_tensor); - auto *weight_data = weight_tensor.data(); - PADDLE_ENFORCE(w_m * w_k == weight_tensor.numel()); + this->engine_->AddOp(op_name, "Dense", {input_name}, {output_name}); + this->engine_->AddOpAttr(op_name, "bias_term", with_bias); + this->engine_->AddOpAttr(op_name, "axis", 1); + this->engine_->AddOpAttr(op_name, "out_dim", out_dim); - std::vector trans_weight_data(weight_tensor.numel()); + auto *weight_data = weight_tensor->data(); + PADDLE_ENFORCE(w_m * w_k == weight_tensor->numel()); + + std::vector trans_weight_data(weight_tensor->numel()); for (int i = 0; i < w_m; i++) { for (int j = 0; j < w_k; j++) { trans_weight_data[i + j * w_m] = weight_data[i * w_k + j]; } } - auto *weight1 = - GraphGlobalMem::Global().template new_block(anakin_shape); - float *cpu_data = static_cast(weight1->h_tensor().mutable_data()); - std::copy_n(trans_weight_data.data(), weight_tensor.numel(), cpu_data); - weight1->d_tensor().set_shape(anakin_shape); - weight1->d_tensor().copy_from(weight1->h_tensor()); - engine_->AddOpAttr(op_name, "weight_1", *weight1); + + int weight_num = weight_tensor->numel(); + bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); + if (enable_int8) { + if (weight_shape.size() < 4UL) { + weight_shape.insert(weight_shape.begin(), 4UL - weight_shape.size(), 1); + } + ::anakin::saber::Shape anakin_shape(weight_shape); + const float int8_range = 127.; + float in_scale = boost::get(op_desc.GetAttr("input_scale")); + float weight_scale = boost::get(op_desc.GetAttr("weight_scale")); + PBlock *weight1 = + new PBlock(anakin_shape, ::anakin::AK_INT8); + this->engine_->RegistBlock(weight1); + std::vector weight_int8; + for (int i = 0; i < weight_num; i++) { + bool is_valid_int8 = + ((trans_weight_data[i] >= -128) && (trans_weight_data[i] <= 127)); + PADDLE_ENFORCE(is_valid_int8, + "We are in anakin subgraph int8 mode, the weight of fc " + "should be in range [-128, 127]"); + weight_int8.push_back(static_cast(trans_weight_data[i])); + } + memcpy(static_cast(weight1->h_tensor().mutable_data()), + static_cast(weight_int8.data()), sizeof(char) * weight_num); + weight1->d_tensor().set_shape(anakin_shape); + weight1->d_tensor().copy_from(weight1->h_tensor()); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8); + this->engine_->Graph()->SetWeightsScale(op_name, + {weight_scale / int8_range}, false); + this->engine_->AddTensorScale(input_name, in_scale / int8_range); + } else { + auto *weight1 = pblock_from_vector(trans_weight_data, + this->engine_); + this->engine_->AddOpAttr(op_name, "weight_1", *weight1); + } // get bias if (with_bias) { auto *b_v = scope.FindVar(op_desc.Input("Bias").front()); PADDLE_ENFORCE_NOT_NULL(b_v); - auto *b_t = b_v->GetMutable(); - - auto bias_shape = framework::vectorize2int(b_t->dims()); - framework::LoDTensor bias_tensor; - bias_tensor.Resize(b_t->dims()); - TensorCopySync((*b_t), platform::CPUPlace(), &bias_tensor); - auto *bias_data = bias_tensor.data(); - bias_shape.insert(bias_shape.begin(), 1); - bias_shape.insert(bias_shape.begin(), 1); - bias_shape.insert(bias_shape.begin(), 1); - // bias_shape.push_back(1); - // bias_shape.push_back(1); - Shape anakin_bias_shape(bias_shape); - - auto *weight2 = GraphGlobalMem::Global().template new_block( - anakin_bias_shape); - float *cpu_data2 = static_cast(weight2->h_tensor().mutable_data()); - std::copy_n(bias_data, bias_tensor.numel(), cpu_data2); - weight2->d_tensor().set_shape(anakin_bias_shape); - weight2->d_tensor().copy_from(weight2->h_tensor()); - engine_->AddOpAttr(op_name, "weight_2", *weight2); + auto weight2 = pblock_from_var(*b_v, this->engine_); + this->engine_->AddOpAttr(op_name, "weight_2", *weight2); } } diff --git a/paddle/fluid/inference/anakin/convert/fc.h b/paddle/fluid/inference/anakin/convert/fc.h index fb461908b3..6fe65e3ecd 100644 --- a/paddle/fluid/inference/anakin/convert/fc.h +++ b/paddle/fluid/inference/anakin/convert/fc.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class FcBaseOpConverter : public AnakinOpConverter { +template +class FcBaseOpConverter : public AnakinOpConverter { public: FcBaseOpConverter() = default; @@ -32,13 +33,15 @@ class FcBaseOpConverter : public AnakinOpConverter { }; // with bias -class FcOpConverter : public FcBaseOpConverter { +template +class FcOpConverter : public FcBaseOpConverter { public: FcOpConverter() = default; }; // without bias -class MulOpConverter : public FcBaseOpConverter { +template +class MulOpConverter : public FcBaseOpConverter { public: MulOpConverter() = default; }; diff --git a/paddle/fluid/inference/anakin/convert/flatten.cc b/paddle/fluid/inference/anakin/convert/flatten.cc index 7f5c151096..7ce519a4de 100644 --- a/paddle/fluid/inference/anakin/convert/flatten.cc +++ b/paddle/fluid/inference/anakin/convert/flatten.cc @@ -15,20 +15,16 @@ #include "paddle/fluid/inference/anakin/convert/flatten.h" #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void FlattenOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void FlattenOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1UL); @@ -41,8 +37,8 @@ void FlattenOpConverter::operator()(const framework::proto::OpDesc &op, std::vector out_dims = {0, -1, 1, 1}; auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Reshape", {input}, {output}); - engine_->AddOpAttr>(op_name, "dims", out_dims); + this->engine_->AddOp(op_name, "Reshape", {input}, {output}); + this->engine_->template AddOpAttr>(op_name, "dims", out_dims); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/flatten.h b/paddle/fluid/inference/anakin/convert/flatten.h index c9cc0006eb..6e5e059927 100644 --- a/paddle/fluid/inference/anakin/convert/flatten.h +++ b/paddle/fluid/inference/anakin/convert/flatten.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class FlattenOpConverter : public AnakinOpConverter { +template +class FlattenOpConverter : public AnakinOpConverter { public: FlattenOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/helper.cc b/paddle/fluid/inference/anakin/convert/helper.cc new file mode 100644 index 0000000000..7804619bf8 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/helper.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/helper.h" + +namespace paddle { +namespace inference { +namespace anakin { + +std::unique_ptr tensor_from_var( + const framework::Variable& var, const platform::Place& place) { + auto& src = var.Get(); + std::unique_ptr dst(new framework::LoDTensor()); + dst->Resize(src.dims()); + TensorCopySync((src), place, dst.get()); + return dst; +} + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/helper.h b/paddle/fluid/inference/anakin/convert/helper.h new file mode 100644 index 0000000000..7b0fb211dc --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/helper.h @@ -0,0 +1,95 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/inference/anakin/engine.h" + +#include "framework/core/net/net.h" +#include "framework/core/types.h" +#include "framework/graph/graph.h" +#include "framework/graph/graph_global_mem.h" +#include "saber/saber_types.h" + +using anakin::saber::Shape; +using anakin::AK_FLOAT; +using anakin::AK_INT8; +using anakin::PBlock; + +namespace paddle { +namespace inference { +namespace anakin { + +std::unique_ptr tensor_from_var( + const framework::Variable& var, const platform::Place& place); + +template +PBlock* pblock_from_tensor(const framework::LoDTensor& tensor, + std::vector shape_vec, + AnakinEngine* engine) { + while (shape_vec.size() < 4) { + shape_vec.insert(shape_vec.begin(), 1); + } + Shape shape(shape_vec); + PBlock* weight = new PBlock(shape, AK_FLOAT); + engine->RegistBlock(weight); + float* cpu_data = static_cast(weight->h_tensor().mutable_data()); + std::copy_n(tensor.data(), tensor.numel(), cpu_data); + weight->d_tensor().set_shape(shape); + weight->d_tensor().copy_from(weight->h_tensor()); + return weight; +} + +template +PBlock* pblock_from_vector(const std::vector& vec, + std::vector shape_vec, + AnakinEngine* engine) { + while (shape_vec.size() < 4) { + shape_vec.insert(shape_vec.begin(), 1); + } + Shape shape(shape_vec); + PBlock* weight = new PBlock(shape, AK_FLOAT); + engine->RegistBlock(weight); + auto* weight_data = static_cast(weight->h_tensor().mutable_data()); + std::copy(std::begin(vec), std::end(vec), weight_data); + weight->d_tensor().set_shape(shape); + weight->d_tensor().copy_from(weight->h_tensor()); + return weight; +} + +template +PBlock* pblock_from_vector(const std::vector& vec, + AnakinEngine* engine) { + int size = vec.size(); + return pblock_from_vector( + vec, std::vector({1, 1, 1, size}), engine); +} + +template +PBlock* pblock_from_var(const framework::Variable& var, + AnakinEngine* engine) { + auto tensor = tensor_from_var(var, platform::CPUPlace()); + auto shape = framework::vectorize2int(tensor->dims()); + return pblock_from_tensor(*tensor, shape, engine); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/im2sequence.cc b/paddle/fluid/inference/anakin/convert/im2sequence.cc index 2cc330c382..5a4e3e61c5 100644 --- a/paddle/fluid/inference/anakin/convert/im2sequence.cc +++ b/paddle/fluid/inference/anakin/convert/im2sequence.cc @@ -17,23 +17,16 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void Im2SequenceConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 0); @@ -43,17 +36,19 @@ void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op, auto out_name = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Im2Sequence", {x_name}, {out_name}); + this->engine_->AddOp(op_name, "Im2Sequence", {x_name}, {out_name}); std::vector dilations = {1, 1}; auto paddings = boost::get>(op_desc.GetAttr("paddings")); auto strides = boost::get>(op_desc.GetAttr("strides")); auto kernels = boost::get>(op_desc.GetAttr("kernels")); - engine_->AddOpAttr>(op_name, "paddings", paddings); - engine_->AddOpAttr>(op_name, "strides", strides); - engine_->AddOpAttr>(op_name, "window_size", kernels); - engine_->AddOpAttr>(op_name, "dilations", dilations); + this->engine_->template AddOpAttr>(op_name, "paddings", paddings); + this->engine_->template AddOpAttr>(op_name, "strides", strides); + this->engine_->template AddOpAttr>(op_name, "window_size", + kernels); + this->engine_->template AddOpAttr>(op_name, "dilations", + dilations); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/im2sequence.h b/paddle/fluid/inference/anakin/convert/im2sequence.h index 714679c1d9..8241d4d6f9 100644 --- a/paddle/fluid/inference/anakin/convert/im2sequence.h +++ b/paddle/fluid/inference/anakin/convert/im2sequence.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class Im2SequenceConverter : public AnakinOpConverter { +template +class Im2SequenceConverter : public AnakinOpConverter { public: Im2SequenceConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/op_converter.h b/paddle/fluid/inference/anakin/convert/op_converter.h index 1ca62658ef..a6ae51bd4b 100644 --- a/paddle/fluid/inference/anakin/convert/op_converter.h +++ b/paddle/fluid/inference/anakin/convert/op_converter.h @@ -32,10 +32,10 @@ namespace paddle { namespace inference { namespace anakin { -using AnakinNvEngine = - AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>; - +template class AnakinOpConverter { + using AnakinEngineT = AnakinEngine; + public: AnakinOpConverter() = default; @@ -45,7 +45,7 @@ class AnakinOpConverter { void ConvertOp(const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, const std::unordered_set ¶meters, - const framework::Scope &scope, AnakinNvEngine *engine, + const framework::Scope &scope, AnakinEngineT *engine, bool test_mode = false) { framework::OpDesc op_desc(op, nullptr); std::string op_type = op_desc.Type(); @@ -65,7 +65,7 @@ class AnakinOpConverter { void ConvertBlock(framework::BlockDesc *block_desc, const std::unordered_set ¶meters, - const framework::Scope &scope, AnakinNvEngine *engine) { + const framework::Scope &scope, AnakinEngineT *engine) { std::unique_lock lock(mutex_); framework::proto::BlockDesc *block = block_desc->Proto(); for (auto i = 0; i < block->ops_size(); i++) { @@ -79,9 +79,8 @@ class AnakinOpConverter { framework::BlockDesc *block_desc, framework::Scope *scope, const std::vector &inputs, const std::unordered_set ¶meters, - const std::vector &outputs, AnakinNvEngine *engine) { + const std::vector &outputs, AnakinEngineT *engine) { ConvertBlock(block_desc, parameters, *scope, engine); - engine->Freeze(); // if the max_batch size int max_batch_size = engine->GetMaxBatchSize(); PADDLE_ENFORCE(max_batch_size > 0, @@ -91,6 +90,18 @@ class AnakinOpConverter { // the block_desc. auto max_input_shape = engine->GetMaxInputShape(); std::map> temp_max_input_shape; + // Register outputs with anakin using the RegistVar interface before Freeze. + // Note that RegistVar's parameters can only be outputs, not inputs. + for (auto &output : outputs) { + engine->Graph()->RegistVar(output); + } + engine->Freeze(); + // Add scale for tensor in int8 mode. + auto tensor_scales = engine->GetTensorScales(); + + for (auto &item : tensor_scales) { + engine->Graph()->SetVarScale(item.first, item.second); + } for (auto &input : inputs) { if (parameters.count(input)) continue; @@ -99,7 +110,7 @@ class AnakinOpConverter { input_shape[0] = max_batch_size; if (max_input_shape.count(input)) { PADDLE_ENFORCE(max_input_shape[input].size() == 4, - "the dimensions of max_input_shape setted from " + "the dimensions of max_input_shape setted from " "config->EnableAnakinEngine must be 4"); for (int i = 1; i < 4; i++) { input_shape[i] = max_input_shape[input][i]; @@ -118,50 +129,104 @@ class AnakinOpConverter { } temp_max_input_shape[input] = input_shape; engine->SetInputShape(input, input_shape); - engine->Graph()->RegistVar(input); // For share from data. } engine->SetMaxInputShape(temp_max_input_shape); engine->Optimize(); - - // For anakin share with fluid tensor. - engine->AllocTmpMem(); - engine->InitGraph(); + engine->InitNet(); } - void SetEngine(AnakinNvEngine *engine) { engine_ = engine; } + void SetEngine(AnakinEngineT *engine) { engine_ = engine; } virtual ~AnakinOpConverter() {} protected: bool test_mode_; - AnakinNvEngine *engine_{nullptr}; + AnakinEngineT *engine_{nullptr}; private: - std::unordered_map converters_; + std::unordered_map *> + converters_; framework::Scope *scope_{nullptr}; std::mutex mutex_; }; +template class AnakinOpConverter<::anakin::saber::NV, + ::anakin::Precision::FP32>; +template class AnakinOpConverter<::anakin::saber::NV, + ::anakin::Precision::INT8>; + +template class AnakinOpConverter<::anakin::saber::X86, + ::anakin::Precision::FP32>; +template class AnakinOpConverter<::anakin::saber::X86, + ::anakin::Precision::INT8>; } // namespace anakin } // namespace inference } // namespace paddle -#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \ - struct anakin_##op_type__##_converter \ - : public ::paddle::framework::Registrar { \ - anakin_##op_type__##_converter() { \ - LOG(INFO) << "register convert " << #op_type__; \ - ::paddle::inference::Registry< \ - ::paddle::inference::anakin::AnakinOpConverter>::Global() \ - .Register<::paddle::inference::anakin::Converter__>(#op_type__); \ - } \ - }; \ - anakin_##op_type__##_converter anakin_##op_type__##_converter__; \ - int TouchConverterRegister_anakin_##op_type__() { \ - anakin_##op_type__##_converter__.Touch(); \ - return 0; \ +#define REGISTER_ANAKIN_OP_CONVERTER_BASE(op_type__, Converter__, \ + place_type__, place_class__, \ + precision_type__, precision_class__) \ + struct anakin_##op_type__##_##place_type__##_##precision_type__##_converter \ + : public ::paddle::framework::Registrar { \ + anakin_##op_type__##_##place_type__##_##precision_type__##_converter() { \ + LOG(INFO) << "register convert " << #op_type__ << " "; \ + ::paddle::inference::Registry< \ + ::paddle::inference::anakin::AnakinOpConverter< \ + place_class__, precision_class__>>::Global() \ + .Register(#op_type__); \ + } \ + }; \ + anakin_##op_type__##_##place_type__##_##precision_type__##_converter \ + anakin_##op_type__##_##place_type__##_##precision_type__##_converter__; \ + int Touch_anakin_##op_type__##_##place_type__##_##precision_type__() { \ + anakin_##op_type__##_##place_type__##_##precision_type__##_converter__ \ + .Touch(); \ + return 0; \ } -#define USE_ANAKIN_CONVERTER(op_type__) \ - extern int TouchConverterRegister_anakin_##op_type__(); \ - int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \ - TouchConverterRegister_anakin_##op_type__(); +#define WRAP(...) __VA_ARGS__ + +#define REGISTER_CUDA_ANAKIN_OP_CONVERTER(op_type__, Converter__, \ + precision_type__) \ + REGISTER_ANAKIN_OP_CONVERTER_BASE( \ + op_type__, \ + ::paddle::inference::anakin::Converter__, \ + CUDA, ::anakin::saber::NV, precision_type__, \ + ::anakin::Precision::precision_type__) + +#define REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, \ + precision_type__) \ + REGISTER_ANAKIN_OP_CONVERTER_BASE( \ + op_type__, \ + ::paddle::inference::anakin::Converter__, \ + CPU, ::anakin::saber::X86, precision_type__, \ + ::anakin::Precision::precision_type__) + +#ifdef PADDLE_WITH_CUDA +#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \ + REGISTER_CUDA_ANAKIN_OP_CONVERTER(op_type__, Converter__, FP32); \ + REGISTER_CUDA_ANAKIN_OP_CONVERTER(op_type__, Converter__, INT8); \ + REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, FP32); \ + REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, INT8) +#else +#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \ + REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, FP32); \ + REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, INT8) +#endif + +#define USE_ANAKIN_CONVERTER_BASE(op_type__, place_type__, precision_type__) \ + extern int Touch_anakin_##op_type__##_##place_type__##_##precision_type__(); \ + int use_converter_anakin_##op_type__##_##place_type__##_##precision_type__ \ + __attribute__((unused)) = \ + Touch_anakin_##op_type__##_##place_type__##_##precision_type__(); + +#define USE_ANAKIN_CONVERTER(op_type__) \ + USE_ANAKIN_CONVERTER_BASE(op_type__, CUDA, FP32) +#define USE_INT8_ANAKIN_CONVERTER(op_type__) \ + USE_ANAKIN_CONVERTER_BASE(op_type__, CUDA, INT8) + +#define USE_CPU_ANAKIN_CONVERTER(op_type__) \ + USE_ANAKIN_CONVERTER_BASE(op_type__, CPU, FP32) +#define USE_CPU_INT8_ANAKIN_CONVERTER(op_type__) \ + USE_ANAKIN_CONVERTER_BASE(op_type__, CPU, INT8) diff --git a/paddle/fluid/inference/anakin/convert/pool2d.cc b/paddle/fluid/inference/anakin/convert/pool2d.cc index 87eefe712a..11e7c717fd 100644 --- a/paddle/fluid/inference/anakin/convert/pool2d.cc +++ b/paddle/fluid/inference/anakin/convert/pool2d.cc @@ -17,23 +17,16 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void Pool2dOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); @@ -65,13 +58,13 @@ void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op, PADDLE_THROW("TensorRT unsupported pooling type!"); } - engine_->AddOp(op_name, "Pooling", {x_name}, {y_name}); - engine_->AddOpAttr>(op_name, "pool_size", ksize); - engine_->AddOpAttr>(op_name, "strides", strides); - engine_->AddOpAttr>(op_name, "padding", paddings); - engine_->AddOpAttr(op_name, "method", anakin_pool_type); - engine_->AddOpAttr(op_name, "global_pooling", global_pooling); - engine_->AddOpAttr(op_name, "cmp_out_shape_floor_as_conv", !ceil_mode); + this->engine_->AddOp(op_name, "Pooling", {x_name}, {y_name}); + this->engine_->template AddOpAttr>(op_name, "pool_size", ksize); + this->engine_->template AddOpAttr>(op_name, "strides", strides); + this->engine_->template AddOpAttr>(op_name, "padding", paddings); + this->engine_->AddOpAttr(op_name, "method", anakin_pool_type); + this->engine_->AddOpAttr(op_name, "global_pooling", global_pooling); + this->engine_->AddOpAttr(op_name, "cmp_out_shape_floor_as_conv", !ceil_mode); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/pool2d.h b/paddle/fluid/inference/anakin/convert/pool2d.h index ec28e48ac8..7a06ff1b66 100644 --- a/paddle/fluid/inference/anakin/convert/pool2d.h +++ b/paddle/fluid/inference/anakin/convert/pool2d.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class Pool2dOpConverter : public AnakinOpConverter { +template +class Pool2dOpConverter : public AnakinOpConverter { public: Pool2dOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/relu.cc b/paddle/fluid/inference/anakin/convert/relu.cc index 993437d014..0085340663 100644 --- a/paddle/fluid/inference/anakin/convert/relu.cc +++ b/paddle/fluid/inference/anakin/convert/relu.cc @@ -16,19 +16,30 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; - namespace paddle { namespace inference { namespace anakin { -void ReluOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void ReluOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + auto input_name = op_desc.Input("X").front(); + auto output_name = op_desc.Output("Out").front(); + + this->engine_->AddOp(op_name, "ReLU", {input_name}, {output_name}); + this->engine_->AddOpAttr(op_name, "alpha", 0); +} + +template +void LeakyReluOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); @@ -37,8 +48,9 @@ void ReluOpConverter::operator()(const framework::proto::OpDesc &op, auto input_name = op_desc.Input("X").front(); auto output_name = op_desc.Output("Out").front(); - engine_->AddOp(op_name, "ReLU", {input_name}, {output_name}); - engine_->AddOpAttr(op_name, "alpha", 0); + float alpha = boost::get(op_desc.GetAttr("alpha")); + this->engine_->AddOp(op_name, "ReLU", {input_name}, {output_name}); + this->engine_->AddOpAttr(op_name, "alpha", alpha); } } // namespace anakin @@ -46,3 +58,4 @@ void ReluOpConverter::operator()(const framework::proto::OpDesc &op, } // namespace paddle REGISTER_ANAKIN_OP_CONVERTER(relu, ReluOpConverter); +REGISTER_ANAKIN_OP_CONVERTER(leaky_relu, LeakyReluOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/relu.h b/paddle/fluid/inference/anakin/convert/relu.h index 6ede506511..f366f05a94 100644 --- a/paddle/fluid/inference/anakin/convert/relu.h +++ b/paddle/fluid/inference/anakin/convert/relu.h @@ -22,7 +22,8 @@ namespace paddle { namespace inference { namespace anakin { -class ReluOpConverter : public AnakinOpConverter { +template +class ReluOpConverter : public AnakinOpConverter { public: ReluOpConverter() = default; @@ -33,6 +34,18 @@ class ReluOpConverter : public AnakinOpConverter { virtual ~ReluOpConverter() {} }; +template +class LeakyReluOpConverter : public AnakinOpConverter { + public: + LeakyReluOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~LeakyReluOpConverter() {} +}; + } // namespace anakin } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/reshape.cc b/paddle/fluid/inference/anakin/convert/reshape.cc index 17e0a1acb5..d73736b7fe 100644 --- a/paddle/fluid/inference/anakin/convert/reshape.cc +++ b/paddle/fluid/inference/anakin/convert/reshape.cc @@ -15,20 +15,16 @@ #include "paddle/fluid/inference/anakin/convert/reshape.h" #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void ReshapeOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void ReshapeOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1UL); @@ -37,13 +33,13 @@ void ReshapeOpConverter::operator()(const framework::proto::OpDesc &op, auto output = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Reshape", {input}, {output}); + this->engine_->AddOp(op_name, "Reshape", {input}, {output}); auto shape = boost::get>(op_desc.GetAttr("shape")); if (shape.size() < 4) { shape.insert(shape.end(), 4 - shape.size(), 1); } - engine_->AddOpAttr>(op_name, "dims", shape); + this->engine_->template AddOpAttr>(op_name, "dims", shape); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/reshape.h b/paddle/fluid/inference/anakin/convert/reshape.h index 9ce2ea2a4f..88de2641e6 100644 --- a/paddle/fluid/inference/anakin/convert/reshape.h +++ b/paddle/fluid/inference/anakin/convert/reshape.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class ReshapeOpConverter : public AnakinOpConverter { +template +class ReshapeOpConverter : public AnakinOpConverter { public: ReshapeOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/roi_align.cc b/paddle/fluid/inference/anakin/convert/roi_align.cc new file mode 100644 index 0000000000..8702f638e1 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/roi_align.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/roi_align.h" +#include +#include + +namespace paddle { +namespace inference { +namespace anakin { + +template +void RoiAlignOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Input("ROIs").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + auto input_x_name = op_desc.Input("X").front(); + auto input_rois_name = op_desc.Input("ROIs").front(); + auto output_name = op_desc.Output("Out").front(); + + auto spatial_scale = boost::get(op_desc.GetAttr("spatial_scale")); + auto pooled_height = boost::get(op_desc.GetAttr("pooled_height")); + auto pooled_width = boost::get(op_desc.GetAttr("pooled_width")); + auto sampling_ratio = boost::get(op_desc.GetAttr("sampling_ratio")); + + this->engine_->AddOp(op_name, "RoiAlign", {input_x_name, input_rois_name}, + {output_name}); + this->engine_->AddOpAttr(op_name, "spatial_scale", spatial_scale); + this->engine_->AddOpAttr(op_name, "pooled_height", pooled_height); + this->engine_->AddOpAttr(op_name, "pooled_width", pooled_width); + this->engine_->AddOpAttr(op_name, "sampling_ratio", sampling_ratio); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(roi_align, RoiAlignOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/roi_align.h b/paddle/fluid/inference/anakin/convert/roi_align.h new file mode 100644 index 0000000000..8b5d23a016 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/roi_align.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +class RoiAlignOpConverter : public AnakinOpConverter { + public: + RoiAlignOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~RoiAlignOpConverter() {} +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/scale.cc b/paddle/fluid/inference/anakin/convert/scale.cc index dd68af4f79..2559ec498c 100644 --- a/paddle/fluid/inference/anakin/convert/scale.cc +++ b/paddle/fluid/inference/anakin/convert/scale.cc @@ -16,19 +16,14 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; - namespace paddle { namespace inference { namespace anakin { -void ScaleOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void ScaleOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); @@ -44,10 +39,10 @@ void ScaleOpConverter::operator()(const framework::proto::OpDesc &op, PADDLE_ENFORCE(bias_after_scale, "The anakin scale layer only support bias after scale now."); - engine_->AddOp(op_name, "Power", {input_name}, {output_name}); - engine_->AddOpAttr(op_name, "shift", bias); - engine_->AddOpAttr(op_name, "scale", scale); - engine_->AddOpAttr(op_name, "power", static_cast(1.0)); + this->engine_->AddOp(op_name, "Power", {input_name}, {output_name}); + this->engine_->AddOpAttr(op_name, "shift", bias); + this->engine_->AddOpAttr(op_name, "scale", scale); + this->engine_->AddOpAttr(op_name, "power", static_cast(1.0)); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/scale.h b/paddle/fluid/inference/anakin/convert/scale.h index ba3bcdd214..f19a920193 100644 --- a/paddle/fluid/inference/anakin/convert/scale.h +++ b/paddle/fluid/inference/anakin/convert/scale.h @@ -22,7 +22,8 @@ namespace paddle { namespace inference { namespace anakin { -class ScaleOpConverter : public AnakinOpConverter { +template +class ScaleOpConverter : public AnakinOpConverter { public: ScaleOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/shuffle_channel.cc b/paddle/fluid/inference/anakin/convert/shuffle_channel.cc new file mode 100644 index 0000000000..fdd2e3182e --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/shuffle_channel.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/shuffle_channel.h" +#include +#include +#include + +using anakin::PTuple; + +namespace paddle { +namespace inference { +namespace anakin { + +template +void ShuffleChannelOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto input = op_desc.Input("X").front(); + auto output = op_desc.Output("Out").front(); + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + this->engine_->AddOp(op_name, "ShuffleChannel", {input}, {output}); + + auto group = boost::get(op_desc.GetAttr("group")); + this->engine_->AddOpAttr(op_name, "group", group); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/shuffle_channel.h b/paddle/fluid/inference/anakin/convert/shuffle_channel.h new file mode 100644 index 0000000000..457a14865a --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/shuffle_channel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +class ShuffleChannelOpConverter + : public AnakinOpConverter { + public: + ShuffleChannelOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~ShuffleChannelOpConverter() {} +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/softmax.cc b/paddle/fluid/inference/anakin/convert/softmax.cc index a6c1e971b1..a4dc5a9156 100644 --- a/paddle/fluid/inference/anakin/convert/softmax.cc +++ b/paddle/fluid/inference/anakin/convert/softmax.cc @@ -14,19 +14,14 @@ #include "paddle/fluid/inference/anakin/convert/softmax.h" -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; - namespace paddle { namespace inference { namespace anakin { -void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void SoftMaxOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1UL); @@ -41,8 +36,8 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op, auto input_shape_in_fluid = input_var_desc->GetShape(); size_t input_dims = input_shape_in_fluid.size(); - engine_->AddOp(op_name, "Softmax", {input}, {output}); - engine_->AddOpAttr(op_name, "axis", static_cast(input_dims - 1)); + this->engine_->AddOp(op_name, "Softmax", {input}, {output}); + this->engine_->AddOpAttr(op_name, "axis", static_cast(input_dims - 1)); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/softmax.h b/paddle/fluid/inference/anakin/convert/softmax.h index a16356d5bb..dc431b5b86 100644 --- a/paddle/fluid/inference/anakin/convert/softmax.h +++ b/paddle/fluid/inference/anakin/convert/softmax.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class SoftMaxOpConverter : public AnakinOpConverter { +template +class SoftMaxOpConverter : public AnakinOpConverter { public: SoftMaxOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/split.cc b/paddle/fluid/inference/anakin/convert/split.cc index ec582c1812..e63edea94a 100644 --- a/paddle/fluid/inference/anakin/convert/split.cc +++ b/paddle/fluid/inference/anakin/convert/split.cc @@ -16,23 +16,16 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void SplitOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void SplitOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); auto input_name = op_desc.Input("X").front(); auto y_names = op_desc.Output("Out"); @@ -51,14 +44,16 @@ void SplitOpConverter::operator()(const framework::proto::OpDesc &op, num_sum += output_lengths[i]; slice_point.push_back(num_sum); } - engine_->AddOp(op_name, "Slice", {input_name}, y_names); - engine_->AddOpAttr(op_name, "axis", axis); - engine_->AddOpAttr>(op_name, "slice_point", slice_point); + this->engine_->AddOp(op_name, "Slice", {input_name}, y_names); + this->engine_->AddOpAttr(op_name, "axis", axis); + this->engine_->template AddOpAttr>(op_name, "slice_point", + slice_point); // slice_dim is useless in anakin - engine_->AddOpAttr(op_name, "slice_dim", 4); + this->engine_->AddOpAttr(op_name, "slice_dim", 4); } } // namespace anakin } // namespace inference } // namespace paddle + REGISTER_ANAKIN_OP_CONVERTER(split, SplitOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/split.h b/paddle/fluid/inference/anakin/convert/split.h index 184112e589..819915315d 100644 --- a/paddle/fluid/inference/anakin/convert/split.h +++ b/paddle/fluid/inference/anakin/convert/split.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class SplitOpConverter : public AnakinOpConverter { +template +class SplitOpConverter : public AnakinOpConverter { public: SplitOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/sum.cc b/paddle/fluid/inference/anakin/convert/sum.cc index 2a4178e237..870c079340 100644 --- a/paddle/fluid/inference/anakin/convert/sum.cc +++ b/paddle/fluid/inference/anakin/convert/sum.cc @@ -17,22 +17,16 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void SumOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, bool test_mode) { +template +void SumOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 2); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); @@ -43,9 +37,10 @@ void SumOpConverter::operator()(const framework::proto::OpDesc &op, std::vector coeff = {1, 1}; std::string elementwise_type = "Add"; - engine_->AddOp(op_name, "Eltwise", input_names, {out_name}); - engine_->AddOpAttr>(op_name, "coeff", coeff); - engine_->AddOpAttr(op_name, "type", elementwise_type); + this->engine_->AddOp(op_name, "Eltwise", input_names, {out_name}); + this->engine_->template AddOpAttr>(op_name, "coeff", coeff); + this->engine_->template AddOpAttr(op_name, "type", + elementwise_type); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/sum.h b/paddle/fluid/inference/anakin/convert/sum.h index b5d402b77f..aefc64c623 100644 --- a/paddle/fluid/inference/anakin/convert/sum.h +++ b/paddle/fluid/inference/anakin/convert/sum.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class SumOpConverter : public AnakinOpConverter { +template +class SumOpConverter : public AnakinOpConverter { public: SumOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/test_activation_op.cc b/paddle/fluid/inference/anakin/convert/test_activation_op.cc index 8bedd4a749..4f898252d2 100644 --- a/paddle/fluid/inference/anakin/convert/test_activation_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_activation_op.cc @@ -21,12 +21,14 @@ namespace paddle { namespace inference { namespace anakin { -static void test_activation_op(const std::string &op_type) { - auto *converter = Registry::Global().Lookup(op_type); - PADDLE_ENFORCE(converter != nullptr); +template +static void test_activation_op(const std::string& op_type, + const platform::DeviceContext& context, + bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("act-X", {10, 6, 1, 1}); validator.DeclOutputVar("act-Out", {10, 6, 1, 1}); framework::OpDesc desc; @@ -34,6 +36,14 @@ static void test_activation_op(const std::string &op_type) { desc.SetInput("X", {"act-X"}); desc.SetOutput("Out", {"act-Out"}); + if (op_type == "swish") { + desc.SetAttr("beta", 1.0f); + } + + if (op_type == "relu6") { + desc.SetAttr("threshold", 6.0f); + } + LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); LOG(INFO) << "execute"; @@ -41,13 +51,74 @@ static void test_activation_op(const std::string &op_type) { validator.Execute(5); } -TEST(sigm_op, test) { test_activation_op("sigmoid"); } -TEST(tanh_op, test) { test_activation_op("tanh"); } +#ifdef PADDLE_WITH_CUDA +TEST(sigm_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_activation_op<::anakin::saber::NV>("sigmoid", ctx, true); +} + +TEST(tanh_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_activation_op<::anakin::saber::NV>("tanh", ctx, true); +} + +TEST(relu6_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_activation_op<::anakin::saber::NV>("relu6", ctx, true); +} + +TEST(swish_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_activation_op<::anakin::saber::NV>("swish", ctx, true); +} +#endif + +/* +TEST(sigm_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_activation_op<::anakin::saber::X86>("sigmoid", ctx, false); +} + +TEST(tanh_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_activation_op<::anakin::saber::X86>("tanh", ctx, false); +} + +TEST(relu6_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_activation_op<::anakin::saber::X86>("relu6", ctx, false); +} + +TEST(swish_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_activation_op<::anakin::saber::X86>("swish", ctx, false); +} +*/ + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(sigmoid); USE_OP(tanh); +USE_OP(relu6); +USE_OP(swish); + +USE_CPU_ANAKIN_CONVERTER(sigmoid); +USE_CPU_ANAKIN_CONVERTER(tanh); +USE_CPU_ANAKIN_CONVERTER(relu6); +USE_CPU_ANAKIN_CONVERTER(swish); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(sigmoid); USE_ANAKIN_CONVERTER(tanh); +USE_ANAKIN_CONVERTER(relu6); +USE_ANAKIN_CONVERTER(swish); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc b/paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc new file mode 100644 index 0000000000..f6399387aa --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/anakin/convert/affine_channel.h" +#include "paddle/fluid/inference/anakin/convert/op_converter.h" +#include "paddle/fluid/inference/anakin/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace anakin { + +template +void test_affine_channel_op(const platform::DeviceContext& context, + bool use_gpu) { + // Declare the difference between the inputs. + std::unordered_set parameters({"scale", "bias"}); + + framework::Scope scope; + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); + validator.DeclInputVar("x", {1, 3, 5, 2}); + validator.DeclOutputVar("out", {1, 3, 5, 2}); + validator.DeclParamVar("scale", {3}); + validator.DeclParamVar("bias", {3}); + + // Prepare Op descriptions. + framework::OpDesc desc; + desc.SetType("affine_channel"); + desc.SetInput("X", {"x"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Scale", {"scale"}); + desc.SetOutput("Out", {"out"}); + + // Layout must be explicitly specified here as NCHW. + desc.SetAttr("data_layout", std::string("NCHW")); + + validator.SetOp(*desc.Proto()); + validator.Execute(1); +} + +#ifdef PADDLE_WITH_CUDA +TEST(affine_channel_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_affine_channel_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(affine_channel_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_affine_channel_op<::anakin::saber::X86>(ctx, false); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +USE_OP(affine_channel); +USE_CPU_ANAKIN_CONVERTER(affine_channel); +#ifdef PADDLE_WITH_CUDA +USE_ANAKIN_CONVERTER(affine_channel); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_batch_norm_op.cc b/paddle/fluid/inference/anakin/convert/test_batch_norm_op.cc index 2832e1c8d1..c008ef1bd5 100644 --- a/paddle/fluid/inference/anakin/convert/test_batch_norm_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_batch_norm_op.cc @@ -19,12 +19,14 @@ namespace paddle { namespace inference { namespace anakin { -TEST(batch_norm_op, test) { +template +void test_batchnorm_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters( {"batch_norm_scale", "batch_norm_bias", "batch_norm_mean", "batch_norm_variance"}); framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); std::vector param_shape{2}; validator.DeclInputVar("batch_norm_X", {1, 2, 5, 5}); @@ -64,8 +66,26 @@ TEST(batch_norm_op, test) { validator.Execute(1, neglected_output); } +#ifdef PADDLE_WITH_CUDA +TEST(batch_norm_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_batchnorm_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(batch_norm_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_batchnorm_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(batch_norm); +USE_CPU_ANAKIN_CONVERTER(batch_norm); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(batch_norm); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_concat_op.cc b/paddle/fluid/inference/anakin/convert/test_concat_op.cc index ecf44def5a..42dfbeb5cd 100644 --- a/paddle/fluid/inference/anakin/convert/test_concat_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_concat_op.cc @@ -21,10 +21,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(concat_op, test) { +template +void test_concat_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters({""}); framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("concat_x1", {1, 2, 1, 1}); validator.DeclInputVar("concat_x2", {1, 3, 1, 1}); validator.DeclInputVar("concat_x3", {1, 1, 1, 1}); @@ -44,31 +46,26 @@ TEST(concat_op, test) { validator.Execute(1); } -TEST(concat_op, test2) { - std::unordered_set parameters({""}); - framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); - validator.DeclInputVar("concat_x1", {1, 4}); - validator.DeclInputVar("concat_x2", {3, 4}); - validator.DeclInputVar("concat_x3", {2, 4}); - validator.DeclOutputVar("concat_out", {6, 4}); - - // Prepare Op description - framework::OpDesc desc; - desc.SetType("concat"); - desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"}); - desc.SetOutput("Out", {"concat_out"}); - - int axis = 0; - desc.SetAttr("axis", axis); - - validator.SetOp(*desc.Proto()); +#ifdef PADDLE_WITH_CUDA +TEST(concat_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_concat_op<::anakin::saber::NV>(ctx, true); +} +#endif - validator.Execute(1); +TEST(concat_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_concat_op<::anakin::saber::X86>(ctx, false); } } // namespace anakin } // namespace inference } // namespace paddle USE_OP(concat); +USE_CPU_ANAKIN_CONVERTER(concat); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(concat); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_conv2d_op.cc b/paddle/fluid/inference/anakin/convert/test_conv2d_op.cc index 6d93e50bc9..e95e11c4f9 100644 --- a/paddle/fluid/inference/anakin/convert/test_conv2d_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_conv2d_op.cc @@ -21,13 +21,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(conv2d_op, test) { - auto* conv2d_converter = - Registry::Global().Lookup("conv2d"); - ASSERT_TRUE(conv2d_converter != nullptr); +template +void test_conv2d_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters({"conv2d-Y"}); framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("conv2d-X", {1, 3, 3, 3}); validator.DeclParamVar("conv2d-Y", {4, 3, 1, 1}); validator.DeclOutputVar("conv2d-Out", {1, 4, 3, 3}); @@ -54,9 +53,27 @@ TEST(conv2d_op, test) { validator.Execute(3); } +#ifdef PADDLE_WITH_CUDA +TEST(conv2d_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_conv2d_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(conv2d_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_conv2d_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(conv2d); +USE_CPU_ANAKIN_CONVERTER(conv2d); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(conv2d); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_dropout_op.cc b/paddle/fluid/inference/anakin/convert/test_dropout_op.cc index b2de5ae0a6..ae27e27ded 100644 --- a/paddle/fluid/inference/anakin/convert/test_dropout_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_dropout_op.cc @@ -21,10 +21,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(dropout_op, native) { +template +void test_dropout_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("x", {1, 1, 2, 2}); validator.DeclOutputVar("out", {1, 1, 2, 2}); validator.DeclOutputVar("mask", {1, 1, 2, 2}); @@ -45,9 +47,26 @@ TEST(dropout_op, native) { validator.Execute(1, neglected_output); } +#ifdef PADDLE_WITH_CUDA +TEST(dropout_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_dropout_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(dropout_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_dropout_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(dropout); +USE_CPU_ANAKIN_CONVERTER(dropout); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(dropout); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_elementwise_op.cc b/paddle/fluid/inference/anakin/convert/test_elementwise_op.cc index 3a437f5fdb..bff7529490 100644 --- a/paddle/fluid/inference/anakin/convert/test_elementwise_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_elementwise_op.cc @@ -21,10 +21,14 @@ namespace paddle { namespace inference { namespace anakin { -static void test_elementwise_op(const std::string &op_type) { +template +static void test_elementwise_op(const std::string& op_type, + const platform::DeviceContext& context, + bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("x", {1, 1, 2, 2}); validator.DeclInputVar("y", {1, 1, 2, 2}); validator.DeclOutputVar("out", {1, 1, 2, 2}); @@ -43,14 +47,41 @@ static void test_elementwise_op(const std::string &op_type) { validator.Execute(1); } -TEST(elementwise_op, native_add) { test_elementwise_op("elementwise_add"); } -TEST(elementwise_op, native_mul) { test_elementwise_op("elementwise_mul"); } +#ifdef PADDLE_WITH_CUDA +TEST(elementwise_op, native_add_gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_elementwise_op<::anakin::saber::NV>("elementwise_add", ctx, true); +} +TEST(elementwise_op, native_mul_gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_elementwise_op<::anakin::saber::NV>("elementwise_mul", ctx, true); +} +#endif + +TEST(elementwise_op, native_add_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_elementwise_op<::anakin::saber::X86>("elementwise_add", ctx, false); +} + +TEST(elementwise_op, native_mul_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_elementwise_op<::anakin::saber::X86>("elementwise_mul", ctx, false); +} } // namespace anakin } // namespace inference } // namespace paddle USE_OP(elementwise_add); -USE_ANAKIN_CONVERTER(elementwise_add); USE_OP(elementwise_mul); +#ifdef PADDLE_WITH_CUDA +USE_ANAKIN_CONVERTER(elementwise_add); USE_ANAKIN_CONVERTER(elementwise_mul); +#endif + +USE_CPU_ANAKIN_CONVERTER(elementwise_add); +USE_CPU_ANAKIN_CONVERTER(elementwise_mul); diff --git a/paddle/fluid/inference/anakin/convert/test_fc_op.cc b/paddle/fluid/inference/anakin/convert/test_fc_op.cc index ee6d1dc291..a24c809c02 100644 --- a/paddle/fluid/inference/anakin/convert/test_fc_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_fc_op.cc @@ -20,13 +20,13 @@ namespace paddle { namespace inference { namespace anakin { -TEST(fc_op, test) { - auto* fc_converter = Registry::Global().Lookup("fc"); - ASSERT_TRUE(fc_converter); - +template +void test_mul_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters({"mul_y"}); framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("mul_x", {1, 1, 2, 2}); validator.DeclParamVar("mul_y", {4, 2}); validator.DeclOutputVar("mul_out", {1, 2}); @@ -42,9 +42,26 @@ TEST(fc_op, test) { validator.Execute(10); } +#ifdef PADDLE_WITH_CUDA +TEST(mul_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_mul_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(mul_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_mul_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(mul); +USE_CPU_ANAKIN_CONVERTER(fc); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(fc); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_flatten_op.cc b/paddle/fluid/inference/anakin/convert/test_flatten_op.cc index d13281f11f..5765f5ebd1 100644 --- a/paddle/fluid/inference/anakin/convert/test_flatten_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_flatten_op.cc @@ -20,13 +20,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(flatten_op, test) { - auto *converter = Registry::Global().Lookup("flatten"); - ASSERT_TRUE(converter); - +template +void test_flatten_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("flatten-X", {3, 10, 10, 4}); validator.DeclOutputVar("flatten-Out", {3, 400, 1, 1}); framework::OpDesc desc; @@ -42,10 +41,27 @@ TEST(flatten_op, test) { validator.Execute(5); } +#ifdef PADDLE_WITH_CUDA +TEST(flatten_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_flatten_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(flatten_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_flatten_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(reshape); USE_OP_ITSELF(flatten); +USE_CPU_ANAKIN_CONVERTER(flatten); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(flatten); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_pool2d_op.cc b/paddle/fluid/inference/anakin/convert/test_pool2d_op.cc index 1ac0194677..90503b1fbb 100644 --- a/paddle/fluid/inference/anakin/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_pool2d_op.cc @@ -19,15 +19,14 @@ namespace paddle { namespace inference { namespace anakin { -void test_pool2d(bool global_pooling, bool ceil_mode, +template +void test_pool2d(const platform::DeviceContext& context, bool use_gpu, + bool global_pooling, bool ceil_mode, std::string pool_type = "max") { - auto* pool2d_converter = - Registry::Global().Lookup("pool2d"); - ASSERT_TRUE(pool2d_converter); - framework::Scope scope; std::unordered_set parameters; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); // The ITensor's Dims should not contain the batch size. // So, the ITensor's Dims of input and output should be C * H * W. @@ -64,56 +63,61 @@ void test_pool2d(bool global_pooling, bool ceil_mode, validator.Execute(1); } -void test_pool2d2(bool global_pooling, bool ceil_mode, - std::string pool_type = "max") { - auto* pool2d_converter = - Registry::Global().Lookup("pool2d"); - ASSERT_TRUE(pool2d_converter); - - framework::Scope scope; - std::unordered_set parameters; - AnakinConvertValidation validator(parameters, &scope); - - // The ITensor's Dims should not contain the batch size. - // So, the ITensor's Dims of input and output should be C * H * W. - validator.DeclInputVar("pool2d_x", {1, 1, 17, 17}); - validator.DeclOutputVar("pool2d_out", {1, 1, 17, 17}); - - // Prepare Op description - framework::OpDesc desc; - desc.SetType("pool2d"); - desc.SetInput("X", {"pool2d_x"}); - desc.SetOutput("Out", {"pool2d_out"}); - - std::vector ksize({3, 3}); - std::vector strides({1, 1}); - std::vector paddings({1, 1}); - std::string pooling_t = pool_type; +#ifdef PADDLE_WITH_CUDA +TEST(Pool2dOpConverter, normal) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_pool2d<::anakin::saber::NV>(ctx, true, false, false); +} +TEST(Pool2dOpConverter, test_global_pooling) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_pool2d<::anakin::saber::NV>(ctx, true, true, false); +} - desc.SetAttr("pooling_type", pooling_t); - desc.SetAttr("ksize", ksize); - desc.SetAttr("strides", strides); - desc.SetAttr("paddings", paddings); - desc.SetAttr("global_pooling", global_pooling); - desc.SetAttr("ceil_mode", true); +TEST(Pool2dOpConverter, max_ceil_test) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_pool2d<::anakin::saber::NV>(ctx, true, false, true); +} - LOG(INFO) << "set OP"; - validator.SetOp(*desc.Proto()); - LOG(INFO) << "execute"; +TEST(Pool2dOpConverter, avg_ceil_test) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_pool2d<::anakin::saber::NV>(ctx, true, false, true, "avg"); +} +#endif - validator.Execute(1); +TEST(Pool2dOpConverter, normal_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_pool2d<::anakin::saber::X86>(ctx, false, false, false); +} +TEST(Pool2dOpConverter, test_global_pooling_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_pool2d<::anakin::saber::X86>(ctx, false, true, false); } -TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); } -TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); } +TEST(Pool2dOpConverter, max_ceil_test_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_pool2d<::anakin::saber::X86>(ctx, false, false, true); +} -TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); } -TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); } -TEST(Pool2dOpConverter, avg_ceil_test2) { test_pool2d2(false, true, "avg"); } +TEST(Pool2dOpConverter, avg_ceil_test_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_pool2d<::anakin::saber::X86>(ctx, false, false, true, "avg"); +} } // namespace anakin } // namespace inference } // namespace paddle USE_OP(pool2d); +USE_CPU_ANAKIN_CONVERTER(pool2d); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(pool2d); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_relu_op.cc b/paddle/fluid/inference/anakin/convert/test_relu_op.cc index 04e624518a..3f22479651 100644 --- a/paddle/fluid/inference/anakin/convert/test_relu_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_relu_op.cc @@ -21,18 +21,23 @@ namespace paddle { namespace inference { namespace anakin { -static void test_activation_op(const std::string &op_type) { - auto *converter = Registry::Global().Lookup(op_type); - PADDLE_ENFORCE(converter != nullptr); +template +static void test_activation_op(const std::string& op_type, + const platform::DeviceContext& context, + bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("act-X", {10, 6, 1, 1}); validator.DeclOutputVar("act-Out", {10, 6, 1, 1}); framework::OpDesc desc; desc.SetType(op_type); desc.SetInput("X", {"act-X"}); desc.SetOutput("Out", {"act-Out"}); + if (op_type == "leaky_relu") { + desc.SetAttr("alpha", 0.1f); + } LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -41,10 +46,30 @@ static void test_activation_op(const std::string &op_type) { validator.Execute(5); } -TEST(sigm_op, test) { test_activation_op("relu"); } +#ifdef PADDLE_WITH_CUDA +TEST(relu_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_activation_op<::anakin::saber::NV>("relu", ctx, true); +} + +TEST(leaky_relu_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_activation_op<::anakin::saber::NV>("leaky_relu", ctx, true); +} +#endif + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(relu); +USE_OP(leaky_relu); +USE_CPU_ANAKIN_CONVERTER(relu); +USE_CPU_ANAKIN_CONVERTER(leaky_relu); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(relu); +USE_ANAKIN_CONVERTER(leaky_relu); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_reshape_op.cc b/paddle/fluid/inference/anakin/convert/test_reshape_op.cc index 306ebf510f..e102bd3ac3 100644 --- a/paddle/fluid/inference/anakin/convert/test_reshape_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_reshape_op.cc @@ -20,12 +20,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(reshape, test) { - auto* converter = Registry::Global().Lookup("reshape"); - ASSERT_TRUE(converter); +template +void test_reshape1_op(const platform::DeviceContext& context, bool use_gpu) { framework::Scope scope; std::unordered_set parameters; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); // validator.DeclInputVar("reshape-X", {2, 3, 3, 1}); // validator.DeclOutputVar("reshape-Out", {3, 2, 1, 3}); @@ -45,10 +45,12 @@ TEST(reshape, test) { validator.Execute(1); } -TEST(reshape, test2) { +template +void test_reshape2_op(const platform::DeviceContext& context, bool use_gpu) { framework::Scope scope; std::unordered_set parameters; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("reshape-X", {1, 2, 4}); validator.DeclOutputVar("reshape-Out", {1, 4, 2}); @@ -66,9 +68,39 @@ TEST(reshape, test2) { validator.Execute(1); } +#ifdef PADDLE_WITH_CUDA +TEST(reshape1_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_reshape1_op<::anakin::saber::NV>(ctx, true); +} + +TEST(reshape2_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_reshape2_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(reshape1_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_reshape2_op<::anakin::saber::X86>(ctx, false); +} + +TEST(reshape2_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_reshape2_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(reshape); +USE_CPU_ANAKIN_CONVERTER(reshape); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(reshape); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_softmax_op.cc b/paddle/fluid/inference/anakin/convert/test_softmax_op.cc index 8c14fae0a6..de0b18fdbf 100644 --- a/paddle/fluid/inference/anakin/convert/test_softmax_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_softmax_op.cc @@ -20,12 +20,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(softmax, test) { - auto* converter = Registry::Global().Lookup("softmax"); - ASSERT_TRUE(converter); +template +void test_softmax_op(const platform::DeviceContext& context, bool use_gpu) { framework::Scope scope; std::unordered_set parameters; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("softmax-X", {1, 10, 2}); validator.DeclOutputVar("softmax-Out", {1, 10, 2}); @@ -41,9 +41,27 @@ TEST(softmax, test) { validator.Execute(1); } +#ifdef PADDLE_WITH_CUDA +TEST(softmax_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_softmax_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(relu_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_softmax_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(softmax); +USE_CPU_ANAKIN_CONVERTER(softmax); + +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(softmax); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_split_op.cc b/paddle/fluid/inference/anakin/convert/test_split_op.cc index aa61c01a51..9a42ffd853 100644 --- a/paddle/fluid/inference/anakin/convert/test_split_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_split_op.cc @@ -21,12 +21,14 @@ namespace paddle { namespace inference { namespace anakin { -template -void AnakinSliceTest(const std::vector &in_shape, +template +void AnakinSliceTest(const platform::DeviceContext &context, bool use_gpu, + const std::vector &in_shape, const std::vector §ions) { std::unordered_set parameters({""}); framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("split_input", in_shape); std::vector output_vars; @@ -55,51 +57,58 @@ void AnakinSliceTest(const std::vector &in_shape, // batch = 0, axis = 1, same shape TEST(split_op, test_same_shape_axis1_batch1) { - AnakinSliceTest<1>({1, 4, 2, 2}, {2, 2}); + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + AnakinSliceTest<::anakin::saber::NV, 1>(ctx, true, {1, 4, 2, 2}, {2, 2}); } // batch = 0, axis = 1, different shape TEST(split_op, test_different_shape_axis1_batch1) { - AnakinSliceTest<1>({1, 3, 2, 2}, {2, 1}); -} -// batch = 10, axis = 1, same shape -TEST(split_op, test_same_shape_axis1_batch10) { - AnakinSliceTest<1>({1, 4, 2, 2}, {2, 2}); -} -// batch = 10, axis = 1, different shape -TEST(split_op, test_different_shape_axis1_batch10) { - AnakinSliceTest<1>({1, 3, 2, 2}, {2, 1}); + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + AnakinSliceTest<::anakin::saber::NV, 1>(ctx, true, {1, 3, 2, 2}, {2, 1}); } // batch = 0, axis = 2, same shape TEST(split_op, test_same_shape_axis2_batch1) { - AnakinSliceTest<2>({1, 3, 4, 2}, {2, 2}); + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + AnakinSliceTest<::anakin::saber::NV, 2>(ctx, true, {1, 3, 4, 2}, {2, 2}); } // batch = 0, axis = 2, different shape TEST(split_op, test_different_shape_axis2_batch1) { - AnakinSliceTest<2>({1, 3, 3, 2}, {2, 1}); -} -// batch = 10, axis = 2, same shape -TEST(split_op, test_same_shape_axis2_batch10) { - AnakinSliceTest<2>({1, 3, 4, 2}, {2, 2}); -} -// batch = 10, axis = 2, different shape -TEST(split_op, test_different_shape_axis2_batch10) { - AnakinSliceTest<2>({1, 3, 3, 2}, {2, 1}); + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + AnakinSliceTest<::anakin::saber::NV, 2>(ctx, true, {1, 3, 3, 2}, {2, 1}); } + // batch = 0, axis = 3, same shape TEST(split_op, test_same_shape_axis3_batch1) { - AnakinSliceTest<3>({1, 3, 2, 4}, {2, 2}); + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + AnakinSliceTest<::anakin::saber::NV, 3>(ctx, true, {1, 3, 2, 4}, {2, 2}); } // batch = 0, axis = 3, different shape TEST(split_op, test_different_shape_axis3_batch1) { - AnakinSliceTest<3>({1, 3, 2, 3}, {2, 1}); + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + AnakinSliceTest<::anakin::saber::NV, 3>(ctx, true, {1, 3, 2, 3}, {2, 1}); } -// batch = 10, axis = 3, same shape -TEST(split_op, test_same_shape_axis3_batch10) { - AnakinSliceTest<3>({1, 3, 2, 4}, {2, 2}); + +TEST(split_op, test_different_shape_axis1_batch1_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + AnakinSliceTest<::anakin::saber::X86, 1>(ctx, false, {1, 3, 2, 3}, {2, 1}); +} + +TEST(split_op, test_different_shape_axis2_batch1_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + AnakinSliceTest<::anakin::saber::X86, 2>(ctx, false, {1, 3, 4, 2}, {2, 2}); } -// batch = 10, axis = 3, different shape -TEST(split_op, test_different_shape_axis3_batch10) { - AnakinSliceTest<3>({1, 3, 2, 3}, {2, 1}); + +TEST(split_op, test_different_shape_axis3_batch1_cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + AnakinSliceTest<::anakin::saber::X86, 3>(ctx, false, {1, 3, 2, 4}, {2, 2}); } } // namespace anakin @@ -107,4 +116,7 @@ TEST(split_op, test_different_shape_axis3_batch10) { } // namespace paddle USE_OP(split); +USE_CPU_ANAKIN_CONVERTER(split); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(split); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_sum_op.cc b/paddle/fluid/inference/anakin/convert/test_sum_op.cc index d6a59a0166..65f67ebd12 100644 --- a/paddle/fluid/inference/anakin/convert/test_sum_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_sum_op.cc @@ -22,10 +22,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(sum, native) { +template +static void test_sum_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("sum_x1", {1, 2, 1, 2}); validator.DeclInputVar("sum_x2", {1, 2, 1, 2}); validator.DeclOutputVar("sum_out", {1, 2, 1, 2}); @@ -40,9 +42,26 @@ TEST(sum, native) { validator.Execute(1); } +#ifdef PADDLE_WITH_CUDA +TEST(sum_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_sum_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(sum_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_sum_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(sum); +USE_CPU_ANAKIN_CONVERTER(sum); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(sum); +#endif diff --git a/paddle/fluid/inference/anakin/convert/test_transpose_op.cc b/paddle/fluid/inference/anakin/convert/test_transpose_op.cc index 016ed26f02..51b69dfbb0 100644 --- a/paddle/fluid/inference/anakin/convert/test_transpose_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_transpose_op.cc @@ -20,12 +20,12 @@ namespace paddle { namespace inference { namespace anakin { -TEST(transpose_op, test) { - auto* converter = Registry::Global().Lookup("transpose"); - ASSERT_TRUE(converter != nullptr); +template +void test_transpose1_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("transpose-X", {2, 3, 4, 5}); validator.DeclOutputVar("transpose-Out", {4, 2, 5, 3}); @@ -43,11 +43,12 @@ TEST(transpose_op, test) { validator.Execute(3); } -// test input shape's dims < 4 -TEST(transpose_op, test2) { +template +void test_transpose2_op(const platform::DeviceContext& context, bool use_gpu) { std::unordered_set parameters; framework::Scope scope; - AnakinConvertValidation validator(parameters, &scope); + AnakinConvertValidation validator( + parameters, &scope, context, use_gpu); validator.DeclInputVar("transpose-X", {3, 4, 5}); validator.DeclOutputVar("transpose-Out", {3, 5, 4}); @@ -65,9 +66,38 @@ TEST(transpose_op, test2) { validator.Execute(1); } +#ifdef PADDLE_WITH_CUDA +TEST(transpose1_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_transpose1_op<::anakin::saber::NV>(ctx, true); +} + +TEST(transpose2_op, gpu) { + platform::CUDAPlace gpu_place(0); + platform::CUDADeviceContext ctx(gpu_place); + test_transpose2_op<::anakin::saber::NV>(ctx, true); +} +#endif + +TEST(transpose1_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_transpose2_op<::anakin::saber::X86>(ctx, false); +} + +TEST(transpose2_op, cpu) { + platform::CPUPlace cpu_place; + platform::CPUDeviceContext ctx(cpu_place); + test_transpose2_op<::anakin::saber::X86>(ctx, false); +} + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(transpose); +USE_CPU_ANAKIN_CONVERTER(transpose); +#ifdef PADDLE_WITH_CUDA USE_ANAKIN_CONVERTER(transpose); +#endif diff --git a/paddle/fluid/inference/anakin/convert/transpose.cc b/paddle/fluid/inference/anakin/convert/transpose.cc index f35372fe5c..28071ca844 100644 --- a/paddle/fluid/inference/anakin/convert/transpose.cc +++ b/paddle/fluid/inference/anakin/convert/transpose.cc @@ -17,20 +17,16 @@ #include #include -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::saber::NV; -using anakin::saber::Shape; using anakin::PTuple; namespace paddle { namespace inference { namespace anakin { -void TransposeOpConverter::operator()(const framework::proto::OpDesc &op, - const framework::BlockDesc &block_desc, - const framework::Scope &scope, - bool test_mode) { +template +void TransposeOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { framework::OpDesc op_desc(op, nullptr); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); @@ -38,7 +34,7 @@ void TransposeOpConverter::operator()(const framework::proto::OpDesc &op, auto input = op_desc.Input("X").front(); auto output = op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); - engine_->AddOp(op_name, "Permute", {input}, {output}); + this->engine_->AddOp(op_name, "Permute", {input}, {output}); auto axis = boost::get>(op_desc.GetAttr("axis")); size_t axis_size = axis.size(); @@ -46,7 +42,7 @@ void TransposeOpConverter::operator()(const framework::proto::OpDesc &op, axis.push_back(axis_size); axis_size += 1; } - engine_->AddOpAttr>(op_name, "dims", axis); + this->engine_->template AddOpAttr>(op_name, "dims", axis); } } // namespace anakin diff --git a/paddle/fluid/inference/anakin/convert/transpose.h b/paddle/fluid/inference/anakin/convert/transpose.h index bacbf152bc..b7b0a0f209 100644 --- a/paddle/fluid/inference/anakin/convert/transpose.h +++ b/paddle/fluid/inference/anakin/convert/transpose.h @@ -20,7 +20,8 @@ namespace paddle { namespace inference { namespace anakin { -class TransposeOpConverter : public AnakinOpConverter { +template +class TransposeOpConverter : public AnakinOpConverter { public: TransposeOpConverter() = default; diff --git a/paddle/fluid/inference/anakin/convert/ut_helper.h b/paddle/fluid/inference/anakin/convert/ut_helper.h index 029aff6704..2f8f953892 100644 --- a/paddle/fluid/inference/anakin/convert/ut_helper.h +++ b/paddle/fluid/inference/anakin/convert/ut_helper.h @@ -32,14 +32,8 @@ limitations under the License. */ #include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/platform/enforce.h" -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; using anakin::Precision; -using anakin::saber::NV; using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; -using anakin::PTuple; namespace paddle { namespace inference { @@ -55,8 +49,8 @@ float random(float low, float high) { return dist(mt); } -void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, - const platform::DeviceContext& ctx) { +void RandomizeTensor(framework::LoDTensor* tensor, + const platform::Place& place) { auto dims = tensor->dims(); size_t num_elements = analysis::AccuDims(dims, dims.size()); PADDLE_ENFORCE_GT(num_elements, 0); @@ -78,17 +72,19 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, * anakin * layer. */ +template class AnakinConvertValidation { - using AnakinNvEngineT = AnakinEngine; + using AnakinNvEngineT = AnakinEngine; public: AnakinConvertValidation() = delete; AnakinConvertValidation(const std::unordered_set& parameters, - framework::Scope* scope) - : parameters_(parameters), scope_(scope), place_(0) { - PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); - engine_.reset(new AnakinEngine(true)); + framework::Scope* scope, + const platform::DeviceContext& ctx, + bool use_gpu = true) + : parameters_(parameters), scope_(scope), ctx_(ctx), use_gpu_(use_gpu) { + engine_.reset(new AnakinEngine(true)); } // Declare a Variable as input with random initialization. @@ -108,11 +104,10 @@ class AnakinConvertValidation { } void DeclVar(const std::string& name, const std::vector dim_vec) { - platform::CUDADeviceContext ctx(place_); auto* x = scope_->Var(name); auto* x_tensor = x->GetMutable(); x_tensor->Resize(framework::make_ddim(dim_vec)); - RandomizeTensor(x_tensor, place_, ctx); + RandomizeTensor(x_tensor, ctx_.GetPlace()); std::vector dim_vec_int64; for (auto& ele : dim_vec) { @@ -132,7 +127,7 @@ class AnakinConvertValidation { // should init anakin engine here. auto& block_desc = program_desc_.Block(framework::kRootBlockIndex); - Singleton::Global().ConvertOp( + Singleton>::Global().ConvertOp( desc, block_desc, parameters_, *scope_, engine_.get(), true /*test_mode*/); engine_->Freeze(); @@ -151,7 +146,7 @@ class AnakinConvertValidation { } engine_->SetMaxInputShape(temp_max_input_shape); engine_->Optimize(); - engine_->InitGraph(); + engine_->InitNet(); } // We use the set 'neglected_output' here, because some Ops like batch norm, @@ -160,11 +155,8 @@ class AnakinConvertValidation { void Execute(int batch_size, std::unordered_set neglected_output = {}) { // Execute Fluid Op - platform::CUDADeviceContext ctx(place_); - op_->Run(*scope_, place_); + op_->Run(*scope_, ctx_.GetPlace()); - // std::vector input_vector; - // std::vector output_vector; std::map inputs; for (const auto& input : op_desc_->InputArgumentNames()) { if (parameters_.count(input)) continue; @@ -180,20 +172,27 @@ class AnakinConvertValidation { std::vector fluid_out; auto* var = scope_->FindVar(output); auto tensor = var->GetMutable(); - framework::TensorToVector(*tensor, ctx, &fluid_out); + framework::TensorToVector(*tensor, ctx_, &fluid_out); fluid_outputs.push_back(fluid_out); outputs.insert({output, tensor}); } - engine_->Execute(inputs, outputs, stream_); + if (!use_gpu_) { + engine_->Execute(inputs, outputs); + } else { + cudaStream_t stream; + PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream), 0); + engine_->Execute(inputs, outputs, stream); + } + int i_output = 0; for (const auto& output : op_desc_->OutputArgumentNames()) { if (neglected_output.count(output)) continue; std::vector anakin_out; auto* var = scope_->FindVar(output); auto tensor = var->GetMutable(); - framework::TensorToVector(*tensor, ctx, &anakin_out); + framework::TensorToVector(*tensor, ctx_, &anakin_out); size_t anakin_out_size = anakin_out.size(); auto fluid_out = fluid_outputs[i_output++]; @@ -205,15 +204,24 @@ class AnakinConvertValidation { private: std::unique_ptr engine_{nullptr}; - cudaStream_t stream_; std::unique_ptr op_; std::unique_ptr op_desc_; framework::ProgramDesc program_desc_; const std::unordered_set& parameters_; framework::Scope* scope_; - platform::CUDAPlace place_; + const platform::DeviceContext& ctx_; + bool use_gpu_{true}; }; +template class AnakinConvertValidation<::anakin::saber::NV, + ::anakin::Precision::FP32>; +template class AnakinConvertValidation<::anakin::saber::X86, + ::anakin::Precision::FP32>; + +template class AnakinConvertValidation<::anakin::saber::NV, + ::anakin::Precision::INT8>; +template class AnakinConvertValidation<::anakin::saber::X86, + ::anakin::Precision::INT8>; } // namespace anakin } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/anakin/engine.cc b/paddle/fluid/inference/anakin/engine.cc index ba044c9401..529a859458 100644 --- a/paddle/fluid/inference/anakin/engine.cc +++ b/paddle/fluid/inference/anakin/engine.cc @@ -35,12 +35,15 @@ namespace anakin { template AnakinEngine::AnakinEngine( bool need_summary, int device, int max_batch_size, - std::map> max_input_shape) + std::map> max_input_shape, + std::vector program_inputs, bool auto_config_layout) : graph_(new AnakinGraphT()), net_(new AnakinNetT(need_summary)) { device_ = device; max_batch_size_ = max_batch_size; max_input_shape_ = max_input_shape; + program_inputs_ = program_inputs; + auto_config_layout_ = auto_config_layout; } template @@ -54,8 +57,8 @@ void AnakinEngine::SetInputShape( } template -void AnakinEngine::InitGraph() { - net_->init(*graph_); +void AnakinEngine::InitNet() { + net_->init(*graph_, auto_config_layout_); } template @@ -67,11 +70,11 @@ void AnakinEngine::AddOp( } template -void AnakinEngine::Execute( - const std::map &inputs, - const std::map &outputs, - cudaStream_t stream) { +void AnakinEngine::BindInput( + const std::map &inputs) { +#ifdef PADDLE_WITH_CUDA cudaDeviceSynchronize(); +#endif for (const auto &input : inputs) { auto *tensor = input.second; auto *data = tensor->data(); @@ -85,16 +88,53 @@ void AnakinEngine::Execute( int max_shape_sum = std::accumulate(max_input_shape.begin(), max_input_shape.end(), 1, std::multiplies()); - - PADDLE_ENFORCE(max_shape_sum >= tensor->numel(), - "The anakin input max shape should be greater than" - " or equal to the real input shape, Please set the max " - "input shape using EnableAnakinEngine"); + if (tensor->numel() > max_shape_sum) { + PADDLE_ENFORCE(std::find(program_inputs_.begin(), program_inputs_.end(), + input.first) == program_inputs_.end(), + "The anakin input max shape should be greater than" + " or equal to the real input shape, Please set the max " + "input shape using EnableAnakinEngine"); + VLOG(3) << "Anakin Net will be reset because of the inputs out of range: " + << input.first; + graph_->Reshape(input.first, fluid_input_shape); + net_.reset(new AnakinNetT(true)); + net_->init(*graph_); + anakin_input = net_->get_in(input.first); + } anakin_input->reshape(fluid_input_shape); ::anakin::saber::Tensor tmp_anakin_tensor(data, TargetT(), 0, fluid_input_shape); anakin_input->copy_from(tmp_anakin_tensor); } +} + +template +void AnakinEngine::Execute( + const std::map &inputs, + const std::map &outputs) { + BindInput(inputs); + net_->prediction(); + for (const auto &output : outputs) { + platform::CPUPlace cpu_place; + auto *tensor = output.second; + auto *anakin_output = net_->get_out(output.first); + auto *anakin_data = anakin_output->data(); + auto anakin_output_shape = anakin_output->valid_shape(); + tensor->Resize(framework::make_ddim(anakin_output_shape)); + auto *fluid_data = tensor->mutable_data(cpu_place); + memory::Copy(cpu_place, static_cast(fluid_data), cpu_place, + static_cast(anakin_data), + tensor->numel() * sizeof(float)); + } +} + +#ifdef PADDLE_WITH_CUDA +template +void AnakinEngine::Execute( + const std::map &inputs, + const std::map &outputs, + cudaStream_t stream) { + BindInput(inputs); net_->prediction(); cudaDeviceSynchronize(); for (const auto &output : outputs) { @@ -111,10 +151,11 @@ void AnakinEngine::Execute( } cudaDeviceSynchronize(); } +#endif template void AnakinEngine::Freeze() { - PADDLE_ENFORCE(graph_->Freeze_v3(), "Freeze anakin subgraph."); + PADDLE_ENFORCE(graph_->Freeze(), "Freeze anakin subgraph."); } template @@ -122,6 +163,12 @@ void AnakinEngine::Optimize() { PADDLE_ENFORCE(graph_->Optimize(), "Graph optimization."); } +template +void AnakinEngine::RegistBlock( + ::anakin::PBlock *block_p) { + PADDLE_ENFORCE(graph_->RegistBlock(block_p), "Block register."); +} + template std::unique_ptr> AnakinEngine::Clone() { @@ -130,7 +177,24 @@ AnakinEngine::Clone() { return std::unique_ptr(engine); } +#ifdef PADDLE_WITH_CUDA template class AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>; +template class AnakinEngineManager<::anakin::saber::NV, + ::anakin::Precision::FP32>; + +template class AnakinEngine<::anakin::saber::NV, ::anakin::Precision::INT8>; +template class AnakinEngineManager<::anakin::saber::NV, + ::anakin::Precision::INT8>; +#endif + +template class AnakinEngine<::anakin::saber::X86, ::anakin::Precision::FP32>; +template class AnakinEngineManager<::anakin::saber::X86, + ::anakin::Precision::FP32>; +template class AnakinEngine<::anakin::saber::X86, ::anakin::Precision::INT8>; +template class AnakinEngineManager<::anakin::saber::X86, + ::anakin::Precision::INT8>; + +// template class AnakinEngine<::anakin::saber::X86, ::anakin::Precision::FP32>; } // namespace anakin } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/anakin/engine.h b/paddle/fluid/inference/anakin/engine.h index 4845ffdf5b..fb40f56511 100644 --- a/paddle/fluid/inference/anakin/engine.h +++ b/paddle/fluid/inference/anakin/engine.h @@ -32,7 +32,6 @@ #include "saber/saber_types.h" using anakin::Precision; -using anakin::saber::NV; namespace anakin { @@ -58,9 +57,11 @@ class AnakinEngine { public: explicit AnakinEngine( bool need_summary = false, int device = 0, int max_batch_size = 1, - std::map> max_input_shape = {}); + std::map> max_input_shape = {}, + std::vector program_inputs = {}, + bool auto_config_layout = false); ~AnakinEngine(); - void InitGraph(); + void InitNet(); void SetInputShape(const std::string &name, std::vector shape); void AddOp(const std::string &name, const std::string &type, const std::vector &inputs, @@ -81,20 +82,35 @@ class AnakinEngine { void SetMaxInputShape(std::map> shape) { max_input_shape_ = shape; } + const std::vector &GetScalableInputs() { + return program_inputs_; + } + void SetScalableInputs(std::vector program_inputs) { + program_inputs_ = program_inputs; + } int GetMaxBatchSize() { return max_batch_size_; } void Freeze(); void Optimize(); - void AllocTmpMem() { - PADDLE_ENFORCE(net_->alloc_memory_first(*graph_), - "anakin alloc temp memory first failed"); - } + void RegistBlock(::anakin::PBlock *block_p); void Save(std::string path) { graph_->save(path); } - bool IsInit() { return initialized_; } int GetDevice() { return device_; } + void AddTensorScale(const std::string &tensor_name, float scale) { + tensor_scales_[tensor_name] = scale; + } + std::unordered_map GetTensorScales() { + return tensor_scales_; + } + void Execute(const std::map &inputs, + const std::map &outputs); +#ifdef PADDLE_WITH_CUDA void Execute(const std::map &inputs, const std::map &outputs, cudaStream_t stream); +#endif + + private: + void BindInput(const std::map &inputs); private: bool initialized_{false}; @@ -103,27 +119,33 @@ class AnakinEngine { int device_; std::unique_ptr graph_; std::unique_ptr net_; + std::vector program_inputs_; + std::unordered_map tensor_scales_; + // Always be false in gpu mode but true in most cpu cases. + bool auto_config_layout_; }; +template class AnakinEngineManager { - using AnakinNvEngineT = AnakinEngine; + using AnakinEngineT = AnakinEngine; public: bool HasEngine(const std::string &name) const { if (engines_.count(name) == 0) return false; return engines_.at(name).get() != nullptr; } - AnakinNvEngineT *Get(const std::string &name) const { + AnakinEngineT *Get(const std::string &name) const { return engines_.at(name).get(); } - AnakinNvEngineT *Create( - bool need_summary, int device, int max_batch_size, - std::map> max_input_shape, - std::string engine_name) { + AnakinEngineT *Create(bool need_summary, int device, int max_batch_size, + std::map> max_input_shape, + std::vector program_inputs, + bool auto_config_layout, std::string engine_name) { std::unique_lock lk(mut_); - auto *p = new AnakinEngine( - need_summary, device, max_batch_size, max_input_shape); + auto *p = new AnakinEngine( + need_summary, device, max_batch_size, max_input_shape, program_inputs, + auto_config_layout); engines_[engine_name].reset(p); return p; } @@ -135,7 +157,7 @@ class AnakinEngineManager { } private: - std::unordered_map> engines_; + std::unordered_map> engines_; std::mutex mut_; }; } // namespace anakin diff --git a/paddle/fluid/inference/anakin/op_teller.cc b/paddle/fluid/inference/anakin/op_teller.cc index 2042fb18ea..67b771226c 100644 --- a/paddle/fluid/inference/anakin/op_teller.cc +++ b/paddle/fluid/inference/anakin/op_teller.cc @@ -44,6 +44,11 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("sum"); teller_set.insert("depthwise_conv2d"); teller_set.insert("prior_box"); + teller_set.insert("leaky_relu"); + teller_set.insert("affine_channel"); + teller_set.insert("relu6"); + teller_set.insert("swish"); + teller_set.insert("shuffle_channel"); } bool operator()(const std::string& op_type, diff --git a/paddle/fluid/inference/anakin/test_anakin_engine.cc b/paddle/fluid/inference/anakin/test_anakin_engine.cc index 8fd6b8bec9..422f415a5d 100644 --- a/paddle/fluid/inference/anakin/test_anakin_engine.cc +++ b/paddle/fluid/inference/anakin/test_anakin_engine.cc @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/inference/anakin/engine.h" -using anakin::graph::GraphGlobalMem; using anakin::AK_FLOAT; using anakin::Precision; using anakin::saber::NV; @@ -52,11 +51,9 @@ TEST_F(TestAnakinEngine, Execute) { engine_->AddOpAttr("op1", "axis", 1); std::vector shape = {1, 1, 1, 2}; Shape tmp_shape(shape); - // PBlock weight1(tmp_shape); - auto *weight1 = - GraphGlobalMem::Global().template new_block(tmp_shape); - // auto *weight1 = new PBlock(tmp_shape, AK_FLOAT); + PBlock *weight1 = new PBlock(tmp_shape, AK_FLOAT); + engine_->RegistBlock(weight1); float *cpu_data = static_cast(weight1->h_tensor().mutable_data()); cpu_data[0] = 2.; weight1->d_tensor().set_shape(tmp_shape); @@ -68,7 +65,7 @@ TEST_F(TestAnakinEngine, Execute) { // engine_->AddOpAttr("x", "input_shape", input_shape); engine_->SetInputShape("x", {1, 1, 1, 1}); engine_->Optimize(); - engine_->InitGraph(); + engine_->InitNet(); framework::LoDTensor x; framework::LoDTensor y; x.Resize({1, 1, 1, 1}); diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a736ca393c..66e8d8b528 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -64,20 +64,20 @@ struct Argument { bool Has(const std::string& key) const { return valid_fields_.count(key); } -#define DECL_ARGUMENT_FIELD(field__, Field, type__) \ - public: \ - type__& field__() { \ - PADDLE_ENFORCE(Has(#field__)); \ - return field__##_; \ - } \ - void Set##Field(const type__& x) { \ - field__##_ = x; \ - valid_fields_.insert(#field__); \ - } \ - DECL_ARGUMENT_FIELD_VALID(field__); \ - type__* field__##_ptr() { return &field__##_; } \ - \ - private: \ +#define DECL_ARGUMENT_FIELD(field__, Field, type__) \ + public: \ + type__& field__() { \ + PADDLE_ENFORCE(Has(#field__), "There is no such field"); \ + return field__##_; \ + } \ + void Set##Field(const type__& x) { \ + field__##_ = x; \ + valid_fields_.insert(#field__); \ + } \ + DECL_ARGUMENT_FIELD_VALID(field__); \ + type__* field__##_ptr() { return &field__##_; } \ + \ + private: \ type__ field__##_; #define DECL_ARGUMENT_FIELD_VALID(field__) \ @@ -169,7 +169,14 @@ struct Argument { anakin_max_shape_t); DECL_ARGUMENT_FIELD(anakin_max_batch_size, AnakinMaxBatchSize, int); DECL_ARGUMENT_FIELD(anakin_min_subgraph_size, AnakinMinSubgraphSize, int); + DECL_ARGUMENT_FIELD(anakin_precision_mode, AnakinPrecisionMode, + AnalysisConfig::Precision); + DECL_ARGUMENT_FIELD(anakin_auto_config_layout, AnakinAutoConfigLayout, bool); DECL_ARGUMENT_FIELD(use_anakin, UseAnakin, bool); + DECL_ARGUMENT_FIELD(anakin_passes_filter, AnakinPassesFilter, + std::vector); + DECL_ARGUMENT_FIELD(anakin_ops_filter, AnakinOpsFilter, + std::vector); // Memory optimized related. DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 78e502c670..4714c30507 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -114,6 +114,7 @@ void IRPassManager::CreatePasses(Argument *argument, if (pass_name == "anakin_subgraph_pass") { pass->Set("program", new framework::ProgramDesc *(&argument->main_program())); + pass->Set("use_gpu", new bool(argument->use_gpu())); pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("model_from_memory", new bool(argument->model_from_memory())); pass->Set("engine_opt_info", new std::map( @@ -122,6 +123,13 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("max_input_shape", new std::map>( argument->anakin_max_input_shape())); pass->Set("max_batch_size", new int(argument->anakin_max_batch_size())); + bool enable_int8 = + argument->anakin_precision_mode() == AnalysisConfig::Precision::kInt8; + pass->Set("enable_int8", new bool(enable_int8)); + pass->Set("anakin_ops_filter", + new std::vector(argument->anakin_ops_filter())); + pass->Set("auto_config_layout", + new bool(argument->anakin_auto_config_layout())); } pre_pass = pass_name; diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc index b8d8b6fed8..9586ce3e6b 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc @@ -39,8 +39,14 @@ void analysis::AnakinSubgraphPass::ApplyImpl( framework::ir::Graph *graph) const { framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph); - auto teller = [](const framework::ir::Node *node) { - if (!node->IsOp() || !node->Op()) return false; + auto &anakin_ops_filter = Get>("anakin_ops_filter"); + + auto teller = [&anakin_ops_filter](const framework::ir::Node *node) { + if (!node->IsOp() || !node->Op()) + return false; + else if (std::find(anakin_ops_filter.begin(), anakin_ops_filter.end(), + node->Op()->Type()) != anakin_ops_filter.end()) + return false; return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); }; @@ -191,22 +197,78 @@ void AnakinSubgraphPass::CreateAnakinOp( SetAttr(op_desc->Proto(), "engine_key", engine_key); auto max_input_shape = Get>>("max_input_shape"); - auto max_batch_size = Get("max_batch_size"); + auto program_inputs = program_desc->GetFeedTargetNames(); - auto *anakin_engine = - inference::Singleton::Global().Create( - true, Get("gpu_device_id"), max_batch_size, max_input_shape, - engine_key); + bool use_gpu = Get("use_gpu"); + SetAttr(op_desc->Proto(), "use_gpu", use_gpu); + bool enable_int8 = Get("enable_int8"); + SetAttr(op_desc->Proto(), "enable_int8", enable_int8); + if (enable_int8) { + CreateAnakinEngine<::anakin::Precision::INT8>(&block_desc, params, + input_names, output_mapping, + program_inputs, engine_key); + } else { + CreateAnakinEngine<::anakin::Precision::FP32>(&block_desc, params, + input_names, output_mapping, + program_inputs, engine_key); + } +} + +template <::anakin::Precision PrecisionT> +void AnakinSubgraphPass::CreateAnakinEngine( + framework::BlockDesc *block_desc, const std::vector ¶ms, + const std::set &input_names, + const std::vector &output_mapping, + const std::vector &program_inputs, + const std::string &engine_key) const { + framework::BlockDesc block_desc_temp(nullptr, block_desc->Proto()); + bool use_gpu = Get("use_gpu"); + auto max_batch_size = Get("max_batch_size"); + auto max_input_shape = + Get>>("max_input_shape"); + bool auto_config_layout = Get("auto_config_layout"); + if (use_gpu) { +#ifdef PADDLE_WITH_CUDA + inference::Singleton< + anakin::AnakinEngineManager<::anakin::saber::NV, PrecisionT>>::Global() + .Create(true, Get("gpu_device_id"), max_batch_size, + max_input_shape, program_inputs, false, engine_key); +#endif + } else { + inference::Singleton< + anakin::AnakinEngineManager<::anakin::saber::X86, PrecisionT>>::Global() + .Create(true, Get("gpu_device_id"), max_batch_size, + max_input_shape, program_inputs, auto_config_layout, + engine_key); + } auto *scope = param_scope(); std::unordered_set param_set(params.begin(), params.end()); - framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto()); - - inference::Singleton::Global() - .ConvertBlockToAnakinEngine( - &block_desc_temp, scope, - std::vector(input_names.begin(), input_names.end()), - param_set, output_mapping, anakin_engine); + if (use_gpu) { +#ifdef PADDLE_WITH_CUDA + auto *anakin_engine = + inference::Singleton>::Global() + .Get(engine_key); + inference::Singleton>::Global() + .ConvertBlockToAnakinEngine( + &block_desc_temp, scope, + std::vector(input_names.begin(), input_names.end()), + param_set, output_mapping, anakin_engine); +#endif + } else { + auto *anakin_engine = + inference::Singleton>::Global() + .Get(engine_key); + inference::Singleton>::Global() + .ConvertBlockToAnakinEngine( + &block_desc_temp, scope, + std::vector(input_names.begin(), input_names.end()), + param_set, output_mapping, anakin_engine); + } } } // namespace analysis diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h index e80b8bb612..4ab2297b2d 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include #include #include "paddle/fluid/framework/ir/pass.h" @@ -36,6 +37,13 @@ class AnakinSubgraphPass : public framework::ir::FusePassBase { const std::vector &graph_params, std::vector *repetitive_params) const; void CleanIntermediateOutputs(framework::ir::Node *node); + template <::anakin::Precision PrecisionT> + void CreateAnakinEngine(framework::BlockDesc *block_desc, + const std::vector ¶ms, + const std::set &input_names, + const std::vector &output_mapping, + const std::vector &program_inputs, + const std::string &engine_key) const; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc b/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc index 7c4aab06a1..8f7c6ac755 100644 --- a/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc +++ b/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc @@ -100,7 +100,6 @@ void RenameAndGetOutputs( const std::string arg_value = in_var->arguments(k); const std::string arg_value_with_id = arg_value + std::to_string(var2id[arg_value]); - if (input_names_with_id.count(arg_value_with_id)) { replaced_names.push_back(arg_value); if (graph_var_map.count(arg_value)) { @@ -149,7 +148,6 @@ void RenameAndGetOutputs( const std::string arg_value = out_var->arguments(k); const std::string arg_value_with_id = arg_value + std::to_string(var2id[arg_value]); - if (graph_var_map.count(arg_value)) { add_block_var(arg_value, arg_value_with_id); } diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index b54ea269ff..8b940b67e3 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -116,6 +116,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(anakin_max_batchsize_); CP_MEMBER(anakin_max_input_shape_); CP_MEMBER(anakin_min_subgraph_size_); + CP_MEMBER(anakin_precision_mode_); + CP_MEMBER(anakin_auto_config_layout_); + CP_MEMBER(anakin_passes_filter_); + CP_MEMBER(anakin_ops_filter_); // Ir related. CP_MEMBER(enable_ir_optim_); @@ -269,13 +273,18 @@ void AnalysisConfig::Update() { PADDLE_ENFORCE(!use_tensorrt_, "Anakin sub-graph and TensorRT sub-graph are not allowed to " "run at the same time!"); - PADDLE_ENFORCE( - use_gpu_, - "Anakin sub-graph engine need gpu, please use the EnableGpu API."); + if (use_gpu_) { + LOG(INFO) << "Run Anakin GPU mode"; + } else { + LOG(INFO) << "Run Anakin CPU mode"; + } pass_builder()->ClearPasses(); for (const auto &pass : kAnakinSubgraphPasses) { - pass_builder()->AppendPass(pass); + if (std::find(anakin_passes_filter_.begin(), anakin_passes_filter_.end(), + pass) == anakin_passes_filter_.end()) { + pass_builder()->AppendPass(pass); + } } } @@ -390,11 +399,17 @@ void AnalysisConfig::SwitchIrDebug(int x) { } void AnalysisConfig::EnableAnakinEngine( int max_batch_size, std::map> max_input_shape, - int min_subgraph_size) { + int min_subgraph_size, AnalysisConfig::Precision precision_mode, + bool auto_config_layout, std::vector passes_filter, + std::vector ops_filter) { anakin_max_batchsize_ = max_batch_size; anakin_max_input_shape_ = max_input_shape; anakin_min_subgraph_size_ = min_subgraph_size; + anakin_passes_filter_ = passes_filter; + anakin_ops_filter_ = ops_filter; use_anakin_ = true; + anakin_precision_mode_ = precision_mode; + anakin_auto_config_layout_ = auto_config_layout; Update(); } } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a84c909b3b..321107377c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -383,10 +383,14 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); } - if (config_.use_gpu() && config_.anakin_engine_enabled()) { + if (config_.anakin_engine_enabled()) { argument_.SetAnakinMaxBatchSize(config_.anakin_max_batchsize_); argument_.SetAnakinMaxInputShape(config_.anakin_max_input_shape_); argument_.SetAnakinMinSubgraphSize(config_.anakin_min_subgraph_size_); + argument_.SetAnakinPrecisionMode(config_.anakin_precision_mode_); + argument_.SetAnakinAutoConfigLayout(config_.anakin_auto_config_layout_); + argument_.SetAnakinPassesFilter(config_.anakin_passes_filter_); + argument_.SetAnakinOpsFilter(config_.anakin_ops_filter_); LOG(INFO) << "Anakin subgraph engine is enabled"; } @@ -929,4 +933,9 @@ USE_ANAKIN_CONVERTER(density_prior_box); USE_ANAKIN_CONVERTER(dropout); USE_ANAKIN_CONVERTER(sum); USE_ANAKIN_CONVERTER(prior_box); +USE_ANAKIN_CONVERTER(leaky_relu); +USE_ANAKIN_CONVERTER(affine_channel); +USE_ANAKIN_CONVERTER(relu6); +USE_ANAKIN_CONVERTER(swish); +USE_ANAKIN_CONVERTER(shuffle_channel); #endif diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index c67c4b5bd0..ebe289322b 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -152,7 +152,10 @@ struct AnalysisConfig { void EnableAnakinEngine( int max_batch_size = 1, std::map> max_input_shape = {}, - int min_subgraph_size = 6); + int min_subgraph_size = 6, Precision precision = Precision::kFloat32, + bool auto_config_layout = false, + std::vector passes_filter = {}, + std::vector ops_filter = {}); /** A boolean state indicating whether the Anakin sub-graph engine is used. */ @@ -291,6 +294,10 @@ struct AnalysisConfig { int anakin_max_batchsize_; int anakin_min_subgraph_size_{6}; std::map> anakin_max_input_shape_; + Precision anakin_precision_mode_; + bool anakin_auto_config_layout_{false}; + std::vector anakin_passes_filter_; + std::vector anakin_ops_filter_; std::map engine_opt_info_; bool use_mkldnn_quantizer_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2fba560ac2..2a7bd55a76 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -73,15 +73,15 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); } // The following passes works for Anakin sub-graph engine. const std::vector kAnakinSubgraphPasses({ "infer_clean_graph_pass", // + "quant_conv2d_dequant_fuse_pass", // "simplify_anakin_priorbox_detection_out_pass", // "fillconstant_elementwisemul_fuse", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // - "conv_bn_fuse_pass", // - "conv_elementwise_add_fuse_pass", // "fc_gru_fuse_pass", // - "quant_conv2d_dequant_fuse_pass", // - "anakin_subgraph_pass", + "shuffle_channel_detect_pass", // + "anakin_subgraph_pass", // + "fc_gru_fuse_pass", // }); GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { diff --git a/paddle/fluid/inference/check_symbol.sh b/paddle/fluid/inference/check_symbol.sh index 12b7b3e7e5..b6b7d1f20b 100755 --- a/paddle/fluid/inference/check_symbol.sh +++ b/paddle/fluid/inference/check_symbol.sh @@ -4,7 +4,7 @@ lib=$1 if [ $# -ne 1 ]; then echo "No input library"; exit -1 ; fi num_paddle_syms=$(nm -D ${lib} | grep paddle | wc -l) -num_google_syms=$(nm -D ${lib} | grep google | grep -v paddle | grep T | wc -l) +num_google_syms=$(nm -D ${lib} | grep google | grep -v paddle | grep "T " | wc -l) if [ $num_paddle_syms -le 0 ]; then echo "Have no paddle symbols"; exit -1 ; fi if [ $num_google_syms -ge 1 ]; then echo "Have some google symbols"; exit -1 ; fi diff --git a/paddle/fluid/operators/anakin/anakin_engine_op.h b/paddle/fluid/operators/anakin/anakin_engine_op.h index e4feb14b22..11c394c76c 100644 --- a/paddle/fluid/operators/anakin/anakin_engine_op.h +++ b/paddle/fluid/operators/anakin/anakin_engine_op.h @@ -34,28 +34,17 @@ limitations under the License. */ namespace paddle { namespace operators { -using FluidDT = framework::proto::VarType_Type; using inference::Singleton; - -using anakin::graph::GraphGlobalMem; -using anakin::AK_FLOAT; -using anakin::Precision; -using anakin::saber::NV; -using anakin::saber::X86; -using anakin::saber::Shape; -using anakin::PBlock; -using anakin::PTuple; using inference::anakin::AnakinEngine; class AnakinEngineOp : public framework::OperatorBase { - using AnakinNvEngineT = AnakinEngine; - private: std::vector input_names_; std::unordered_set param_names_; - mutable AnakinNvEngineT *anakin_engine_; std::string engine_key_; std::string engine_serialized_data_; + bool use_gpu_; + bool enable_int8_; public: AnakinEngineOp(const std::string &type, @@ -66,10 +55,11 @@ class AnakinEngineOp : public framework::OperatorBase { input_names_ = Inputs("Xs"); engine_key_ = Attr("engine_key"); auto params = Attr>("parameters"); + use_gpu_ = Attr("use_gpu"); + enable_int8_ = Attr("enable_int8"); for (const auto ¶m : params) { param_names_.insert(param); } - anakin_engine_ = nullptr; } protected: @@ -80,19 +70,12 @@ class AnakinEngineOp : public framework::OperatorBase { void RunAnakin(const framework::Scope &scope, const platform::Place &dev_place) const { - auto *engine = GetEngine(scope, dev_place); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - auto stream = - reinterpret_cast(dev_ctx).stream(); - PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); std::vector output_maps = Attr>("output_name_mapping"); std::map inputs; - // Convert input tensor from fluid to engine. for (const auto &x : Inputs("Xs")) { if (param_names_.count(x)) continue; auto &t = @@ -110,17 +93,38 @@ class AnakinEngineOp : public framework::OperatorBase { outputs.insert({output_maps[output_index], fluid_t}); output_index += 1; } - engine->Execute(inputs, outputs, stream); + if (enable_int8_) { + Execute<::anakin::Precision::INT8>(inputs, outputs, dev_place); + } else { + Execute<::anakin::Precision::FP32>(inputs, outputs, dev_place); + } } - AnakinNvEngineT *GetEngine(const framework::Scope &scope, - const platform::Place &dev_place) const { - if (anakin_engine_ == nullptr) { - anakin_engine_ = - inference::Singleton::Global() + template <::anakin::Precision PrecisionT> + void Execute(const std::map &inputs, + const std::map &outputs, + const platform::Place &dev_place) const { + if (use_gpu_) { +#ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + auto stream = + reinterpret_cast(dev_ctx) + .stream(); + auto *engine = + inference::Singleton>::Global() + .Get(engine_key_); + engine->Execute(inputs, outputs, stream); +#endif + } else { + auto *engine = + inference::Singleton>::Global() .Get(engine_key_); + engine->Execute(inputs, outputs); } - return anakin_engine_; } }; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 236afc77f7..b650225c64 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/inference/api/analysis_predictor.h" @@ -229,6 +230,15 @@ void BindAnalysisConfig(py::module *m) { py::arg("min_subgraph_size") = 3, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, py::arg("use_static") = true) + .def("enable_anakin_engine", &AnalysisConfig::EnableAnakinEngine, + py::arg("max_batch_size") = 1, + py::arg("max_input_shape") = + std::map>(), + py::arg("min_subgraph_size") = 6, + py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, + py::arg("auto_config_layout") = false, + py::arg("passes_filter") = std::vector(), + py::arg("ops_filter") = std::vector()) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, py::arg("x") = true)