From 276950291aa672a23f2fffa3d62f51300504783b Mon Sep 17 00:00:00 2001 From: nhzlx Date: Tue, 21 Aug 2018 11:19:28 +0000 Subject: [PATCH 01/10] 1. fix ssa bug with batchnorm, 2. refine the trt --- paddle/fluid/inference/analysis/analyzer.cc | 3 ++- .../analysis/data_flow_graph_to_fluid_pass.cc | 5 ----- .../analysis/data_flow_graph_to_fluid_pass.h | 3 --- .../analysis/fluid_to_data_flow_graph_pass.cc | 2 +- .../api/api_tensorrt_subgraph_engine.cc | 13 +++++++++++- .../inference/api/paddle_inference_api.h | 8 +++++++ paddle/fluid/operators/tensorrt_engine_op.cc | 4 ++-- paddle/fluid/operators/tensorrt_engine_op.h | 21 +++++++++++++------ .../operators/tensorrt_engine_op_test.cc | 7 +++++-- 9 files changed, 45 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 9318f10897..912615c945 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -44,7 +44,8 @@ class DfgPassManagerImpl final : public DfgPassManager { if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( - {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"}); + {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", + "depthwise_conv2d", "batch_norm"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index f40d471cbf..ce0639a616 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -23,9 +23,6 @@ namespace paddle { namespace inference { -DEFINE_int32(tensorrt_max_batchsize, 1, "TensorRT maximum batch size"); -DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size"); - namespace analysis { using framework::proto::ProgramDesc; @@ -190,8 +187,6 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, // Set attrs SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); - SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize); - SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "output_name_mapping", output_mapping); node->SetPbMsg(desc.Proto()->SerializeAsString()); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h index 59c47365aa..0c9a8a0b7c 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h @@ -27,9 +27,6 @@ namespace paddle { namespace inference { -DECLARE_int32(tensorrt_max_batchsize); -DECLARE_int32(tensorrt_workspace_size); - namespace analysis { class DataFlowGraphToFluidPass final : public DataFlowGraphPass { public: diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 511631d3e0..16d82b5aa1 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -92,6 +92,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k))); in->outlinks.push_back(o); o->inlinks.push_back(in); + unique_written_vars.insert(in); } } for (int j = 0; j < op.outputs_size(); j++) { @@ -112,7 +113,6 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { } out->inlinks.push_back(o); o->outlinks.push_back(out); - unique_written_vars.insert(out); } } } diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 45b5a7638b..9ac0372971 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/operators/tensorrt_engine_op.h" @@ -32,7 +33,8 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { bool Init(const std::shared_ptr& parent_scope) { VLOG(3) << "Predictor::init()"; - + FLAGS_tensorrt_max_batch_size = config_.max_batch_size; + FLAGS_tensorrt_workspace_size = config_.workspace_size; if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { @@ -150,3 +152,12 @@ CreatePaddlePredictor( } } // namespace paddle + +USE_TRT_CONVERTER(elementwise_add_weight); +USE_TRT_CONVERTER(mul); +USE_TRT_CONVERTER(conv2d); +USE_TRT_CONVERTER(relu); +USE_TRT_CONVERTER(fc); +USE_TRT_CONVERTER(pool2d); +USE_TRT_CONVERTER(softmax); +USE_TRT_CONVERTER(batch_norm); diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 794534467b..da6c2cfc21 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -137,6 +137,14 @@ struct AnakinConfig : public PaddlePredictor::Config { struct TensorRTConfig : public NativeConfig { // Determine whether a subgraph will be executed by TRT. int min_subgraph_size{1}; + // While TensorRT allows an engine optimized for a given max batch size + // to run at any smaller size, the performance for those smaller + // sizes may not be as well-optimized. Therefore, Max batch is best + // equivalent to the runtime batch size. + int max_batch_size{1}; + // For workspace_size, refer it from here: + // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting + int workspace_size{1 << 30}; }; // A factory to help create different predictors. diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 4d930e9cec..1048d30171 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -22,6 +22,8 @@ namespace paddle { DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); +DEFINE_int32(tensorrt_max_batch_size, 1, "TensorRT maximum batch size"); +DEFINE_int32(tensorrt_workspace_size, 16 << 20, "TensorRT workspace size"); namespace operators { @@ -32,8 +34,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); AddAttr("engine_uniq_key", "unique key for the TRT engine."); - AddAttr("max_batch", "the maximum batch size."); - AddAttr("max_workspace", "the maximum batch size."); AddComment("TensorRT engine operator."); } }; diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index f2ec7f066a..bc556ab364 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -28,6 +28,8 @@ namespace paddle { DECLARE_int32(tensorrt_engine_batch_size); +DECLARE_int32(tensorrt_max_batch_size); +DECLARE_int32(tensorrt_workspace_size); namespace operators { @@ -54,8 +56,10 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { "TensorRT' tensor input requires at least 2 dimensions"); PADDLE_ENFORCE_LE(shape.size(), 4UL, "TensorRT' tensor input requires at most 4 dimensions"); - PADDLE_ENFORCE_EQ(shape.size(), 4UL); - return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); + PADDLE_ENFORCE(shape.size() == 4UL || shape.size() == 2UL); + if (shape.size() == 4UL) + return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); + return nvinfer1::DimsCHW(shape[1], 1, 1); } } // namespace @@ -95,7 +99,7 @@ class TensorRTEngineKernel : public framework::OpKernel { auto input_names = context.op().Inputs("Xs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, - context.Attr("max_batch")); + FLAGS_tensorrt_max_batch_size); std::vector output_maps = context.Attr>("output_name_mapping"); @@ -132,7 +136,12 @@ class TensorRTEngineKernel : public framework::OpKernel { nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. - std::vector ddim(dims.d, dims.d + dims.nbDims); + // The ITensor doesn't contain the batch size dim. + std::vector ddim; + ddim.push_back(FLAGS_tensorrt_engine_batch_size); + for (int i = 0; i < dims.nbDims; i++) { + ddim.push_back(dims.d[i]); + } auto* fluid_v = context.scope().FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); @@ -168,8 +177,8 @@ class TensorRTEngineKernel : public framework::OpKernel { // Get the ProgramDesc and pass to convert. framework::proto::BlockDesc block_desc; block_desc.ParseFromString(context.Attr("subgraph")); - int max_batch = context.Attr("max_batch"); - auto max_workspace = context.Attr("max_workspace"); + int max_batch = FLAGS_tensorrt_max_batch_size; + auto max_workspace = FLAGS_tensorrt_workspace_size; auto params = context.Attr>("parameters"); std::unordered_set parameters; for (const auto& param : params) { diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 97c375361f..27c1d29762 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/tensorrt_engine_op.h" #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -57,6 +58,8 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block, using inference::analysis::SetAttr; TEST(TensorRTEngineOp, manual) { + FLAGS_tensorrt_engine_batch_size = 2; + FLAGS_tensorrt_max_batch_size = 2; framework::ProgramDesc program; auto* block_ = program.Proto()->add_blocks(); block_->set_idx(0); @@ -98,8 +101,6 @@ TEST(TensorRTEngineOp, manual) { engine_op_desc.SetOutput("Ys", std::vector({"z0"})); SetAttr(engine_op_desc.Proto(), "subgraph", block_->SerializeAsString()); - SetAttr(engine_op_desc.Proto(), "max_batch", 100); - SetAttr(engine_op_desc.Proto(), "max_workspace", 1 << 10); SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr>(engine_op_desc.Proto(), "parameters", std::vector({})); @@ -128,6 +129,8 @@ TEST(TensorRTEngineOp, manual) { } void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { + FLAGS_tensorrt_engine_batch_size = batch_size; + FLAGS_tensorrt_max_batch_size = batch_size; framework::ProgramDesc program; framework::Scope scope; platform::CUDAPlace place; From a2c0e52f3e8a63c2a80b5b62073d372cfbe0e6c6 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 22 Aug 2018 10:33:48 +0800 Subject: [PATCH 02/10] speed up while_op --- paddle/fluid/operators/while_op.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 733157ea05..48e37796e1 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase { PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), "Condition of while op must in CPU memory."); + + auto ctx = executor.Prepare(*program, block->ID()); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); - - executor.Run(*program, ¤t_scope, block->ID(), - false /*create_local_scope*/); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false); } } }; @@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase { framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); auto *program = block->Program(); + auto ctx = executor.Prepare(*program, block->ID()); auto *step_scopes = scope.FindVar(Input(kStepScopes))->GetMutable(); @@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase { } } } - - executor.Run(*program, *cur_scope_iter, block->ID(), false); + executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); auto &pg_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); From f8c6b4641573e1b4ae449253822eda486cb01646 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 22 Aug 2018 10:50:49 +0800 Subject: [PATCH 03/10] fix profiler test --- python/paddle/fluid/tests/unittests/test_profiler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index 38a7c913bf..7934164b84 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -25,9 +25,6 @@ import paddle.fluid.core as core class TestProfiler(unittest.TestCase): def net_profiler(self, state, profile_path='/tmp/profile'): - enable_if_gpu = state == 'GPU' or state == "All" - if enable_if_gpu and not core.is_compiled_with_cuda(): - return startup_program = fluid.Program() main_program = fluid.Program() @@ -81,8 +78,6 @@ class TestProfiler(unittest.TestCase): pass_acc_calculator.add(value=acc, weight=b_size) pass_acc = pass_acc_calculator.eval() - @unittest.skipIf(not core.is_compiled_with_cuda(), - "profiler is enabled only with GPU") def test_cpu_profiler(self): self.net_profiler('CPU') From bc4f53754fc02f012ca5512639a790230a6324f7 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 22 Aug 2018 11:34:04 +0800 Subject: [PATCH 04/10] Doc: append PADDLE_ENFORCE rules to new_op.md (#12727) * doc: append PADDLE_ENFORCE rules to new_op_cn.md * doc: polish writing * refactor: polish doc based on advice --- doc/fluid/dev/new_op_cn.md | 67 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/doc/fluid/dev/new_op_cn.md b/doc/fluid/dev/new_op_cn.md index 587d819f79..63d471ff52 100644 --- a/doc/fluid/dev/new_op_cn.md +++ b/doc/fluid/dev/new_op_cn.md @@ -334,3 +334,70 @@ ctest -R test_mul_op - 注册Op时的类型名,需要和该Op的名字一样。即不允许在`A_op.cc`里面,注册`REGISTER_OPERATOR(B, ...)`等,这将会导致单元测试出错。 - 如果Op没有实现CUDA Kernel,请不要创建空的`*_op.cu`,这将会导致单元测试出错。 - 如果多个Op依赖一些共用的函数,可以创建非`*_op.*`格式的文件来存放,如`gather.h`文件。 + +### PADDLE_ENFORCE使用注意 + +实现Op时检查数据的合法性需要使用PADDLE_ENFORCE以及PADDLE_ENFORCE_EQ等宏定义,基本格式如下: + +``` +PADDLE_ENFORCE(表达式, 错误提示信息) +PADDLE_ENFORCE_EQ(比较对象A, 比较对象B, 错误提示信息) +``` + +如果表达式为真,或者比较对象A=B,则检查通过,否则会终止程序运行,向用户反馈相应的错误提示信息。 +为了确保提示友好易懂,开发者需要注意其使用方法。 + +#### 总体原则 + +任何使用了PADDLE_ENFORCE与PADDLE_ENFORCE_**检查的地方,必须有详略得当的备注解释!**错误提示信息**不能为空! + +#### 提示信息书写标准 + +1. [required] 哪里错了?为什么错了? + - 例如:`ValueError: Mismatched label shape` +2. [optional] 期望的输入是什么样的?实际的输入是怎样的? + - 例如:`Expected labels dimension=1. Received 4.` +3. [optional] 能否给出修改意见? + - 例如:`Suggested Fix:If your classifier expects one-hot encoding label,check your n_classes argument to the estimatorand/or the shape of your label.Otherwise, check the shape of your label.` + +如果并非必要或者简洁的描述即可表达清楚以上要点,根据情况书写亦可。 + +##### FAQ 典型问题 + +1. 无报错信息或报错信息过于简单,不能给用户提供有效的提示! + +问题示例1 :未写提示信息 +``` +PADDLE_ENFORCE(ctx->HasInput("X"), ""); +``` +问题示例2 :提示信息过于简单 +``` +PADDLE_ENFORCE(i != nullptr, "I must be set"); // I是什么? +``` + +2. 在报错信息中使用开发人员定义的变量缩写,不易理解! + +问题示例: +``` +PADDLE_ENFORCE(forward_pd != nullptr, + "Fail to find eltwise_fwd_pd in device context"); //eltwise_fwd_pd用户可能看不懂 +``` + +#### OP InferShape检查提示信息特别说明 + +- 检查输入输出变量,请统一遵循以下格式 +`Input(变量名) of OP名 operator should not be null.` + +正确示例: +``` +PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTMP operator should not be null."); +``` + +- 反向Op的输入输出检查,要写明反向Op的名字 + +正确示例: +``` +PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LoDResetGrad opreator should not be null."); +``` From 9ee698e6059e15488100e1b905100031cfb357e5 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 22 Aug 2018 13:09:03 +0800 Subject: [PATCH 05/10] enhance/ditu rnn with fc fuse (#12831) * make fc fuse work with ditu rnn * add ditu rnn data download to CMAKE --- paddle/fluid/framework/ir/graph_helper.cc | 2 +- .../fluid/inference/analysis/CMakeLists.txt | 45 ++- paddle/fluid/inference/analysis/analyzer.cc | 3 +- paddle/fluid/inference/analysis/analyzer.h | 3 +- .../inference/analysis/analyzer_tester.cc | 266 +++++++++++++++++- paddle/fluid/inference/api/CMakeLists.txt | 5 +- paddle/fluid/inference/api/api_impl.cc | 7 +- paddle/fluid/inference/api/helper.h | 110 ++++++++ .../inference/api/paddle_inference_api.h | 2 + paddle/fluid/operators/mul_op.cc | 6 +- 10 files changed, 423 insertions(+), 26 deletions(-) create mode 100644 paddle/fluid/inference/api/helper.h diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index b1c19e6535..dc81a2cac5 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -104,7 +104,7 @@ std::map> BuildOperationAdjList( for (auto &adj_n : var->inputs) { PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); adj_list[n].insert(adj_n); - VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) + VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) << " -> " << n->Name() << reinterpret_cast(n) << " via " << var->Name() << reinterpret_cast(var); } diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index f1271ddb75..4feaed2b0d 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -22,7 +22,7 @@ function (inference_analysis_test TARGET) if(WITH_TESTING) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS) + set(multiValueArgs SRCS EXTRA_DEPS) cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(mem_opt "") @@ -31,22 +31,43 @@ function (inference_analysis_test TARGET) endif() cc_test(${TARGET} SRCS "${analysis_test_SRCS}" - DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass + DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS} ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) endif(WITH_TESTING) endfunction(inference_analysis_test) -cc_test(test_analyzer SRCS analyzer_tester.cc DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis - # ir - fc_fuse_pass - graph_viz_pass - infer_clean_graph_pass - graph_pattern_detecter - pass - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -#set_tests_properties(test_analyzer PROPERTIES DEPENDS test_word2vec) -#inference_api_test(test_analyzer SRC analyzer_tester.cc ARGS test_word2vec) +set(DITU_RNN_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fmodel.tar.gz") +set(DITU_RNN_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fdata.txt.tar.gz") +set(DITU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/ditu_rnn" CACHE PATH "Ditu RNN model and data root." FORCE) +set(DITU_RNN_MODEL ${DITU_INSTALL_DIR}/model) +set(DITU_RNN_DATA ${DITU_INSTALL_DIR}/data.txt) + +function (inference_download_and_uncompress target url gz_filename) + message(STATUS "Download inference test stuff ${gz_filename} from ${url}") + execute_process(COMMAND bash -c "mkdir -p ${DITU_INSTALL_DIR}") + execute_process(COMMAND bash -c "cd ${DITU_INSTALL_DIR} && wget -q ${url}") + execute_process(COMMAND bash -c "cd ${DITU_INSTALL_DIR} && tar xzf ${gz_filename}") + message(STATUS "finish downloading ${gz_filename}") +endfunction(inference_download_and_uncompress) + +if (NOT EXISTS ${DITU_INSTALL_DIR}) + inference_download_and_uncompress(ditu_rnn_model ${DITU_RNN_MODEL_URL} "ditu_rnn_fluid%2Fmodel.tar.gz") + inference_download_and_uncompress(ditu_rnn_data ${DITU_RNN_DATA_URL} "ditu_rnn_fluid%2Fdata.txt.tar.gz") +endif() + +inference_analysis_test(test_analyzer SRCS analyzer_tester.cc + EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis + # ir + fc_fuse_pass + graph_viz_pass + infer_clean_graph_pass + graph_pattern_detecter + infer_clean_graph_pass + pass + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model + --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model + --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index fc8b1f6864..7d16364609 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -23,8 +23,6 @@ #include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" -namespace paddle { - DEFINE_bool(IA_enable_tensorrt_subgraph_engine, false, "Enable subgraph to TensorRT engine for acceleration"); @@ -35,6 +33,7 @@ DEFINE_string(IA_graphviz_log_root, "./", DEFINE_string(IA_output_storage_path, "", "optimized model output path"); +namespace paddle { namespace inference { namespace analysis { diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index a72875d36f..2e107c82dd 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -39,8 +39,6 @@ limitations under the License. */ #include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass_manager.h" -namespace paddle { - // TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this // flag if not available. DECLARE_bool(IA_enable_tensorrt_subgraph_engine); @@ -48,6 +46,7 @@ DECLARE_string(IA_graphviz_log_root); DECLARE_string(IA_output_storage_path); DECLARE_bool(IA_enable_ir); +namespace paddle { namespace inference { namespace analysis { diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 3be336dd5c..52f5c4f5ae 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -13,11 +13,17 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/analyzer.h" + #include +#include #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/inference/analysis/ut_helper.h" +#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); +DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); + namespace paddle { namespace inference { namespace analysis { @@ -38,7 +44,7 @@ TEST(Analyzer, analysis_with_tensorrt) { analyser.Run(&argument); } -void TestWord2vecPrediction(const std::string& model_path) { +void TestWord2vecPrediction(const std::string &model_path) { NativeConfig config; config.model_dir = model_path; config.use_gpu = false; @@ -69,12 +75,245 @@ void TestWord2vecPrediction(const std::string& model_path) { // The outputs' buffers are in CPU memory. for (size_t i = 0; i < std::min(5UL, num_elements); i++) { LOG(INFO) << "data: " - << static_cast(outputs.front().data.data())[i]; - PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], + << static_cast(outputs.front().data.data())[i]; + PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], result[i]); } } +namespace { + +struct DataRecord { + std::vector>> link_step_data_all; + std::vector> week_data_all, minute_data_all; + std::vector lod1, lod2, lod3; + std::vector> rnn_link_data, rnn_week_datas, + rnn_minute_datas; + size_t batch_iter{0}; + size_t batch_size{1}; + DataRecord() = default; + DataRecord(const std::string &path, int batch_size = 1) + : batch_size(batch_size) { + Load(path); + } + DataRecord NextBatch() { + DataRecord data; + size_t batch_end = batch_iter + batch_size; + // NOTE skip the final batch, if no enough data is provided. + if (batch_end <= link_step_data_all.size()) { + data.link_step_data_all.assign(link_step_data_all.begin() + batch_iter, + link_step_data_all.begin() + batch_end); + data.week_data_all.assign(week_data_all.begin() + batch_iter, + week_data_all.begin() + batch_end); + data.minute_data_all.assign(minute_data_all.begin() + batch_iter, + minute_data_all.begin() + batch_end); + // Prepare LoDs + data.lod1.push_back(0); + data.lod2.push_back(0); + data.lod3.push_back(0); + CHECK(!data.link_step_data_all.empty()) << "empty"; + CHECK(!data.week_data_all.empty()); + CHECK(!data.minute_data_all.empty()); + CHECK_EQ(data.link_step_data_all.size(), data.week_data_all.size()); + CHECK_EQ(data.minute_data_all.size(), data.link_step_data_all.size()); + for (size_t j = 0; j < data.link_step_data_all.size(); j++) { + for (const auto &d : data.link_step_data_all[j]) { + data.rnn_link_data.push_back(d); + } + data.rnn_week_datas.push_back(data.week_data_all[j]); + data.rnn_minute_datas.push_back(data.minute_data_all[j]); + // calculate lod + data.lod1.push_back(data.lod1.back() + + data.link_step_data_all[j].size()); + data.lod3.push_back(data.lod3.back() + 1); + for (size_t i = 1; i < data.link_step_data_all[j].size() + 1; i++) { + data.lod2.push_back(data.lod2.back() + + data.link_step_data_all[j].size()); + } + } + } + batch_iter += batch_size; + return data; + } + void Load(const std::string &path) { + std::ifstream file(path); + std::string line; + int num_lines = 0; + while (std::getline(file, line)) { + num_lines++; + std::vector data; + split(line, ':', &data); + std::vector> link_step_data; + std::vector link_datas; + split(data[0], '|', &link_datas); + for (auto &step_data : link_datas) { + std::vector tmp; + split_to_float(step_data, ',', &tmp); + link_step_data.push_back(tmp); + } + // load week data + std::vector week_data; + split_to_float(data[2], ',', &week_data); + // load minute data + std::vector minute_data; + split_to_float(data[1], ',', &minute_data); + link_step_data_all.push_back(std::move(link_step_data)); + week_data_all.push_back(std::move(week_data)); + minute_data_all.push_back(std::move(minute_data)); + } + } +}; +void PrepareInputs(std::vector *input_slots, DataRecord *data, + int batch_size) { + // DataRecord data(FLAGS_datapath, batch_size); + PaddleTensor lod_attention_tensor, init_zero_tensor, lod_tensor_tensor, + week_tensor, minute_tensor; + lod_attention_tensor.name = "data_lod_attention"; + init_zero_tensor.name = "cell_init"; + lod_tensor_tensor.name = "data"; + week_tensor.name = "week"; + minute_tensor.name = "minute"; + auto one_batch = data->NextBatch(); + // clang-format off + std::vector rnn_link_data_shape + ({static_cast(one_batch.rnn_link_data.size()), static_cast(one_batch.rnn_link_data.front().size())}); + lod_attention_tensor.shape.assign({1, 2}); + lod_attention_tensor.lod.assign({one_batch.lod1, one_batch.lod2}); + init_zero_tensor.shape.assign({batch_size, 15}); + init_zero_tensor.lod.assign({one_batch.lod3}); + lod_tensor_tensor.shape = rnn_link_data_shape; + lod_tensor_tensor.lod.assign({one_batch.lod1}); + week_tensor.shape.assign({(int) one_batch.rnn_week_datas.size(), (int) one_batch.rnn_week_datas.front().size()}); + week_tensor.lod.assign({one_batch.lod3}); + minute_tensor.shape.assign({(int) one_batch.rnn_minute_datas.size(), + (int) one_batch.rnn_minute_datas.front().size()}); + minute_tensor.lod.assign({one_batch.lod3}); + // assign data + TensorAssignData(&lod_attention_tensor, std::vector>({{0, 0}})); + std::vector tmp_zeros(batch_size * 15, 0.); + TensorAssignData(&init_zero_tensor, {tmp_zeros}); + TensorAssignData(&lod_tensor_tensor, one_batch.rnn_link_data); + TensorAssignData(&week_tensor, one_batch.rnn_week_datas); + TensorAssignData(&minute_tensor, one_batch.rnn_minute_datas); + // clang-format on + // Set inputs. + auto init_zero_tensor1 = init_zero_tensor; + init_zero_tensor1.name = "hidden_init"; + input_slots->assign({week_tensor, init_zero_tensor, minute_tensor, + init_zero_tensor1, lod_attention_tensor, + lod_tensor_tensor}); + for (auto &tensor : *input_slots) { + tensor.dtype = PaddleDType::FLOAT32; + } +} + +std::string DescribeTensor(const PaddleTensor &tensor) { + std::stringstream os; + os << "Tensor [" << tensor.name << "]\n"; + os << " - type: "; + switch (tensor.dtype) { + case PaddleDType::FLOAT32: + os << "float32"; + break; + case PaddleDType::INT64: + os << "int64"; + break; + default: + os << "unset"; + } + os << '\n'; + + os << " - shape: " << to_string(tensor.shape) << '\n'; + os << " - lod: "; + for (auto &l : tensor.lod) { + os << to_string(l) << "; "; + } + os << "\n"; + os << " - data: "; + + // clang-format off + int dim = std::accumulate(tensor.shape.begin(), + tensor.shape.end(), + 1, + [](int a, int b) { return a * b; }); // clang-format on + for (size_t i = 0; i < dim; i++) { + os << static_cast(tensor.data.data())[i] << " "; + } + os << '\n'; + return os.str(); +} + +} // namespace + +const float ditu_rnn_target_data[] = { + 104.711, 11.2431, 1.35422, 0, 0, 0, 0, 0, + 27.7039, 1.41486, 7.09526, 0, 0, 0, 0, 0, + 7.6481, 6.5324, 56.383, 2.88018, 8.92918, 132.007, 4.27429, 2.02934, + 14.1727, 10.7461, 25.0616, 16.0197, 14.4163, 16.9199, 6.75517, 0, + 80.0249, 4.77739, 0, 0, 0, 0, 0, 0, + 47.5643, 2.67029, 8.76252, 0, 0, 0, 0, 0, + 51.8822, 4.4411, 0, 0, 0, 0, 0, 0, + 10.7286, 12.0595, 10.6672, 0, 0, 0, 0, 0, + 93.5771, 3.84641, 0, 0, 0, 0, 0, 0, + 169.426, 0, 0, 0, 0, 0, 0, 0}; +// Test with a really complicate model. +void TestDituRNNPrediction(const std::string &model_path, + const std::string &data_path, int batch_size, + bool use_analysis, bool activate_ir, + int num_times = 1) { + FLAGS_IA_enable_ir = activate_ir; + FLAGS_IA_enable_tensorrt_subgraph_engine = false; + FLAGS_IA_output_storage_path = "./analysis.out"; + + std::string model_out; + if (use_analysis) { + Argument argument(model_path); + argument.model_output_store_path.reset(new std::string("./analysis.out")); + + Analyzer analyzer; + analyzer.Run(&argument); + + // Should get the transformed model stored to ./analysis.out + model_out = "./analysis.out"; + ASSERT_TRUE(PathExists(model_out)); + } else { + model_out = FLAGS_infer_ditu_rnn_model; + } + + NativeConfig config; + config.prog_file = model_out + "/__model__"; + config.param_file = model_out + "/param"; + config.use_gpu = false; + config.device = 0; + config.specify_input_name = true; + + auto predictor = + CreatePaddlePredictor(config); + std::vector input_slots; + DataRecord data(data_path, batch_size); + // Prepare inputs. + PrepareInputs(&input_slots, &data, batch_size); + std::vector outputs; + + Timer timer; + timer.tic(); + for (int i = 0; i < num_times; i++) { + predictor->Run(input_slots, &outputs); + } + LOG(INFO) << "time/batch: " << timer.toc() / num_times; + + for (auto &out : outputs) { + size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, + [](int a, int b) { return a * b; }); + float *data = static_cast(out.data.data()); + for (int i = 0; + i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size); + i++) { + EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3); + } + } +} + // Turn on the IR pass supportion, run a real inference and check the result. TEST(Analyzer, SupportIRPass) { FLAGS_IA_enable_ir = true; @@ -94,6 +333,27 @@ TEST(Analyzer, SupportIRPass) { TestWord2vecPrediction("./analysis.out"); } +// Directly infer with the original model. +TEST(Analyzer, DituRNN_without_analysis) { + TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, + 10, false, false); +} + +// Inference with the original model with the analysis turned on, the analysis +// module will transform the program to a data flow graph. +TEST(Analyzer, DituRNN_with_analysis) { + LOG(INFO) << "ditu rnn with analysis"; + TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, + 10, true, false, 1); +} + +// Inference with analysis and IR. The IR module will fuse some large kernels. +TEST(Analyzer, DituRNN_with_analysis_with_IR) { + LOG(INFO) << "ditu rnn with analysis and IR fuse"; + TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, + 10, true, true, 1); +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index ce6c8f0474..6da9a6385f 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -18,7 +18,10 @@ if(APPLE) endif(APPLE) -set(inference_deps paddle_inference_api paddle_fluid_api) +set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager + graph_viz_pass fc_fuse_pass + infer_clean_graph_pass + ) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine) diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index e31c637e96..32a691b81f 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -137,8 +137,11 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, return false; } for (size_t i = 0; i < feed_target_names_.size(); ++i) { - VLOG(4) << "setting " << i << "-th target"; - feed_targets[feed_target_names_[i]] = &feeds[i]; + if (config_.specify_input_name) { + feed_targets[inputs[i].name] = &feeds[i]; + } else { + feed_targets[feed_target_names_[i]] = &feeds[i]; + } } // get fetch variable std::map fetch_targets; diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h new file mode 100644 index 0000000000..2c166cc062 --- /dev/null +++ b/paddle/fluid/inference/api/helper.h @@ -0,0 +1,110 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +namespace paddle { +namespace inference { + +// Timer for timer +class Timer { + public: + double start; + double startu; + void tic() { + struct timeval tp; + gettimeofday(&tp, NULL); + start = tp.tv_sec; + startu = tp.tv_usec; + } + double toc() { + struct timeval tp; + gettimeofday(&tp, NULL); + double used_time_ms = + (tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0; + return used_time_ms; + } +}; + +void split(const std::string &str, char sep, std::vector *pieces) { + pieces->clear(); + if (str.empty()) { + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } +} +void split_to_float(const std::string &str, char sep, std::vector *fs) { + std::vector pieces; + split(str, sep, &pieces); + std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs), + [](const std::string &v) { return std::stof(v); }); +} +template +std::string to_string(const std::vector &vec) { + std::stringstream ss; + for (const auto &c : vec) { + ss << c << " "; + } + return ss.str(); +} +template <> +std::string to_string>( + const std::vector> &vec) { + std::stringstream ss; + for (const auto &piece : vec) { + ss << to_string(piece) << "\n"; + } + return ss.str(); +} +template <> +std::string to_string>>( + const std::vector>> &vec) { + std::stringstream ss; + for (const auto &line : vec) { + for (const auto &rcd : line) { + ss << to_string(rcd) << ";\t"; + } + ss << '\n'; + } + return ss.str(); +} +// clang-format off +void TensorAssignData(PaddleTensor *tensor, const std::vector> &data) { + // Assign buffer + int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; }); + tensor->data.Resize(sizeof(float) * dim); + int c = 0; + for (const auto &f : data) { + for (float v : f) { static_cast(tensor->data.data())[c++] = v; } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index da6c2cfc21..3b36377274 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -120,6 +120,8 @@ struct NativeConfig : public PaddlePredictor::Config { bool use_gpu{false}; int device{0}; float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization. + // Specify the variable's name of each input. + bool specify_input_name{false}; std::string prog_file; std::string param_file; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 51993398bd..2a8e4af516 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -54,9 +54,9 @@ class MulOp : public framework::OperatorWithKernel { auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); - PADDLE_ENFORCE_EQ( - x_mat_dims[1], y_mat_dims[0], - "First matrix's width must be equal with second matrix's height."); + PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0], + "First matrix's width must be equal with second matrix's " + "height. %s, %s"); std::vector output_dims; output_dims.reserve( static_cast(x_num_col_dims + y_dims.size() - y_num_col_dims)); From 0b3d8fcd948b3e62e49f46c980f467aa736020af Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 22 Aug 2018 13:34:23 +0800 Subject: [PATCH 06/10] Feature/op standard (#12860) * new doc * standard --- doc/fluid/dev/new_op_cn.md | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/doc/fluid/dev/new_op_cn.md b/doc/fluid/dev/new_op_cn.md index 63d471ff52..c00f73be95 100644 --- a/doc/fluid/dev/new_op_cn.md +++ b/doc/fluid/dev/new_op_cn.md @@ -119,10 +119,29 @@ $$Out = scale*X$$ 这个例子有`AddAttr("scale", "...").SetDefault(1.0);` : 增加`scale`系数,作为参数属性,并且设置默认值为1.0。 +### 定义GradProtoMaker类 +每个Op的必须有一个对应的GraProtoMaker,若未定制对应前向Op的GradProtoMaker,fluid提供了DefaultGradProtoMaker,默认注册会使用全部输入输出,包括Input, Output, Output@Grad等,使用不需要的变量的会造成显存浪费。 +下面示例定义了ScaleOp的GradProtoMaker。 + +```cpp +class ScaleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDesc(); + grad_op->SetType("scale"); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttr("scale", GetAttr("scale")); + return std::unique_ptr(grad_op); + } +}; +``` ### 定义Operator类 -下面的点实现了MulOp的定义: +下面实现了MulOp的定义: ```cpp class MulOp : public framework::OperatorWithKernel { @@ -383,6 +402,19 @@ PADDLE_ENFORCE(forward_pd != nullptr, "Fail to find eltwise_fwd_pd in device context"); //eltwise_fwd_pd用户可能看不懂 ``` +3. OP内部调用非法接口:Op内部如果出现Output = ShareDataWith(Input) +问题示例: +```cpp +auto *out = ctx.Output("Out"); +auto *in = ctx.Input("X"); +out->ShareDataWith(*in); +``` +Op内部如果出现Output = ShareDataWith(Input),相当于operator图的中有一条隐藏边,连接了Input和Output,这条边无法在图分析中表达,引发基于图优化的错误。 + +4. OP实现的性能实践 +调用了eigen的broadcast, chop等操作,性能会比手写cuda kernel差几倍以上。此时cpu的实现可以复用eigen,gpu实现可以实现cuda kernel. + + #### OP InferShape检查提示信息特别说明 - 检查输入输出变量,请统一遵循以下格式 From decda738b0bf1dbba5ff4b0b035d2630ab0b7919 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Wed, 22 Aug 2018 15:14:08 +0800 Subject: [PATCH 07/10] fea/anakin compile with demo (#12772) * anakin support x86 * fix code style * add anakin ditu cnn demo * add timer * add rnn * fix inference_anakin_cnn/rnn_test compile error * make anakin_rnn_tester run * add anakin_enable_op_time option * update api/CMakeLists.txt * enlarge the max_batch_size in anakin.config * update with comments --- cmake/external/anakin.cmake | 29 +- paddle/fluid/inference/api/CMakeLists.txt | 11 +- paddle/fluid/inference/api/api.cc | 3 - .../fluid/inference/api/api_anakin_engine.cc | 137 ++++++-- .../fluid/inference/api/api_anakin_engine.h | 6 +- .../api/api_anakin_engine_rnn_tester.cc | 315 ++++++++++++++++++ .../inference/api/paddle_inference_api.h | 4 +- 7 files changed, 463 insertions(+), 42 deletions(-) create mode 100644 paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 5d11d238cd..78be074909 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -2,6 +2,11 @@ if (NOT WITH_ANAKIN) return() endif() +option(ANAKIN_ENABLE_OP_TIMER "Get more detailed information with Anakin op time" OFF) +if(ANAKIN_ENABLE_OP_TIMER) + add_definitions(-DPADDLE_ANAKIN_ENABLE_OP_TIMER) +endif() + INCLUDE(ExternalProject) set(ANAKIN_SOURCE_DIR ${THIRD_PARTY_PATH}/anakin) # the anakin install dir is only default one now @@ -11,23 +16,34 @@ set(ANAKIN_LIBRARY ${ANAKIN_INSTALL_DIR}) set(ANAKIN_SHARED_LIB ${ANAKIN_LIBRARY}/libanakin.so) set(ANAKIN_SABER_LIB ${ANAKIN_LIBRARY}/libanakin_saber_common.so) -# TODO(luotao): ANAKIN_MODLE_URL will move to demo ci later. -set(ANAKIN_MODLE_URL "http://paddle-inference-dist.bj.bcebos.com/mobilenet_v2.anakin.bin") +# TODO(luotao): ANAKIN_MODLE_URL etc will move to demo ci later. +set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com") +set(ANAKIN_MODLE_URL "${INFERENCE_URL}/mobilenet_v2.anakin.bin") +set(ANAKIN_RNN_MODLE_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn.anakin2.model.bin") +set(ANAKIN_RNN_DATA_URL "${INFERENCE_URL}/anakin_test%2Fditu_rnn_data.txt") execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_SOURCE_DIR}") -execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL}") +execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL} -N") +execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_MODLE_URL} -N") +execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_RNN_DATA_URL} -N") include_directories(${ANAKIN_INCLUDE}) include_directories(${ANAKIN_INCLUDE}/saber/) +include_directories(${ANAKIN_INCLUDE}/saber/core/) +include_directories(${ANAKIN_INCLUDE}/saber/funcs/impl/x86/) +include_directories(${ANAKIN_INCLUDE}/saber/funcs/impl/cuda/base/cuda_c/) set(ANAKIN_COMPILE_EXTRA_FLAGS -Wno-error=unused-but-set-variable -Wno-unused-but-set-variable -Wno-error=unused-variable -Wno-unused-variable -Wno-error=format-extra-args -Wno-format-extra-args - -Wno-error=comment -Wno-comment - -Wno-error=format -Wno-format + -Wno-error=comment -Wno-comment + -Wno-error=format -Wno-format + -Wno-error=maybe-uninitialized -Wno-maybe-uninitialized -Wno-error=switch -Wno-switch -Wno-error=return-type -Wno-return-type -Wno-error=non-virtual-dtor -Wno-non-virtual-dtor + -Wno-error=ignored-qualifiers + -Wno-ignored-qualifiers -Wno-sign-compare -Wno-reorder -Wno-error=cpp) @@ -38,7 +54,7 @@ ExternalProject_Add( DEPENDS ${MKLML_PROJECT} # Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code. GIT_REPOSITORY "https://github.com/luotao1/Anakin" - GIT_TAG "bcf17aabe7921ceb7bce591244b4f9dce7dba5c8" + GIT_TAG "211d1fc5d813d70c0c14072f9083cf25f40940ea" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES @@ -48,6 +64,7 @@ ExternalProject_Add( -DMKLML_ROOT=${THIRD_PARTY_PATH}/install/mklml -DCUDNN_ROOT=${CUDNN_ROOT} -DCUDNN_INCLUDE_DIR=${CUDNN_INCLUDE_DIR} + -DENABLE_OP_TIMER=${ANAKIN_ENABLE_OP_TIMER} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ANAKIN_INSTALL_DIR} ) diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 6da9a6385f..0ca1af455c 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -65,7 +65,7 @@ endif() if (WITH_ANAKIN AND WITH_GPU) # only needed in CI # compile the libinference_anakin_api.a and anakin.so. - cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) + cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber mklml) cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) function(anakin_target target_name) target_compile_options(${target_name} BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) @@ -73,9 +73,12 @@ if (WITH_ANAKIN AND WITH_GPU) # only needed in CI anakin_target(inference_anakin_api) anakin_target(inference_anakin_api_shared) if (WITH_TESTING) - cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc + cc_test(api_anakin_engine_tester SRCS api_anakin_engine_tester.cc ARGS --model=${ANAKIN_SOURCE_DIR}/mobilenet_v2.anakin.bin - DEPS inference_anakin_api dynload_cuda SERIAL) - target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) + DEPS inference_anakin_api_shared dynload_cuda SERIAL) + cc_test(api_anakin_engine_rnn_tester SRCS api_anakin_engine_rnn_tester.cc + ARGS --model=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin + --datapath=${ANAKIN_SOURCE_DIR}/anakin_test%2Fditu_rnn_data.txt + DEPS inference_anakin_api_shared dynload_cuda SERIAL) endif(WITH_TESTING) endif() diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index 63c3f0d7b3..5f1e1b548c 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/inference/api/api_anakin_engine.cc b/paddle/fluid/inference/api/api_anakin_engine.cc index 6b374ceefb..ea66aa89b8 100644 --- a/paddle/fluid/inference/api/api_anakin_engine.cc +++ b/paddle/fluid/inference/api/api_anakin_engine.cc @@ -13,9 +13,22 @@ // limitations under the License. #include "paddle/fluid/inference/api/api_anakin_engine.h" + +#ifdef PADDLE_WITH_CUDA #include +#endif + +#include +#include +#include +#include +#include #include +#include "framework/core/net/net.h" +#include "framework/operators/ops.h" +#include "saber/funcs/timer.h" + namespace paddle { template @@ -23,16 +36,24 @@ PaddleInferenceAnakinPredictor::PaddleInferenceAnakinPredictor( const AnakinConfig &config) { CHECK(Init(config)); } - +template <> +PaddleInferenceAnakinPredictor::PaddleInferenceAnakinPredictor( + const AnakinConfig &config) { + omp_set_dynamic(0); + omp_set_num_threads(1); + mkl_set_num_threads(1); + CHECK(Init(config)); +} template bool PaddleInferenceAnakinPredictor::Init(const AnakinConfig &config) { if (!(graph_.load(config.model_file))) { - LOG(FATAL) << "fail to load graph from " << config.model_file; + VLOG(3) << "fail to load graph from " << config.model_file; return false; } auto inputs = graph_.get_ins(); for (auto &input_str : inputs) { graph_.ResetBatchSize(input_str, config.max_batch_size); + max_batch_size_ = config.max_batch_size; } // optimization for graph if (!(graph_.Optimize())) { @@ -52,15 +73,15 @@ bool PaddleInferenceAnakinPredictor::Run( std::vector *output_data, int batch_size) { for (const auto &input : inputs) { if (input.dtype != PaddleDType::FLOAT32) { - LOG(ERROR) << "Only support float type inputs. " << input.name - << "'s type is not float"; + VLOG(3) << "Only support float type inputs. " << input.name + << "'s type is not float"; return false; } auto d_tensor_in_p = executor_p_->get_in(input.name); - auto net_shape = d_tensor_in_p->valid_shape(); + auto net_shape = d_tensor_in_p->shape(); if (net_shape.size() != input.shape.size()) { - LOG(ERROR) << " input " << input.name - << "'s shape size should be equal to that of net"; + VLOG(3) << " input " << input.name + << "'s shape size should be equal to that of net"; return false; } int sum = 1; @@ -79,21 +100,45 @@ bool PaddleInferenceAnakinPredictor::Run( } d_tensor_in_p->reshape(tmp_shape); + if (input.lod.size() > 0) { + if (input.lod.size() > 1) { + VLOG(3) << " input lod first dim should <=1, but you set " + << input.lod.size(); + return false; + } + std::vector offset(input.lod[0].begin(), input.lod[0].end()); + d_tensor_in_p->set_seq_offset(offset); + VLOG(3) << "offset.size(): " << offset.size(); + for (int i = 0; i < offset.size(); i++) { + VLOG(3) << offset[i]; + } + } + float *d_data_p = d_tensor_in_p->mutable_data(); - if (cudaMemcpy(d_data_p, static_cast(input.data.data()), - d_tensor_in_p->valid_size() * sizeof(float), - cudaMemcpyHostToDevice) != 0) { - LOG(ERROR) << "copy data from CPU to GPU error"; - return false; + +#ifdef PADDLE_WITH_CUDA + if (std::is_same::value) { + if (cudaMemcpy(d_data_p, static_cast(input.data.data()), + d_tensor_in_p->valid_size() * sizeof(float), + cudaMemcpyHostToDevice) != 0) { + VLOG(3) << "copy data from CPU to GPU error"; + return false; + } + } +#endif + if (std::is_same::value) { + memcpy(d_data_p, static_cast(input.data.data()), + d_tensor_in_p->valid_size() * sizeof(float)); } - cudaStreamSynchronize(NULL); } +#ifdef PADDLE_WITH_CUDA cudaDeviceSynchronize(); executor_p_->prediction(); cudaDeviceSynchronize(); +#endif if (output_data->empty()) { - LOG(ERROR) << "At least one output should be set with tensors' names."; + VLOG(3) << "At least one output should be set with tensors' names."; return false; } for (auto &output : *output_data) { @@ -102,14 +147,22 @@ bool PaddleInferenceAnakinPredictor::Run( if (output.data.length() < tensor->valid_size() * sizeof(float)) { output.data.Resize(tensor->valid_size() * sizeof(float)); } - // Copy data from GPU -> CPU - if (cudaMemcpy(output.data.data(), tensor->mutable_data(), - tensor->valid_size() * sizeof(float), - cudaMemcpyDeviceToHost) != 0) { - LOG(ERROR) << "copy data from GPU to CPU error"; - return false; + +#if PADDLE_WITH_CUDA + if (std::is_same::value) { + // Copy data from GPU -> CPU + if (cudaMemcpy(output.data.data(), tensor->mutable_data(), + tensor->valid_size() * sizeof(float), + cudaMemcpyDeviceToHost) != 0) { + VLOG(3) << "copy data from GPU to CPU error"; + return false; + } + } +#endif + if (std::is_same::value) { + memcpy(output.data.data(), tensor->mutable_data(), + tensor->valid_size() * sizeof(float)); } - cudaStreamSynchronize(NULL); } return true; } @@ -132,7 +185,7 @@ PaddleInferenceAnakinPredictor::Clone() { auto anakin_predictor_p = dynamic_cast *>(cls.get()); if (!anakin_predictor_p) { - LOG(ERROR) << "fail to call Init"; + VLOG(3) << "fail to call Init"; return nullptr; } anakin_predictor_p->get_executer().init(graph_); @@ -162,6 +215,44 @@ std::unique_ptr CreatePaddlePredictor< VLOG(3) << "Anakin Predictor create on unknown platform."; return nullptr; } -}; +} + +#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER +template +using executor_t = + anakin::Net; + +template +void DisplayOpTimer(executor_t *net_executor, int epoch) { + std::vector op_time = net_executor->get_op_time(); + auto exec_funcs = net_executor->get_exec_funcs(); + auto op_param = net_executor->get_op_param(); + for (int i = 0; i < op_time.size(); i++) { + LOG(INFO) << "name: " << exec_funcs[i].name + << " op_type: " << exec_funcs[i].op_name + << " op_param: " << op_param[i] << " time " << op_time[i] / epoch; + } + std::map op_map; + for (int i = 0; i < op_time.size(); i++) { + auto it = op_map.find(op_param[i]); + if (it != op_map.end()) + op_map[op_param[i]] += op_time[i]; + else + op_map.insert(std::pair(op_param[i], op_time[i])); + } + for (auto it = op_map.begin(); it != op_map.end(); ++it) { + LOG(INFO) << it->first << " " << (it->second) / epoch << " ms"; + } +} +#endif + +template +PaddleInferenceAnakinPredictor::~PaddleInferenceAnakinPredictor() { +#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER + DisplayOpTimer(executor_p_, max_batch_size_); +#endif + delete executor_p_; + executor_p_ = nullptr; +} } // namespace paddle diff --git a/paddle/fluid/inference/api/api_anakin_engine.h b/paddle/fluid/inference/api/api_anakin_engine.h index 836badd979..dd08661880 100644 --- a/paddle/fluid/inference/api/api_anakin_engine.h +++ b/paddle/fluid/inference/api/api_anakin_engine.h @@ -47,10 +47,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor { anakin::Net& get_executer(); - ~PaddleInferenceAnakinPredictor() override { - delete executor_p_; - executor_p_ = nullptr; - }; + ~PaddleInferenceAnakinPredictor() override; private: bool Init(const AnakinConfig& config); @@ -60,6 +57,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor { anakin::Net* executor_p_{nullptr}; AnakinConfig config_; + int max_batch_size_{0}; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc b/paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc new file mode 100644 index 0000000000..6183864234 --- /dev/null +++ b/paddle/fluid/inference/api/api_anakin_engine_rnn_tester.cc @@ -0,0 +1,315 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include "framework/core/net/net.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +DEFINE_string(model, "", "Directory of the inference model."); +DEFINE_string(datapath, "", "Path of the dataset."); +DEFINE_int32(batch_size, 1, "batch size."); +DEFINE_int32(repeat, 1, "Running the inference program repeat times."); + +// Timer for timer +class Timer { + public: + double start; + double startu; + void tic() { + struct timeval tp; + gettimeofday(&tp, NULL); + start = tp.tv_sec; + startu = tp.tv_usec; + } + double toc() { + struct timeval tp; + gettimeofday(&tp, NULL); + double used_time_ms = + (tp.tv_sec - start) * 1000.0 + (tp.tv_usec - startu) / 1000.0; + return used_time_ms; + } +}; + +std::vector string_split(std::string in_str, + std::string delimiter) { + std::vector seq; + int found = in_str.find(delimiter); + int pre_found = -1; + while (found != std::string::npos) { + if (pre_found == -1) { + seq.push_back(in_str.substr(0, found)); + } else { + seq.push_back(in_str.substr(pre_found + delimiter.length(), + found - delimiter.length() - pre_found)); + } + pre_found = found; + found = in_str.find(delimiter, pre_found + delimiter.length()); + } + seq.push_back( + in_str.substr(pre_found + 1, in_str.length() - (pre_found + 1))); + return seq; +} +std::vector string_split( + std::string in_str, std::vector& delimiter) { // NOLINT + std::vector in; + std::vector out; + out.push_back(in_str); + for (auto del : delimiter) { + in = out; + out.clear(); + for (auto s : in) { + auto out_s = string_split(s, del); + for (auto o : out_s) { + out.push_back(o); + } + } + } + return out; +} + +class Data { + public: + Data(std::string file_name, int batch_size) + : _batch_size(batch_size), _total_length(0) { + _file.open(file_name); + _file.seekg(_file.end); + _total_length = _file.tellg(); + _file.seekg(_file.beg); + } + void get_batch_data(std::vector>& fea, // NOLINT + std::vector>& week_fea, // NOLINT + std::vector>& time_fea, // NOLINT + std::vector& seq_offset); // NOLINT + + private: + std::fstream _file; + int _total_length; + int _batch_size; +}; + +void Data::get_batch_data( + std::vector>& fea, // NOLINT + std::vector>& week_fea, // NOLINT + std::vector>& time_fea, // NOLINT + std::vector& seq_offset) { // NOLINT + int seq_num = 0; + long unsigned int cum = 0; // NOLINT + + char buf[10000]; + seq_offset.clear(); + seq_offset.push_back(0); + fea.clear(); + week_fea.clear(); + time_fea.clear(); + while (_file.getline(buf, 10000)) { + std::string s = buf; + std::vector deli_vec = {":"}; + std::vector data_vec = string_split(s, deli_vec); + + std::vector seq; + seq = string_split(data_vec[0], {"|"}); + + for (auto link : seq) { + std::vector data = string_split(link, ","); + std::vector vec; + for (int i = 0; i < data.size(); i++) { + vec.push_back(atof(data[i].c_str())); + } + fea.push_back(vec); + } + std::vector week_data; + std::vector time_data; + + week_data = string_split(data_vec[2], ","); + std::vector vec_w; + for (int i = 0; i < week_data.size(); i++) { + vec_w.push_back(atof(week_data[i].c_str())); + } + week_fea.push_back(vec_w); + + time_data = string_split(data_vec[1], ","); + std::vector vec_t; + for (int i = 0; i < time_data.size(); i++) { + vec_t.push_back(atof(time_data[i].c_str())); + } + time_fea.push_back(vec_t); + + cum += seq.size(); + seq_offset.push_back(cum); + + seq_num++; + if (seq_num >= _batch_size) { + break; + } + } +} + +namespace paddle { + +AnakinConfig GetConfig() { + AnakinConfig config; + // using AnakinConfig::X86 if you need to use cpu to do inference + config.target_type = AnakinConfig::X86; + config.model_file = FLAGS_model; + config.device = 0; + config.max_batch_size = 1000; // the max number of token + return config; +} + +void set_tensor(std::string name, std::vector shape, + std::vector& vec) { // NOLINT + int sum = 1; + std::for_each(shape.begin(), shape.end(), [&](int n) { sum *= n; }); + float* data = new float[sum]; + PaddleTensor tensor; + tensor.name = name; + tensor.shape = shape; + tensor.data = PaddleBuf(data, sum); + tensor.dtype = PaddleDType::FLOAT32; + vec.push_back(tensor); +} + +void single_test() { + AnakinConfig config = GetConfig(); + auto predictor = + CreatePaddlePredictor(config); + + int max_batch_size = 1000; + std::string feature_file = FLAGS_datapath; + Data map_data(feature_file, FLAGS_batch_size); + std::vector> fea; + std::vector> week_fea; + std::vector> time_fea; + std::vector seq_offset; // NOLINT + + paddle::PaddleTensor tensor_0, tensor_1, tensor_2; + tensor_0.name = "input_0"; + tensor_1.name = "input_4"; + tensor_2.name = "input_5"; + + PaddleTensor tensor_out; + tensor_out.name = "final_output.tmp_1_gout"; + tensor_out.shape = std::vector({}); + tensor_out.data = PaddleBuf(); + tensor_out.dtype = PaddleDType::FLOAT32; + + std::vector inputs; + std::vector outputs(1, tensor_out); + + int data_0_dim = 38; + int data_1_dim = 10; + int data_2_dim = 10; + float data_0[max_batch_size * data_0_dim]; // NOLINT + float data_1[max_batch_size * data_1_dim]; // NOLINT + float data_2[max_batch_size * data_2_dim]; // NOLINT + + int count = 0; + while (true) { + if (count++ > 0) break; // only run the first batch in ci. + seq_offset.clear(); + map_data.get_batch_data(fea, week_fea, time_fea, seq_offset); + if (seq_offset.size() <= 1) { + LOG(FATAL) << "seq_offset.size() <= 1, exit."; + break; + } + + std::vector> seq_offset_vec; // NOLINT + seq_offset_vec.push_back(seq_offset); + tensor_0.lod = seq_offset_vec; + + int p_shape_0[] = {(int)fea.size(), 1, 1, data_0_dim}; // NOLINT + int p_shape_1[] = {(int)week_fea.size(), data_1_dim, 1, 1}; // NOLINT + int p_shape_2[] = {(int)time_fea.size(), data_2_dim, 1, 1}; // NOLINT + + std::vector shape_0(p_shape_0, p_shape_0 + 4); + std::vector shape_1(p_shape_1, p_shape_1 + 4); + std::vector shape_2(p_shape_2, p_shape_2 + 4); + + tensor_0.shape = shape_0; + tensor_1.shape = shape_1; + tensor_2.shape = shape_2; + + for (int i = 0; i < fea.size(); i++) { + memcpy(data_0 + i * data_0_dim, &fea[i][0], sizeof(float) * data_0_dim); + } + for (int i = 0; i < week_fea.size(); i++) { + memcpy(data_1 + i * data_1_dim, &week_fea[i][0], + sizeof(float) * data_1_dim); + } + for (int i = 0; i < time_fea.size(); i++) { + memcpy(data_2 + i * data_2_dim, &time_fea[i][0], + sizeof(float) * data_2_dim); + } + + tensor_0.data = + paddle::PaddleBuf(data_0, fea.size() * sizeof(float) * data_0_dim); + tensor_1.data = + paddle::PaddleBuf(data_1, week_fea.size() * sizeof(float) * data_1_dim); + tensor_2.data = + paddle::PaddleBuf(data_2, time_fea.size() * sizeof(float) * data_2_dim); + + tensor_0.dtype = paddle::PaddleDType::FLOAT32; + tensor_1.dtype = paddle::PaddleDType::FLOAT32; + tensor_2.dtype = paddle::PaddleDType::FLOAT32; + + inputs.clear(); + inputs.push_back(tensor_1); + inputs.push_back(tensor_2); + inputs.push_back(tensor_0); + + Timer timer; + timer.tic(); + for (int i = 0; i < FLAGS_repeat; i++) predictor->Run(inputs, &outputs); + + LOG(INFO) << "batch_size = " << FLAGS_batch_size + << ", repeat = " << FLAGS_repeat + << ", sequence_length = " << seq_offset[seq_offset.size() - 1] + << ", latency: " << timer.toc() / FLAGS_repeat << "ms"; + + float* data_o = static_cast(outputs[0].data.data()); + VLOG(3) << "outputs[0].data.length() = " << outputs[0].data.length(); + for (size_t j = 0; j < outputs[0].data.length(); ++j) { + VLOG(3) << "output[" << j << "]: " << data_o[j]; + } + } +} +} // namespace paddle + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + logger::init(argv[0]); + + paddle::single_test(); + /* multi-threads + std::vector threads; + int num = 1; + for (int i = 0; i < num; i++) { + LOG(INFO) << " thread id : " << i; + threads.emplace_back(paddle::single_test); + } + for (int i = 0; i < num; i++) { + threads[i].join(); + } + threads.clear(); + */ + + return 0; +} diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 3b36377274..36fd0727aa 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -45,7 +45,7 @@ class PaddleBuf { PaddleBuf(void* data, size_t length) : data_(data), length_(length), memory_owned_{false} {} // Own memory. - explicit PaddleBuf(size_t length) + PaddleBuf(size_t length) : data_(new char[length]), length_(length), memory_owned_(true) {} // Resize to `length` bytes. void Resize(size_t length); @@ -70,7 +70,7 @@ struct PaddleTensor { std::vector shape; PaddleBuf data; // blob of data. PaddleDType dtype; - std::vector> lod; // lod data + std::vector> lod; // Tensor+LoD equals LoDTensor }; enum class PaddleEngineKind { From e8b4e0d627176e4d4028698e48e1d67440913d88 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 22 Aug 2018 16:53:21 +0800 Subject: [PATCH 08/10] fix load_vars bug (#12869) --- python/paddle/fluid/io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index b3ed094c89..5c4ec99c53 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -406,6 +406,9 @@ def load_vars(executor, attrs={'file_path': os.path.join(dirname, filename)}) executor.run(load_prog) + if main_program is None: + main_program = default_main_program() + # load slice vars on pserver, if have it. _load_slice_up_vars(executor, dirname, main_program._slice_vars_and_attrs) From 774896347943f7100adc9763dad529ffd5754f6e Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 22 Aug 2018 18:57:30 +0800 Subject: [PATCH 09/10] refine op_test (#12846) --- python/paddle/fluid/tests/unittests/op_test.py | 2 +- python/paddle/fluid/tests/unittests/testsuite.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 972e44c952..44cd073379 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -56,8 +56,8 @@ def get_numeric_gradient(place, def get_output(): sum = [] + op.run(scope, place) for output_name in output_names: - op.run(scope, place) sum.append( np.array(scope.find_var(output_name).get_tensor()).mean()) return np.array(sum).mean() diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index 31ae25f02c..34fbb1b549 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -153,9 +153,6 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): def append_loss_ops(block, output_names): mean_inputs = list(map(block.var, output_names)) - # for item in mean_inputs: - # print(item) - # print("Item", item.dtype) if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) From f5d5d7b2d989e8aa5b5e637fd04318566b23f2fe Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 22 Aug 2018 20:06:48 +0800 Subject: [PATCH 10/10] Disable in_place in batch_norm API. (#12736) * Disable in_place in batch_norm API. --- paddle/fluid/operators/batch_norm_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 9 +++++++-- python/paddle/fluid/nets.py | 2 +- .../paddle/fluid/tests/book/test_image_classification.py | 5 ++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5912a1a17c..969f75544f 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization").Reuse("X"); + AddOutput("Y", "result after normalization"); AddOutput("MeanOut", "Share memory with Mean. " "Store the global mean when training") diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 71592618f5..a815ba0f2f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,6 +27,7 @@ from . import utils import random from .. import unique_name from functools import reduce +import warnings __all__ = [ 'fc', @@ -2046,7 +2047,7 @@ def batch_norm(input, param_attr(ParamAttr): The parameter attribute for Parameter `scale`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. data_layout(string, default NCHW): NCHW|NHWC - in_place(bool, Default False): Make the input and output of batch norm reuse memory. + in_place(bool, Default False): This argument is deprecated since 0.15.0. use_mkldnn(bool, Default false): ${use_mkldnn_comment} name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2068,6 +2069,10 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() + if in_place: + raise warnings.warn("The argument in_place is deprecated since 0.15.0, " + "please do not set it True.") + input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2117,7 +2122,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) + batch_norm_out = helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 051fe84364..01563cbbb7 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -229,7 +229,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) + tmp = layers.batch_norm(input=tmp, act=conv_act) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 9fe361425c..cd1e8cd682 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -256,7 +256,10 @@ def main(net_type, use_cuda, is_local=True): save_dirname = "image_classification_" + net_type + ".inference.model" train(net_type, use_cuda, save_dirname, is_local) - infer(use_cuda, save_dirname) + + # There is bug in fluid.InferenceTranspiler for VGG. + if net_type == "resnet": + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase):