|
|
|
@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
|
|
|
|
from mindspore.nn.optim import Momentum
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
|
|
|
|
|
_CallbackManager, Callback, CheckpointConfig
|
|
|
|
|
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
|
|
|
|
|
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
|
|
|
|
|
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op():
|
|
|
|
|
one_param['name'] = "conv1.weight"
|
|
|
|
|
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
|
|
|
|
|
parameter_list.append(one_param)
|
|
|
|
|
checkpoint_cb_for_save_op(parameter_list)
|
|
|
|
|
_checkpoint_cb_for_save_op(parameter_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_cb_for_save_op_update_net():
|
|
|
|
@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net():
|
|
|
|
|
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
|
|
|
|
|
parameter_list.append(one_param)
|
|
|
|
|
net = Net()
|
|
|
|
|
set_cur_net(net)
|
|
|
|
|
checkpoint_cb_for_save_op(parameter_list)
|
|
|
|
|
_set_cur_net(net)
|
|
|
|
|
_checkpoint_cb_for_save_op(parameter_list)
|
|
|
|
|
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|