|
|
@ -28,6 +28,7 @@ from src.alexnet import AlexNet
|
|
|
|
from src.get_param_groups import get_param_groups
|
|
|
|
from src.get_param_groups import get_param_groups
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
from mindspore.communication.management import init, get_rank
|
|
|
|
from mindspore.communication.management import init, get_rank
|
|
|
|
|
|
|
|
from mindspore import dataset as de
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import Tensor
|
|
|
|
from mindspore import Tensor
|
|
|
|
from mindspore.train import Model
|
|
|
|
from mindspore.train import Model
|
|
|
@ -37,6 +38,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
|
|
|
from mindspore.common import set_seed
|
|
|
|
from mindspore.common import set_seed
|
|
|
|
|
|
|
|
|
|
|
|
set_seed(1)
|
|
|
|
set_seed(1)
|
|
|
|
|
|
|
|
de.config.set_seed(1)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
|
|
|
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
|
|
|