|
|
|
@ -59,8 +59,6 @@ def run_pretrain():
|
|
|
|
|
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
|
|
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
|
|
|
|
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
|
|
|
|
parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
|
|
|
|
|
parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
|
|
|
|
|
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
|
|
|
|
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
|
|
|
|
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
|
|
|
@ -75,8 +73,6 @@ def run_pretrain():
|
|
|
|
|
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
|
|
|
|
context.set_context(enable_loop_sink=(args_opt.enable_loop_sink == "true"),
|
|
|
|
|
enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|