|
|
|
@ -254,6 +254,172 @@ class TestImperativeQat(unittest.TestCase):
|
|
|
|
|
np.allclose(after_save, before_save.numpy()),
|
|
|
|
|
msg='Failed to save the inference quantized model.')
|
|
|
|
|
|
|
|
|
|
def test_qat_acc(self):
|
|
|
|
|
def _build_static_lenet(main, startup, is_test=False, seed=1000):
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
|
main.random_seed = seed
|
|
|
|
|
startup.random_seed = seed
|
|
|
|
|
img = fluid.layers.data(
|
|
|
|
|
name='image', shape=[1, 28, 28], dtype='float32')
|
|
|
|
|
label = fluid.layers.data(
|
|
|
|
|
name='label', shape=[1], dtype='int64')
|
|
|
|
|
prediction = StaticLenet(img)
|
|
|
|
|
if not is_test:
|
|
|
|
|
loss = fluid.layers.cross_entropy(
|
|
|
|
|
input=prediction, label=label)
|
|
|
|
|
avg_loss = fluid.layers.mean(loss)
|
|
|
|
|
else:
|
|
|
|
|
avg_loss = prediction
|
|
|
|
|
return img, label, avg_loss
|
|
|
|
|
|
|
|
|
|
reader = paddle.batch(
|
|
|
|
|
paddle.dataset.mnist.test(), batch_size=32, drop_last=True)
|
|
|
|
|
weight_quantize_type = 'abs_max'
|
|
|
|
|
activation_quant_type = 'moving_average_abs_max'
|
|
|
|
|
param_init_map = {}
|
|
|
|
|
seed = 1000
|
|
|
|
|
lr = 0.1
|
|
|
|
|
|
|
|
|
|
# imperative train
|
|
|
|
|
_logger.info(
|
|
|
|
|
"--------------------------dynamic graph qat--------------------------"
|
|
|
|
|
)
|
|
|
|
|
imperative_qat = ImperativeQuantAware(
|
|
|
|
|
weight_quantize_type=weight_quantize_type,
|
|
|
|
|
activation_quantize_type=activation_quant_type)
|
|
|
|
|
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
|
lenet = ImperativeLenet()
|
|
|
|
|
fixed_state = {}
|
|
|
|
|
for name, param in lenet.named_parameters():
|
|
|
|
|
p_shape = param.numpy().shape
|
|
|
|
|
p_value = param.numpy()
|
|
|
|
|
if name.endswith("bias"):
|
|
|
|
|
value = np.zeros_like(p_value).astype('float32')
|
|
|
|
|
else:
|
|
|
|
|
value = np.random.normal(
|
|
|
|
|
loc=0.0, scale=0.01, size=np.product(p_shape)).reshape(
|
|
|
|
|
p_shape).astype('float32')
|
|
|
|
|
fixed_state[name] = value
|
|
|
|
|
param_init_map[param.name] = value
|
|
|
|
|
lenet.set_dict(fixed_state)
|
|
|
|
|
|
|
|
|
|
imperative_qat.quantize(lenet)
|
|
|
|
|
adam = AdamOptimizer(
|
|
|
|
|
learning_rate=lr, parameter_list=lenet.parameters())
|
|
|
|
|
dynamic_loss_rec = []
|
|
|
|
|
lenet.train()
|
|
|
|
|
for batch_id, data in enumerate(reader()):
|
|
|
|
|
x_data = np.array([x[0].reshape(1, 28, 28)
|
|
|
|
|
for x in data]).astype('float32')
|
|
|
|
|
y_data = np.array(
|
|
|
|
|
[x[1] for x in data]).astype('int64').reshape(-1, 1)
|
|
|
|
|
|
|
|
|
|
img = fluid.dygraph.to_variable(x_data)
|
|
|
|
|
label = fluid.dygraph.to_variable(y_data)
|
|
|
|
|
|
|
|
|
|
out = lenet(img)
|
|
|
|
|
loss = fluid.layers.cross_entropy(out, label)
|
|
|
|
|
avg_loss = fluid.layers.mean(loss)
|
|
|
|
|
avg_loss.backward()
|
|
|
|
|
adam.minimize(avg_loss)
|
|
|
|
|
lenet.clear_gradients()
|
|
|
|
|
dynamic_loss_rec.append(avg_loss.numpy()[0])
|
|
|
|
|
if batch_id % 100 == 0:
|
|
|
|
|
_logger.info('{}: {}'.format('loss', avg_loss.numpy()))
|
|
|
|
|
|
|
|
|
|
imperative_qat.save_quantized_model(
|
|
|
|
|
dirname="./dynamic_mnist",
|
|
|
|
|
model=lenet,
|
|
|
|
|
input_shape=[(1, 28, 28)],
|
|
|
|
|
input_dtype=['float32'],
|
|
|
|
|
feed=[0],
|
|
|
|
|
fetch=[0])
|
|
|
|
|
|
|
|
|
|
# static graph train
|
|
|
|
|
_logger.info(
|
|
|
|
|
"--------------------------static graph qat--------------------------"
|
|
|
|
|
)
|
|
|
|
|
static_loss_rec = []
|
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
else:
|
|
|
|
|
place = core.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
|
|
|
|
|
main = fluid.Program()
|
|
|
|
|
infer = fluid.Program()
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
static_img, static_label, static_loss = _build_static_lenet(
|
|
|
|
|
main, startup, False, seed)
|
|
|
|
|
infer_img, _, infer_pre = _build_static_lenet(infer, startup, True,
|
|
|
|
|
seed)
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
|
opt = AdamOptimizer(learning_rate=lr)
|
|
|
|
|
opt.minimize(static_loss)
|
|
|
|
|
|
|
|
|
|
scope = core.Scope()
|
|
|
|
|
with fluid.scope_guard(scope):
|
|
|
|
|
exe.run(startup)
|
|
|
|
|
for param in main.all_parameters():
|
|
|
|
|
param_tensor = scope.var(param.name).get_tensor()
|
|
|
|
|
param_tensor.set(param_init_map[param.name], place)
|
|
|
|
|
|
|
|
|
|
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
|
|
|
|
|
infer_graph = IrGraph(core.Graph(infer.desc), for_test=True)
|
|
|
|
|
transform_pass = QuantizationTransformPass(
|
|
|
|
|
scope=scope,
|
|
|
|
|
place=place,
|
|
|
|
|
activation_quantize_type=activation_quant_type,
|
|
|
|
|
weight_quantize_type=weight_quantize_type,
|
|
|
|
|
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'])
|
|
|
|
|
transform_pass.apply(main_graph)
|
|
|
|
|
transform_pass.apply(infer_graph)
|
|
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
|
build_strategy.fuse_all_reduce_ops = False
|
|
|
|
|
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
|
|
|
|
|
loss_name=static_loss.name, build_strategy=build_strategy)
|
|
|
|
|
|
|
|
|
|
feeder = fluid.DataFeeder(
|
|
|
|
|
feed_list=[static_img, static_label], place=place)
|
|
|
|
|
with fluid.scope_guard(scope):
|
|
|
|
|
for batch_id, data in enumerate(reader()):
|
|
|
|
|
loss_v, = exe.run(binary,
|
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
|
fetch_list=[static_loss])
|
|
|
|
|
static_loss_rec.append(loss_v[0])
|
|
|
|
|
if batch_id % 100 == 0:
|
|
|
|
|
_logger.info('{}: {}'.format('loss', loss_v))
|
|
|
|
|
|
|
|
|
|
save_program = infer_graph.to_program()
|
|
|
|
|
with fluid.scope_guard(scope):
|
|
|
|
|
fluid.io.save_inference_model("./static_mnist", [infer_img.name],
|
|
|
|
|
[infer_pre], exe, save_program)
|
|
|
|
|
rtol = 1e-05
|
|
|
|
|
atol = 1e-08
|
|
|
|
|
for i, (loss_d,
|
|
|
|
|
loss_s) in enumerate(zip(dynamic_loss_rec, static_loss_rec)):
|
|
|
|
|
diff = np.abs(loss_d - loss_s)
|
|
|
|
|
if diff > (atol + rtol * np.abs(loss_s)):
|
|
|
|
|
_logger.info(
|
|
|
|
|
"diff({}) at {}, dynamic loss = {}, static loss = {}".
|
|
|
|
|
format(diff, i, loss_d, loss_s))
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(
|
|
|
|
|
np.array(dynamic_loss_rec),
|
|
|
|
|
np.array(static_loss_rec),
|
|
|
|
|
rtol=rtol,
|
|
|
|
|
atol=atol,
|
|
|
|
|
equal_nan=True),
|
|
|
|
|
msg='Failed to do the imperative qat.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|