|
|
|
@ -24,11 +24,12 @@ from mindspore import context
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.nn.optim import Momentum
|
|
|
|
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext,_checkpoint_cb_for_save_op,\
|
|
|
|
|
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist,\
|
|
|
|
|
_build_callbacks, CheckpointConfig, _set_cur_net
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \
|
|
|
|
|
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
|
|
|
|
_build_callbacks, CheckpointConfig, _set_cur_net
|
|
|
|
|
from mindspore.common.api import ms_function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
"""Net definition."""
|
|
|
|
|
|
|
|
|
@ -52,6 +53,7 @@ class Net(nn.Cell):
|
|
|
|
|
|
|
|
|
|
class LossNet(nn.Cell):
|
|
|
|
|
""" LossNet definition """
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(LossNet, self).__init__()
|
|
|
|
|
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
|
|
|
|
@ -110,8 +112,8 @@ def test_save_checkpoint():
|
|
|
|
|
os.remove('./test_files/test_ckpt-model.pkl')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_loss_monitor_graph_model():
|
|
|
|
|
"""Test lossmonitor Graph model."""
|
|
|
|
|
def test_loss_monitor_sink_model():
|
|
|
|
|
"""Test loss monitor sink model."""
|
|
|
|
|
cb_params = _InternalCallbackParam()
|
|
|
|
|
cb_params.cur_epoch_num = 4
|
|
|
|
|
cb_params.cur_step_num = 2
|
|
|
|
@ -129,8 +131,8 @@ def test_loss_monitor_graph_model():
|
|
|
|
|
callbacklist.end(run_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_Loss_Monitor_feed_feed_model():
|
|
|
|
|
"""Test Loss Monitor feed feed mode."""
|
|
|
|
|
def test_loss_monitor_feed_model():
|
|
|
|
|
"""Test loss monitor non-sink mode."""
|
|
|
|
|
cb_params = _InternalCallbackParam()
|
|
|
|
|
run_context = RunContext(cb_params)
|
|
|
|
|
loss_cb = LossMonitor(1)
|
|
|
|
|