modify wrong args

pull/5881/head
Payne 5 years ago
parent b7425d3e0c
commit da41c9d683

@ -26,8 +26,9 @@ def set_config(args):
"batch_size": 150, "batch_size": 150,
"epoch_size": 15, "epoch_size": 15,
"warmup_epochs": 0, "warmup_epochs": 0,
"lr_max": 0.03, "lr_init": .0,
"lr_end": 0.03, "lr_end": 0.03,
"lr_max": 0.03,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,
@ -45,9 +46,9 @@ def set_config(args):
"batch_size": 150, "batch_size": 150,
"epoch_size": 200, "epoch_size": 200,
"warmup_epochs": 0, "warmup_epochs": 0,
"lr": 0.8, "lr_init": .0,
"lr_max": 0.03, "lr_end": .0,
"lr_end": 0.03, "lr_max": 0.8,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,
@ -66,9 +67,9 @@ def set_config(args):
"batch_size": 256, "batch_size": 256,
"epoch_size": 200, "epoch_size": 200,
"warmup_epochs": 4, "warmup_epochs": 4,
"lr": 0.4, "lr_init": 0.00,
"lr_max": 0.03, "lr_end": 0.00,
"lr_end": 0.03, "lr_max": 0.4,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 4e-5, "weight_decay": 4e-5,
"label_smooth": 0.1, "label_smooth": 0.1,

@ -17,7 +17,6 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init, get_group_size from mindspore.communication.management import get_rank, init, get_group_size
@ -63,7 +62,7 @@ def set_context(config):
device_id=config.device_id, save_graphs=False) device_id=config.device_id, save_graphs=False)
elif config.platform == "GPU": elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.platform, save_graphs=False) device_target=config.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size): def config_ckpoint(config, lr, step_size):
cb = None cb = None

@ -77,7 +77,7 @@ if __name__ == '__main__':
# get learning rate # get learning rate
lr = Tensor(get_lr(global_step=0, lr = Tensor(get_lr(global_step=0,
lr_init=0, lr_init=config.lr_init,
lr_end=config.lr_end, lr_end=config.lr_end,
lr_max=config.lr_max, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, warmup_epochs=config.warmup_epochs,

Loading…
Cancel
Save