|
|
@ -96,6 +96,8 @@ def train():
|
|
|
|
if eval_acc >= val_acc_max and eval_loss < val_loss_min:
|
|
|
|
if eval_acc >= val_acc_max and eval_loss < val_loss_min:
|
|
|
|
val_acc_model = eval_acc
|
|
|
|
val_acc_model = eval_acc
|
|
|
|
val_loss_model = eval_loss
|
|
|
|
val_loss_model = eval_loss
|
|
|
|
|
|
|
|
if os.path.exists("ckpts/gat.ckpt"):
|
|
|
|
|
|
|
|
os.remove("ckpts/gat.ckpt")
|
|
|
|
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
|
|
|
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
|
|
|
val_acc_max = np.max((val_acc_max, eval_acc))
|
|
|
|
val_acc_max = np.max((val_acc_max, eval_acc))
|
|
|
|
val_loss_min = np.min((val_loss_min, eval_loss))
|
|
|
|
val_loss_min = np.min((val_loss_min, eval_loss))
|
|
|
|