|
|
|
@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|
|
|
|
from mindspore.train.serialization import _exec_save_checkpoint
|
|
|
|
|
from mindspore.train.serialization import save_checkpoint
|
|
|
|
|
from mindspore.common import set_seed
|
|
|
|
|
|
|
|
|
|
from src.dataset import create_dataset, extract_features
|
|
|
|
@ -116,7 +116,7 @@ if __name__ == '__main__':
|
|
|
|
|
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
|
|
|
|
|
end="")
|
|
|
|
|
if (epoch + 1) % config.save_checkpoint_epochs == 0:
|
|
|
|
|
_exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
|
|
|
|
|
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
|
|
|
|
|
f"mobilenetv2_head_{epoch+1}.ckpt"))
|
|
|
|
|
print("total cost {:5.4f} s".format(time.time() - start))
|
|
|
|
|
|
|
|
|
|