modify _exec_save_checkpoint

pull/5834/head
liuyang_655 4 years ago
parent 6a851ee252
commit 18c442e724

@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed
from src.dataset import create_dataset, extract_features
@ -116,7 +116,7 @@ if __name__ == '__main__':
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
end="")
if (epoch + 1) % config.save_checkpoint_epochs == 0:
_exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start))

Loading…
Cancel
Save