From 17833d30e3c4eedf72839becec7aa0c144929261 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 9 Apr 2018 16:48:07 +0800 Subject: [PATCH 01/10] fuse batch norm for conv operator without bias --- python/paddle/fluid/__init__.py | 1 + python/paddle/fluid/framework.py | 9 + python/paddle/fluid/inference_transpiler.py | 174 ++++++++++++++++++ .../tests/book/test_image_classification.py | 15 ++ 4 files changed, 199 insertions(+) create mode 100644 python/paddle/fluid/inference_transpiler.py diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f01d638efd..445204b2fd 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -36,6 +36,7 @@ from distribute_transpiler import DistributeTranspiler from distribute_transpiler_simple import SimpleDistributeTranspiler from concurrency import (Go, make_channel, channel_send, channel_recv, channel_close, Select) +from inference_transpiler import InferenceTranspiler import clip from memory_optimization_transpiler import memory_optimize, release_memory import profiler diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 33cf691817..0ca853d3c6 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -920,6 +920,15 @@ class Block(object): ops_in_cpp_index += 1 ops_in_python_index += 1 + # sync ops inserted from c++ end + if len(self.ops) != len(ops_in_cpp) and start_index == 0 and len( + self.ops) == end_index: + del self.ops[:] + for index in range(len(ops_in_cpp)): + op_desc = ops_in_cpp[index] + op = Operator(self, op_desc) + self.ops.append(op) + assert len(self.ops) == len(ops_in_cpp) for index in range(len(self.ops)): assert self.ops[index].desc == ops_in_cpp[index] diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py new file mode 100644 index 0000000000..6a45de5741 --- /dev/null +++ b/python/paddle/fluid/inference_transpiler.py @@ -0,0 +1,174 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import shutil +from . import core + + +class InferenceTranspiler: + def transpile(self, program, scope, place): + ''' + Transpile the program to a inference program by fused batch normalization. + + The batch normalization followed the convolution or fully connected layer + can be integrated with them. Doing so will give us a forward acceleration, + especially in environments like mobile or embedded. + + For input X: + - Conv process: X = input * W + bias + - Batch norm process: X' = (X - mean) / std + - Scale Process: Y = a * X' + b + + After fuse into one operation: + + Y = (input * W + bias - mean) / std * a + b + = input * a * W / std + ((bias - mean) / std * a + b) + + The operator transformation is: + - before: + - conv->batch_norm->any_other_op (bias == 0) + - conv->elementwise_add->batch_norm->any_other_op (bias != 0) + - after: + - conv->elementwise_add->any_other_op + + The transpile stages are: + 1. insert elementwise_add op when bias == 0, and adjust its input and output. + 2. fuse the batch_norm's parameters to conv and elementwise_add operators. + 3. remove batch_norm ops and its variables which are not used in any other ops. + 4. remove unused variables. + + :param program: program to transpile + :type program: Program + :param scope: inference scope + :type scope: Scope + :param place: inference place + :type place: Place + :return: program by fused batch normalization + :rtype: Program + ''' + self.scope = scope + self.place = place + self.block_desc = program.get_desc().block(0) + i = 0 + while i < self.block_desc.op_size(): + current_op = self.block_desc.op(i) + # TODO(luotao1): consider only conv2d now. fc would be delt later. + if current_op.type() in ['conv2d']: + next_op = self.block_desc.op(i + 1) + # TODO(luotao1): consider only conv2d without bias now. + # If conv2d with bias, the next_op.type is elementwise_add. + if (next_op.type() == 'batch_norm'): + # insert bias op + bias_op = self._insert_bias_op(i + 1, current_op, next_op) + program.sync_with_cpp() + # fuse batch_norm + self._fuse_param(current_op, next_op, bias_op) + # remove batch_norm_op + self.block_desc.remove_op(i + 2, i + 3) + program.sync_with_cpp() + i = i + 1 + i = i + 1 + + self._remove_unused_var() + program.sync_with_cpp() + + return program + + # ====================== private transpiler functions ===================== + def _insert_bias_op(self, index, current_op, bn_op): + ''' + Construct elementwise_add operator for adding bias + and insert it into program. + + :param index: insert location of bias_op + :type index: Int + :param current_op: current operator (conv or fc) + :type current_op: Operator + :param bn_op: batch norm operator + :type bn_op: Operator + :return: bias_op + :rtype: Operator + ''' + bias_op = self.block_desc.insert_op(index) + bias_op.set_type("elementwise_add") + # The input of bias_op is current_op's output and Bias of bn_op + # The output of bias_op is bn_op's output + bias_op.set_input("X", current_op.output("Output")) + bias_op.set_input("Y", bn_op.input("Bias")) + bias_op.set_output("Out", bn_op.output("Y")) + bias_op.set_attr('axis', 1) # dim_start=1 + return bias_op + + def _fuse_param(self, current_op, bn_op, bias_op): + ''' + fuse the batch_norm_op' parameters to current_op (conv or fc) + + :param current_op: current operator (conv or fc) + :type current_op: Operator + :param bn_op: batch norm operator + :type bn_op: Operator + :param bias_op: elementwise_add operator for adding bias + :type bias_op: Operator + ''' + + def _load_tensor(param_name): + return self.scope.find_var(param_name[0]).get_tensor() + + def _load_param(param_name): + return np.array(_load_tensor(param_name)) + + bias_bn = _load_param(bn_op.input("Bias")) #Bias + scale_bn = _load_param(bn_op.input("Scale")) #Scale + mean_bn = _load_param(bn_op.input("Mean")) #Mean + var_bn = _load_param(bn_op.input("Variance")) #Variance + + # TODO(luotao1): consider only conv2d now. fc would be delt later. + current_param = _load_param(current_op.input("Filter")) + current_tensor = _load_tensor(current_op.input("Filter")) + + std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5))) + tmp = np.float32(np.divide(scale_bn, std_bn)) + + # add bias of batch_norm_op to conv2d + bias = np.zeros(bias_bn.shape) + bias = np.float32( + np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) + bias_tensor = _load_tensor(bias_op.input("Y")) + bias_tensor.set(bias, self.place) + + # re-compute weight of conv2d + tmp = tmp.reshape(tmp.shape[0], -1) + dst_param = current_param.reshape((tmp.shape[0], -1)) + dst_param = np.float32(np.multiply(dst_param, tmp)) + dst_param = dst_param.reshape(current_param.shape) + + # set the updated parameters + current_tensor.set(np.array(dst_param), self.place) + + def _remove_unused_var(self): + ''' + remove unused varibles in program desc + ''' + args = [] + for i in xrange(0, self.block_desc.op_size()): + current_op = self.block_desc.op(i) + args += current_op.input_arg_names() + args += current_op.output_arg_names() + args = list(set(args)) # unique the input and output arguments + + for var in self.block_desc.all_vars(): + if var.name() not in args: + self.block_desc.remove_var(var.name()) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index e8bb082be1..87cbe98c9b 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -22,6 +22,7 @@ import sys import numpy import unittest import os +import numpy as np def resnet_cifar10(input, depth=32): @@ -224,6 +225,20 @@ def infer(use_cuda, save_dirname=None): results = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) + + # Use inference_transpiler to speedup + t = fluid.InferenceTranspiler() + inference_transpiler_program = t.transpile(inference_program, + inference_scope, place) + transpiler_results = exe.run(inference_transpiler_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + + assert len(results[0]) == len(transpiler_results[0]) + for i in range(len(results[0])): + np.testing.assert_almost_equal(results[0][i], + transpiler_results[0][i]) + print("infer results: ", results[0]) From ea0cf6f3829ddce01b2e5a2c0dea36e3ff7fce40 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 10 Apr 2018 13:00:28 +0800 Subject: [PATCH 02/10] rewrite inference_transpiler in Python end --- paddle/fluid/framework/block_desc.cc | 50 +----------------- python/paddle/fluid/framework.py | 20 ++++---- python/paddle/fluid/inference_transpiler.py | 51 ++++++++++--------- .../tests/book/test_image_classification.py | 4 +- .../tests/unittests/test_protobuf_descs.py | 20 -------- 5 files changed, 41 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index fbe08349c3..b8847e4b90 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/block_desc.h" +#include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include - namespace paddle { namespace framework { @@ -147,52 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } - auto get_vars = [](std::deque>::iterator &op, - std::vector &v) { - auto in_names = (*op)->InputArgumentNames(); - v.insert(v.end(), in_names.begin(), in_names.end()); - auto out_names = (*op)->OutputArgumentNames(); - v.insert(v.end(), out_names.begin(), out_names.end()); - std::sort(v.begin(), v.end()); - auto last = std::unique(v.begin(), v.end()); - v.erase(last, v.end()); - }; - need_update_ = true; - - for (size_t i = s; i < e; i++) { - // since remove op one by one, every time remove the first op. - auto op = ops_.begin() + s; - - // collect input and output variables from current delete op - std::vector cur_vars; - get_vars(op, cur_vars); - - // remove current op - ops_.erase(ops_.begin() + s); - - // collect input and output variables from other ops - std::vector other_vars; - for (auto it = ops_.begin(); it != ops_.end(); it++) { - get_vars(it, other_vars); - } - - // variables should be deleted - std::vector delete_vars; - // delete_vars = cur_vars - cur_vars ^ other_input_vars - std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(), - other_vars.end(), - std::inserter(delete_vars, delete_vars.end())); - // remove variables - for (size_t i = 0; i < delete_vars.size(); i++) { - auto name = delete_vars[i]; - auto it = vars_.find(name); - PADDLE_ENFORCE(it != vars_.end(), - "%s is not in variable list, it should not be deleted", - name); - vars_.erase(it); - VLOG(3) << "deleting variable " << name; - } - } + ops_.erase(ops_.begin() + s, ops_.begin() + e); } std::vector BlockDesc::AllOps() const { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0ca853d3c6..793421a22f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -818,6 +818,11 @@ class Block(object): del self.vars[name] self.sync_with_cpp() + def remove_var(self, name): + self.sync_with_cpp() + self.desc.remove_var(name) + del self.vars[name] + def create_parameter(self, *args, **kwargs): global_block = self.program.global_block() param = Parameter(global_block, *args, **kwargs) @@ -838,6 +843,11 @@ class Block(object): self.ops.insert(index, op) return op + def remove_op(self, index): + self.sync_with_cpp() + self.desc.remove_op(index, index + 1) + del self.ops[index] + def delete_ops(self, ops): # remove from cpp # FIXME(typhoonzero): remove only the first occurrence. @@ -846,6 +856,7 @@ class Block(object): end = list(self.ops).index(ops[-1]) except Exception, e: raise e + self.desc.remove_op(start, end + 1) def slice_ops(self, start, end): @@ -920,15 +931,6 @@ class Block(object): ops_in_cpp_index += 1 ops_in_python_index += 1 - # sync ops inserted from c++ end - if len(self.ops) != len(ops_in_cpp) and start_index == 0 and len( - self.ops) == end_index: - del self.ops[:] - for index in range(len(ops_in_cpp)): - op_desc = ops_in_cpp[index] - op = Operator(self, op_desc) - self.ops.append(op) - assert len(self.ops) == len(ops_in_cpp) for index in range(len(self.ops)): assert self.ops[index].desc == ops_in_cpp[index] diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index 6a45de5741..3791e93576 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -61,30 +61,26 @@ class InferenceTranspiler: ''' self.scope = scope self.place = place - self.block_desc = program.get_desc().block(0) + self.block = program.block(0) i = 0 - while i < self.block_desc.op_size(): - current_op = self.block_desc.op(i) + while i < len(self.block.ops): + current_op = self.block.ops[i] # TODO(luotao1): consider only conv2d now. fc would be delt later. - if current_op.type() in ['conv2d']: - next_op = self.block_desc.op(i + 1) + if current_op.type in ['conv2d']: + next_op = self.block.ops[i + 1] # TODO(luotao1): consider only conv2d without bias now. # If conv2d with bias, the next_op.type is elementwise_add. - if (next_op.type() == 'batch_norm'): + if (next_op.type == 'batch_norm'): # insert bias op bias_op = self._insert_bias_op(i + 1, current_op, next_op) - program.sync_with_cpp() # fuse batch_norm self._fuse_param(current_op, next_op, bias_op) # remove batch_norm_op - self.block_desc.remove_op(i + 2, i + 3) - program.sync_with_cpp() + self.block.remove_op(i + 2) i = i + 1 i = i + 1 self._remove_unused_var() - program.sync_with_cpp() - return program # ====================== private transpiler functions ===================== @@ -102,14 +98,19 @@ class InferenceTranspiler: :return: bias_op :rtype: Operator ''' - bias_op = self.block_desc.insert_op(index) - bias_op.set_type("elementwise_add") # The input of bias_op is current_op's output and Bias of bn_op # The output of bias_op is bn_op's output - bias_op.set_input("X", current_op.output("Output")) - bias_op.set_input("Y", bn_op.input("Bias")) - bias_op.set_output("Out", bn_op.output("Y")) - bias_op.set_attr('axis', 1) # dim_start=1 + x_var = self.block.var(current_op.output("Output")[0]) + y_var = self.block.var(bn_op.input("Bias")[0]) + out_var = self.block.var(bn_op.output("Y")[0]) + + bias_op = self.block.insert_op( + index, + type="elementwise_add", + inputs={"X": x_var, + "Y": y_var}, + outputs={"Out": out_var}, + attrs={"axis": 1}) # dim_start=1 return bias_op def _fuse_param(self, current_op, bn_op, bias_op): @@ -160,15 +161,15 @@ class InferenceTranspiler: def _remove_unused_var(self): ''' - remove unused varibles in program desc + remove unused varibles in program ''' args = [] - for i in xrange(0, self.block_desc.op_size()): - current_op = self.block_desc.op(i) - args += current_op.input_arg_names() - args += current_op.output_arg_names() + for i in range(len(self.block.ops)): + current_op = self.block.ops[i] + args += current_op.input_arg_names + args += current_op.output_arg_names args = list(set(args)) # unique the input and output arguments - for var in self.block_desc.all_vars(): - if var.name() not in args: - self.block_desc.remove_var(var.name()) + for var in self.block.vars.keys(): + if var not in args: + self.block.remove_var(var) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 87cbe98c9b..bca42a89cd 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -236,8 +236,8 @@ def infer(use_cuda, save_dirname=None): assert len(results[0]) == len(transpiler_results[0]) for i in range(len(results[0])): - np.testing.assert_almost_equal(results[0][i], - transpiler_results[0][i]) + np.testing.assert_almost_equal( + results[0][i], transpiler_results[0][i], decimal=6) print("infer results: ", results[0]) diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index f98a8bbc68..3f9059fb5b 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -201,24 +201,6 @@ class TestBlockDesc(unittest.TestCase): op1.set_type("test") op2.set_type("test") - var0 = block.var("var0") - var1 = block.var("var1") - var2 = block.var("var2") - var3 = block.var("var3") - var4 = block.var("var4") - var5 = block.var("var5") - - op0.set_input("X", ["var0"]) - op0.set_output("Y", ["var0"]) - op1.set_input("X", ["var1", "var2"]) - op1.set_output("Y", ["var3", "var4"]) - op2.set_input("X", ["var1"]) - op2.set_output("Y", ["var4", "var5"]) - - program.sync_with_cpp() - - # remove op1, its input var2 and output var3 will be removed at the same time, - # but its input var1 and output var4 will not be removed since they are used for op2. block.remove_op(1, 2) program.sync_with_cpp() @@ -226,8 +208,6 @@ class TestBlockDesc(unittest.TestCase): for idx in xrange(0, block.op_size()): all_ops.append(block.op(idx)) self.assertEqual(all_ops, [op0, op2]) - all_vars = block.all_vars() - self.assertEqual(set(all_vars), {var0, var1, var4, var5}) if __name__ == '__main__': From 5483258b7e8c6e56968b1d63091539ed34609b1c Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 10 Apr 2018 19:01:18 +0800 Subject: [PATCH 03/10] fuse batch norm for conv operator with bias --- python/paddle/fluid/inference_transpiler.py | 44 +++++++++++++++---- .../tests/book/test_image_classification.py | 12 +++-- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index 3791e93576..194f7adf46 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -45,10 +45,11 @@ class InferenceTranspiler: - conv->elementwise_add->any_other_op The transpile stages are: - 1. insert elementwise_add op when bias == 0, and adjust its input and output. + 1. insert elementwise_add op when bias == 0. 2. fuse the batch_norm's parameters to conv and elementwise_add operators. - 3. remove batch_norm ops and its variables which are not used in any other ops. - 4. remove unused variables. + 3. remove batch_norm ops which are not used in any other ops. + 4. adjust the input of any_other_op to be the output of elementwise_add operator. + 5. remove unused variables. :param program: program to transpile :type program: Program @@ -62,24 +63,35 @@ class InferenceTranspiler: self.scope = scope self.place = place self.block = program.block(0) + self.input_map = {} # store the input names should be adjusted + i = 0 while i < len(self.block.ops): current_op = self.block.ops[i] # TODO(luotao1): consider only conv2d now. fc would be delt later. if current_op.type in ['conv2d']: next_op = self.block.ops[i + 1] - # TODO(luotao1): consider only conv2d without bias now. - # If conv2d with bias, the next_op.type is elementwise_add. + # conv2d without bias if (next_op.type == 'batch_norm'): # insert bias op bias_op = self._insert_bias_op(i + 1, current_op, next_op) # fuse batch_norm - self._fuse_param(current_op, next_op, bias_op) + self._fuse_param(current_op, next_op, bias_op, 0) # remove batch_norm_op self.block.remove_op(i + 2) i = i + 1 + # conv2d with bias, the next_op.type is elementwise_add + elif (next_op.type == 'elementwise_add'): + next_next_op = self.block.ops[i + 2] + if (next_next_op.type == 'batch_norm'): + # fuse batch_norm + self._fuse_param(current_op, next_next_op, next_op, 1) + # remove batch_norm_op + self.block.remove_op(i + 2) + i = i + 1 i = i + 1 + self._adjust_input() self._remove_unused_var() return program @@ -113,7 +125,7 @@ class InferenceTranspiler: attrs={"axis": 1}) # dim_start=1 return bias_op - def _fuse_param(self, current_op, bn_op, bias_op): + def _fuse_param(self, current_op, bn_op, bias_op, with_bias): ''' fuse the batch_norm_op' parameters to current_op (conv or fc) @@ -123,6 +135,8 @@ class InferenceTranspiler: :type bn_op: Operator :param bias_op: elementwise_add operator for adding bias :type bias_op: Operator + :param with_bias: If current operator has bias, with_bias = 1; otherwise 0. + :type with_bias: Int ''' def _load_tensor(param_name): @@ -144,7 +158,10 @@ class InferenceTranspiler: tmp = np.float32(np.divide(scale_bn, std_bn)) # add bias of batch_norm_op to conv2d - bias = np.zeros(bias_bn.shape) + if with_bias: + bias = _load_param(bias_op.input("Y")) + else: + bias = np.zeros(bias_bn.shape) bias = np.float32( np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) bias_tensor = _load_tensor(bias_op.input("Y")) @@ -159,6 +176,17 @@ class InferenceTranspiler: # set the updated parameters current_tensor.set(np.array(dst_param), self.place) + # collect the renamed input + self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0] + + def _adjust_input(self): + for i in range(len(self.block.ops)): + current_op = self.block.ops[i] + for input_arg in current_op.input_arg_names: + if input_arg in self.input_map: + current_op.rename_input(input_arg, + self.input_map[input_arg]) + def _remove_unused_var(self): ''' remove unused varibles in program diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index bca42a89cd..5e47bcb2cb 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -26,7 +26,13 @@ import numpy as np def resnet_cifar10(input, depth=32): - def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): tmp = fluid.layers.conv2d( input=input, filter_size=filter_size, @@ -34,7 +40,7 @@ def resnet_cifar10(input, depth=32): stride=stride, padding=padding, act=None, - bias_attr=False) + bias_attr=bias_attr) return fluid.layers.batch_norm(input=tmp, act=act) def shortcut(input, ch_in, ch_out, stride): @@ -45,7 +51,7 @@ def resnet_cifar10(input, depth=32): def basicblock(input, ch_in, ch_out, stride): tmp = conv_bn_layer(input, ch_out, 3, stride, 1) - tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) short = shortcut(input, ch_in, ch_out, stride) return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') From 7815cdffbaacd49f8ba875f5d53a8972b9c3b060 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 11 Apr 2018 15:22:17 +0800 Subject: [PATCH 04/10] use clone method to flush in forth --- python/paddle/fluid/inference_transpiler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index 194f7adf46..a215b98c61 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -93,7 +93,10 @@ class InferenceTranspiler: self._adjust_input() self._remove_unused_var() - return program + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + return program.clone() # ====================== private transpiler functions ===================== def _insert_bias_op(self, index, current_op, bn_op): From f45818e7f9291322a13fd9e137c151282e0bd1a9 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 13 Apr 2018 16:55:02 +0800 Subject: [PATCH 05/10] create new varible in scope --- python/paddle/fluid/__init__.py | 1 + python/paddle/fluid/inference_transpiler.py | 48 ++++++++++++++----- .../tests/book/test_image_classification.py | 9 ++-- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index bb4b6d5fc4..e9ca0d45f9 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -67,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ 'clip', 'SimpleDistributeTranspiler', 'DistributeTranspiler', + 'InferenceTranspiler', 'memory_optimize', 'release_memory', 'profiler', diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index a215b98c61..7b7bd899ea 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -21,7 +21,20 @@ from . import core class InferenceTranspiler: def transpile(self, program, scope, place): ''' - Transpile the program to a inference program by fused batch normalization. + Transpile the program. Support only fuse batch normalization now. + + :param program: program to transpile + :type program: Program + :param scope: inference scope + :type scope: Scope + :param place: inference place + :type place: Place + ''' + self.fuse_batch_norm(program, scope, place) + + def fuse_batch_norm(self, program, scope, place): + ''' + Transpile the program by fused batch normalization. The batch normalization followed the convolution or fully connected layer can be integrated with them. Doing so will give us a forward acceleration, @@ -57,8 +70,6 @@ class InferenceTranspiler: :type scope: Scope :param place: inference place :type place: Place - :return: program by fused batch normalization - :rtype: Program ''' self.scope = scope self.place = place @@ -96,7 +107,7 @@ class InferenceTranspiler: # TODO(luotao): use clone() method to flush the program.desc in force, # since some large program.desc will not be flushed immediately. # And a better solution will be considered later. - return program.clone() + program = program.clone() # ====================== private transpiler functions ===================== def _insert_bias_op(self, index, current_op, bn_op): @@ -142,11 +153,25 @@ class InferenceTranspiler: :type with_bias: Int ''' - def _load_tensor(param_name): - return self.scope.find_var(param_name[0]).get_tensor() + def _update_param(op, old_param_name, new_param): + # For the sake of remaining the original variables the same as before, + # create new variables in scope to store the new parameters. + old_param_name = old_param_name[0] + old_var = self.block.vars[old_param_name] + new_param_name = old_param_name + '_fuse_bn' + new_var = self.block.create_parameter( + name=new_param_name.encode('ascii'), + type=old_var.type, + dtype=old_var.dtype, + shape=old_var.shape) + op.rename_input(old_param_name, new_param_name) + self.scope.var(new_param_name) + + tensor = self.scope.find_var(new_param_name).get_tensor() + tensor.set(np.array(new_param), self.place) def _load_param(param_name): - return np.array(_load_tensor(param_name)) + return np.array(self.scope.find_var(param_name[0]).get_tensor()) bias_bn = _load_param(bn_op.input("Bias")) #Bias scale_bn = _load_param(bn_op.input("Scale")) #Scale @@ -155,8 +180,6 @@ class InferenceTranspiler: # TODO(luotao1): consider only conv2d now. fc would be delt later. current_param = _load_param(current_op.input("Filter")) - current_tensor = _load_tensor(current_op.input("Filter")) - std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5))) tmp = np.float32(np.divide(scale_bn, std_bn)) @@ -167,8 +190,6 @@ class InferenceTranspiler: bias = np.zeros(bias_bn.shape) bias = np.float32( np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) - bias_tensor = _load_tensor(bias_op.input("Y")) - bias_tensor.set(bias, self.place) # re-compute weight of conv2d tmp = tmp.reshape(tmp.shape[0], -1) @@ -176,8 +197,9 @@ class InferenceTranspiler: dst_param = np.float32(np.multiply(dst_param, tmp)) dst_param = dst_param.reshape(current_param.shape) - # set the updated parameters - current_tensor.set(np.array(dst_param), self.place) + # update parameters + _update_param(current_op, current_op.input("Filter"), dst_param) + _update_param(bias_op, bias_op.input("Y"), bias) # collect the renamed input self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0] diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 5e47bcb2cb..aeacca5753 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -226,16 +226,17 @@ def infer(use_cuda, save_dirname=None): batch_size = 1 tensor_img = numpy.random.rand(batch_size, 3, 32, 32).astype("float32") + # Use inference_transpiler to speedup + inference_transpiler_program = inference_program.clone() + t = fluid.InferenceTranspiler() + t.transpile(inference_transpiler_program, inference_scope, place) + # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. results = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) - # Use inference_transpiler to speedup - t = fluid.InferenceTranspiler() - inference_transpiler_program = t.transpile(inference_program, - inference_scope, place) transpiler_results = exe.run(inference_transpiler_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) From ec512cdcce5cf6e5952c7184db222f3d293d218d Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 13 Apr 2018 18:45:09 +0800 Subject: [PATCH 06/10] add comment for branch network --- python/paddle/fluid/inference_transpiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index 7b7bd899ea..be8a627953 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -81,6 +81,9 @@ class InferenceTranspiler: current_op = self.block.ops[i] # TODO(luotao1): consider only conv2d now. fc would be delt later. if current_op.type in ['conv2d']: + # TODO(luotao1): consider single chain network now. + # For branch network, we counldn't use block.ops[i + 1] as + # the judgment condition. next_op = self.block.ops[i + 1] # conv2d without bias if (next_op.type == 'batch_norm'): From 81c47b21ef742dca9a7bfad16059575ce57f20aa Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 17 Apr 2018 19:38:20 +0800 Subject: [PATCH 07/10] add type check and default scope --- python/paddle/fluid/inference_transpiler.py | 29 ++++++++++++------- .../tests/book/test_image_classification.py | 2 +- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index be8a627953..39b01610f9 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -13,26 +13,35 @@ # limitations under the License. import numpy as np -import os -import shutil +from framework import Program +from executor import global_scope from . import core class InferenceTranspiler: - def transpile(self, program, scope, place): + def transpile(self, program, place, scope=None): ''' Transpile the program. Support only fuse batch normalization now. :param program: program to transpile :type program: Program - :param scope: inference scope - :type scope: Scope :param place: inference place :type place: Place + :param scope: inference scope + :type scope: Scope or None ''' - self.fuse_batch_norm(program, scope, place) - - def fuse_batch_norm(self, program, scope, place): + if not isinstance(program, Program): + raise TypeError("program should be as Program type") + if not isinstance(place, core.CPUPlace) and not isinstance( + place, core.CUDAPlace): + raise TypeError("place should be as CPUPlace/CUDAPlace type") + if scope is None: + scope = global_scope() + if not isinstance(scope, core.Scope): + raise TypeError("scope should be as Scope type or None") + self.fuse_batch_norm(program, place, scope) + + def fuse_batch_norm(self, program, place, scope): ''' Transpile the program by fused batch normalization. @@ -66,10 +75,10 @@ class InferenceTranspiler: :param program: program to transpile :type program: Program - :param scope: inference scope - :type scope: Scope :param place: inference place :type place: Place + :param scope: inference scope + :type scope: Scope ''' self.scope = scope self.place = place diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index aeacca5753..0027b651e8 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -229,7 +229,7 @@ def infer(use_cuda, save_dirname=None): # Use inference_transpiler to speedup inference_transpiler_program = inference_program.clone() t = fluid.InferenceTranspiler() - t.transpile(inference_transpiler_program, inference_scope, place) + t.transpile(inference_transpiler_program, place) # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. From acdf7cbd19c105cc20a7e32a57512bbc3b8dc844 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Mon, 16 Apr 2018 08:24:51 -0700 Subject: [PATCH 08/10] - Added EPS for softmax MKLDNN op - EPS added to softmax mkldnn primitive outcome is limited to training phase Fixes after review clang format fixes clang format fixes --- paddle/fluid/operators/softmax_mkldnn_op.cc | 9 +++++++++ paddle/fluid/operators/softmax_op.cc | 3 +++ python/paddle/fluid/layers/nn.py | 6 +++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/softmax_mkldnn_op.cc index dc2f176344..d00bd1447e 100644 --- a/paddle/fluid/operators/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/softmax_mkldnn_op.cc @@ -73,6 +73,15 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { softmax_dst_memory); std::vector pipeline{softmax}; stream(stream::kind::eager).submit(pipeline).wait(); + + const bool is_test = ctx.Attr("is_test"); + if (!is_test) { + T threshold = exp(-64); + for (size_t i = 0; i < dst_tz[0] * dst_tz[1]; ++i) { + output_data[i] = + output_data[i] < threshold ? threshold : output_data[i]; + } + } } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 6bdefc0f23..e1f286f9ba 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -97,6 +97,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("is_test", + "Disable epsilon adding to softmax results. Used by MKLDNN.") + .SetDefault(false); AddComment(R"DOC( Softmax Operator. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5c2c2dd7ab..fb41cc6009 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -87,6 +87,7 @@ def fc(input, bias_attr=None, use_mkldnn=False, act=None, + is_test=False, name=None): """ **Fully Connected Layer** @@ -133,6 +134,7 @@ def fc(input, bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias of this layer. If it is set to None, no bias will be added to the output units. act (str, default None): Activation to be applied to the output of this layer. + is_test(bool): A flag indicating whether execution is in test phase. use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn library is installed. Default: False name (str, default None): The name of this layer. @@ -177,7 +179,9 @@ def fc(input, "W": w}, outputs={"Out": tmp}, attrs={"use_mkldnn": use_mkldnn, - "bias_attr": bias_attr}) + "is_test": is_test, + "bias_attr": bias_attr + }) return helper.append_activation(tmp) else: for input_var, param_attr in helper.iter_inputs_and_params(): From de8094f57eee01abc62724ed8aa7becdc0474314 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 17 Apr 2018 06:31:56 -0700 Subject: [PATCH 09/10] Cosmetic fixes --- python/paddle/fluid/layers/nn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fb41cc6009..e25400c68e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -178,10 +178,11 @@ def fc(input, inputs={"Input": input, "W": w}, outputs={"Out": tmp}, - attrs={"use_mkldnn": use_mkldnn, - "is_test": is_test, - "bias_attr": bias_attr - }) + attrs={ + "use_mkldnn": use_mkldnn, + "is_test": is_test, + "bias_attr": bias_attr + }) return helper.append_activation(tmp) else: for input_var, param_attr in helper.iter_inputs_and_params(): From ed681d5235fca44ba985c2380168afe8e09a1e7b Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 17 Apr 2018 17:02:44 -0700 Subject: [PATCH 10/10] Fix conv_mkldnn_op.cc which is causing CI failure --- paddle/fluid/operators/conv_mkldnn_op.cc | 32 +++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index d7a8f918ed..63d371310d 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_md = platform::MKLDNNMemDesc( dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); - auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, - reinterpret_cast(input_data)); - auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine}, - reinterpret_cast(filter_data)); + auto src_memory = + mkldnn::memory({src_md, mkldnn_engine}, + reinterpret_cast(const_cast(input_data))); + auto weights_memory = + mkldnn::memory({weights_md, mkldnn_engine}, + reinterpret_cast(const_cast(filter_data))); auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data); std::shared_ptr conv_pd = @@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); // create memory - auto diff_dst_memory = - mkldnn::memory({diff_weights_md, mkldnn_engine}, - reinterpret_cast(output_grad_data)); + auto diff_dst_memory = mkldnn::memory( + {diff_weights_md, mkldnn_engine}, + reinterpret_cast(const_cast(output_grad_data))); // Retrieve conv_pd from device context auto conv_pd = std::static_pointer_cast( @@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto diff_weights_memory = mkldnn::memory({diff_weights_md, mkldnn_engine}, reinterpret_cast(filter_grad_data)); - auto src_memory = mkldnn::memory({src_md, mkldnn_engine}, - reinterpret_cast(input_data)); + auto src_memory = + mkldnn::memory({src_md, mkldnn_engine}, + reinterpret_cast(const_cast(input_data))); // create backward conv primitive for weights auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights( @@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { strides, paddings, *conv_pd, mkldnn_engine); // create memory - auto diff_src_memory = - mkldnn::memory({diff_src_md, mkldnn_engine}, - reinterpret_cast(input_grad_data)); - auto weights_memory = mkldnn::memory( - {weights_md, mkldnn_engine}, reinterpret_cast(filter_data)); + auto diff_src_memory = mkldnn::memory( + {diff_src_md, mkldnn_engine}, + reinterpret_cast(const_cast(input_grad_data))); + auto weights_memory = + mkldnn::memory({weights_md, mkldnn_engine}, + reinterpret_cast(const_cast(filter_data))); // create backward conv primitive for data auto conv_bwd_data_prim = mkldnn::convolution_backward_data(