Reschedules the outline

pull/4865/head
huangxinjing 5 years ago
parent 26ac496093
commit cd34a553b9

File diff suppressed because it is too large Load Diff

@ -22,25 +22,28 @@ def argparse_init():
parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
help="device where the code will be implemented. (Default: Ascend)")
parser.add_argument("--data_path", type=str, default="./test_raw_data/")
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--full_batch", type=bool, default=False)
parser.add_argument("--batch_size", type=int, default=16000)
parser.add_argument("--eval_batch_size", type=int, default=16000)
parser.add_argument("--field_size", type=int, default=39)
parser.add_argument("--vocab_size", type=int, default=200000)
parser.add_argument("--emb_dim", type=int, default=80)
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
parser.add_argument("--deep_layer_act", type=str, default='relu')
parser.add_argument("--keep_prob", type=float, default=1.0)
parser.add_argument("--dropout_flag", type=int, default=0)
parser.add_argument("--data_path", type=str, default="./test_raw_data/",
help="This should be set to the same directory given to the data_download's data_dir argument")
parser.add_argument("--epochs", type=int, default=15, help="Total train epochs")
parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ")
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.")
parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.")
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128],
help="The dimension of all deep layers.")
parser.add_argument("--deep_layer_act", type=str, default='relu',
help="The activation function of all deep layers.")
parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.")
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
parser.add_argument("--eval_file_name", type=str, default="eval.log")
parser.add_argument("--loss_file_name", type=str, default="loss.log")
parser.add_argument("--host_device_mix", type=int, default=0)
parser.add_argument("--dataset_type", type=str, default="tfrecord")
parser.add_argument("--parameter_server", type=int, default=0)
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="tfrecord/mindrecord/hd5")
parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not")
return parser
@ -48,6 +51,7 @@ class WideDeepConfig():
"""
WideDeepConfig
"""
def __init__(self):
self.device_target = "Ascend"
self.data_path = "./test_raw_data/"

Loading…
Cancel
Save