|
|
|
@ -50,6 +50,8 @@ if __name__ == "__main__":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
|
|
|
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
|
|
|
|
cfg.batch_size)
|
|
|
|
|
if ds_train.get_dataset_size() == 0:
|
|
|
|
|
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
|
|
|
|
|
|
|
|
|
network = LeNet5(cfg.num_classes)
|
|
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
|
|
|