|
|
|
@ -79,10 +79,10 @@ def plot_metric(metric,
|
|
|
|
|
plt.figure()
|
|
|
|
|
plt.title(graph_title)
|
|
|
|
|
if line_num == 1:
|
|
|
|
|
plt.plot(batch_id, metric, line_style, line_label)
|
|
|
|
|
plt.plot(batch_id, metric, line_style, label=line_label)
|
|
|
|
|
else:
|
|
|
|
|
for i in range(line_num):
|
|
|
|
|
plt.plot(batch_id, metric[i], line_style[i], line_label[i])
|
|
|
|
|
plt.plot(batch_id, metric[i], line_style[i], label=line_label[i])
|
|
|
|
|
plt.xlabel('batch')
|
|
|
|
|
plt.ylabel(graph_title)
|
|
|
|
|
plt.legend()
|
|
|
|
@ -102,12 +102,12 @@ def main():
|
|
|
|
|
accuracy_sample = sample(accuracy, args.sample_rate)
|
|
|
|
|
|
|
|
|
|
plot_metric(loss_sample, batch_sample, 'loss', line_label='loss')
|
|
|
|
|
plot_metric(accuracy_sample,
|
|
|
|
|
batch_sample,
|
|
|
|
|
'accuracy',
|
|
|
|
|
line_style='g-',
|
|
|
|
|
line_label='accuracy')
|
|
|
|
|
|
|
|
|
|
plot_metric(
|
|
|
|
|
accuracy_sample,
|
|
|
|
|
batch_sample,
|
|
|
|
|
'accuracy',
|
|
|
|
|
line_style='g-',
|
|
|
|
|
line_label='accuracy')
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|
|
|
|
|