|
|
|
@ -134,12 +134,8 @@ class LossGet(Callback):
|
|
|
|
|
return self._loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_process(device_id, epoch_size, num_classes, batch_size):
|
|
|
|
|
os.system("mkdir " + str(device_id))
|
|
|
|
|
os.chdir(str(device_id))
|
|
|
|
|
def train_process(epoch_size, num_classes, batch_size):
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
|
|
|
context.set_context(device_id=device_id)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
net = resnet50(batch_size, num_classes)
|
|
|
|
|
loss = CrossEntropyLoss()
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad,
|
|
|
|
@ -148,34 +144,15 @@ def train_process(device_id, epoch_size, num_classes, batch_size):
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
|
|
|
|
|
|
|
|
|
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size)
|
|
|
|
|
batch_num = dataset.get_dataset_size()
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./",
|
|
|
|
|
config=config_ck)
|
|
|
|
|
loss_cb = LossGet()
|
|
|
|
|
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
|
|
|
|
model.train(epoch_size, dataset, callbacks=[loss_cb])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eval(batch_size, num_classes):
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
|
|
|
context.set_context(device_id=0)
|
|
|
|
|
|
|
|
|
|
net = resnet50(batch_size, num_classes)
|
|
|
|
|
loss = CrossEntropyLoss()
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad,
|
|
|
|
|
net.get_parameters()), 0.01, 0.9)
|
|
|
|
|
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
|
|
|
|
checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt"
|
|
|
|
|
param_dict = load_checkpoint(checkpoint_path)
|
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
net.set_train(False)
|
|
|
|
|
eval_dataset = create_dataset(1, training=False)
|
|
|
|
|
res = model.eval(eval_dataset)
|
|
|
|
|
print("result: ", res)
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_arm_ascend_training
|
|
|
|
|
@pytest.mark.platform_x86_ascend_training
|
|
|
|
@ -184,11 +161,7 @@ def test_resnet_cifar_1p():
|
|
|
|
|
epoch_size = 1
|
|
|
|
|
num_classes = 10
|
|
|
|
|
batch_size = 32
|
|
|
|
|
device_id = 0
|
|
|
|
|
train_process(device_id, epoch_size, num_classes, batch_size)
|
|
|
|
|
time.sleep(3)
|
|
|
|
|
acc = eval(batch_size, num_classes)
|
|
|
|
|
os.chdir("../")
|
|
|
|
|
os.system("rm -rf " + str(device_id))
|
|
|
|
|
acc = train_process(epoch_size, num_classes, batch_size)
|
|
|
|
|
os.system("rm -rf kernel_meta")
|
|
|
|
|
print("End training...")
|
|
|
|
|
assert acc['acc'] > 0.35
|
|
|
|
|