|
|
@ -17,7 +17,6 @@ from mindspore import context
|
|
|
|
from mindspore import nn
|
|
|
|
from mindspore import nn
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.train.model import ParallelMode
|
|
|
|
from mindspore.train.model import ParallelMode
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
from mindspore.communication.management import get_rank, init, get_group_size
|
|
|
|
from mindspore.communication.management import get_rank, init, get_group_size
|
|
|
|
|
|
|
|
|
|
|
@ -63,7 +62,7 @@ def set_context(config):
|
|
|
|
device_id=config.device_id, save_graphs=False)
|
|
|
|
device_id=config.device_id, save_graphs=False)
|
|
|
|
elif config.platform == "GPU":
|
|
|
|
elif config.platform == "GPU":
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
device_target=args_opt.platform, save_graphs=False)
|
|
|
|
device_target=config.platform, save_graphs=False)
|
|
|
|
|
|
|
|
|
|
|
|
def config_ckpoint(config, lr, step_size):
|
|
|
|
def config_ckpoint(config, lr, step_size):
|
|
|
|
cb = None
|
|
|
|
cb = None
|
|
|
|