|
|
|
@ -47,7 +47,7 @@ def train_lenet():
|
|
|
|
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
|
|
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck)
|
|
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
|
|
|
|
|
print("============== Starting Training Lenet==============")
|
|
|
|
@ -58,7 +58,7 @@ def train_lenet():
|
|
|
|
|
def train_lenet_quant():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
|
|
|
|
cfg = quant_cfg
|
|
|
|
|
ckpt_path = './checkpoint_lenet-10_1875.ckpt'
|
|
|
|
|
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
|
|
|
|
|
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
|
|
|
|
|
step_size = ds_train.get_dataset_size()
|
|
|
|
|
|
|
|
|
@ -81,7 +81,7 @@ def train_lenet_quant():
|
|
|
|
|
# call back and monitor
|
|
|
|
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
|
|
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
|
|
|
|
ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant", config=config_ckpt)
|
|
|
|
|
|
|
|
|
|
# define model
|
|
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
@ -96,7 +96,7 @@ def eval_quant():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
|
|
|
|
cfg = quant_cfg
|
|
|
|
|
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
|
|
|
|
|
ckpt_path = './checkpoint_lenet_1-10_937.ckpt'
|
|
|
|
|
ckpt_path = './ckpt_lenet_quant-10_937.ckpt'
|
|
|
|
|
# define fusion network
|
|
|
|
|
network = LeNet5Fusion(cfg.num_classes)
|
|
|
|
|
# convert fusion network to quantization aware network
|
|
|
|
|