Implement the function of OutScaleForTraining/OutScaleForInference in dygraph (#26601)

* Implement the function of OueScaleForTraining/OutScaleForInference in dygraph

test=develop
my_2.0rc
guofei 5 years ago committed by GitHub
parent 0140d74e23
commit 6bbb6e7f45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"

@ -51,6 +51,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}},
{"hierarchical_sigmoid",
{"X", "W", "Label", "PathTable", "PathCode", "Bias"}},
{"moving_average_abs_max_scale", {"X", "InAccum", "InState"}},
};
// NOTE(zhiqiu): Like op_ins_map.
@ -75,6 +76,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"collect_fpn_proposals", {"FpnRois", "RoisNum"}},
{"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
@ -118,6 +120,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
};
// clang-format off

@ -24,7 +24,8 @@ from paddle.fluid.data_feeder import check_variable_and_dtype
__all__ = [
'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D',
'QuantizedLinear', 'FakeChannelWiseQuantDequantAbsMax'
'QuantizedLinear', 'FakeChannelWiseQuantDequantAbsMax',
'MovingAverageAbsMaxScale'
]
@ -494,3 +495,78 @@ class QuantizedLinear(layers.Layer):
else:
pre_activation = mul_out
return self._helper.append_activation(pre_activation, act=self._act)
class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
"""
MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer.
Its computational formula is described as below:
:math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
:math:`Out = X`
"""
super(MovingAverageAbsMaxScale, self).__init__()
self._moving_rate = moving_rate
self._dtype = dtype
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=self._dtype)
self._state.stop_gradient = True
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=self._dtype)
self._accum.stop_gradient = True
MovingAverageAbsMaxScale._has_create = True
def forward(self, input):
if in_dygraph_mode():
attrs = ('moving_rate', self._moving_rate, 'is_test',
not self.training)
state = self._state if self.training else None
accum = self._accum if self.training else None
out_scale, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, self._scale, state, accum, *attrs)
return out_scale
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'MovingAverageAbsMaxScale')
scale_out = self._scale
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]}
outputs = {"OutScale": [scale_out]}
if self.training:
inputs['InState'] = [self._state]
inputs['InAccum'] = [self._accum]
outputs['OutState'] = [self._state]
outputs['OutAccum'] = [self._accum]
self._helper.append_op(
type="moving_average_abs_max_scale",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return scale_out

@ -0,0 +1,83 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization.imperative import quant_nn
paddle.enable_static()
def init_data(batch_size=32, img_shape=[784], label_range=9):
np.random.seed(5)
assert isinstance(img_shape, list)
input_shape = [batch_size] + img_shape
img = np.random.random(size=input_shape).astype(np.float32)
label = np.array(
[np.random.randint(0, label_range) for _ in range(batch_size)]).reshape(
(-1, 1)).astype("int64")
return img, label
class TestMovingAverageAbsMaxScaleOp(unittest.TestCase):
def check_backward(self, use_cuda):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
image = fluid.layers.data(
name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc_tmp = fluid.layers.fc(image, size=10, act='softmax')
out_scale = quant_nn.MovingAverageAbsMaxScale(
name=fc_tmp.name, dtype=fc_tmp.dtype)
fc_tmp_1 = out_scale(fc_tmp)
cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp,
label)
loss = fluid.layers.reduce_mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
moving_average_abs_max_scale_ops = [
op for op in main_program.blocks[0].ops
if op.type == u'moving_average_abs_max_scale'
]
assert len(
moving_average_abs_max_scale_ops
) == 1, "The number of moving_average_abs_max_scale_ops should be 1."
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
binary = fluid.compiler.CompiledProgram(
main_program).with_data_parallel(loss_name=loss.name)
img, label = init_data()
feed_dict = {"image": img, "label": label}
res = exe.run(binary, feed_dict)
def test_fw_bw(self):
if core.is_compiled_with_cuda():
self.check_backward(use_cuda=True)
self.check_backward(use_cuda=False)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save