Add cross_entropy loss to mnist ut

revert-15207-remove_op_handle_lock_and_fix_var
minqiyang 6 years ago
parent 7aab39af15
commit 0601f5c4ee

@ -125,8 +125,8 @@ class TestImperativeMnist(unittest.TestCase):
label._stop_gradient = True
cost = mnist(img)
# loss = fluid.layers.cross_entropy(cost)
avg_loss = fluid.layers.reduce_mean(cost)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
dy_out = avg_loss._numpy()
if batch_id == 0:
@ -156,8 +156,8 @@ class TestImperativeMnist(unittest.TestCase):
name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = mnist(img)
# loss = fluid.layers.cross_entropy(cost)
avg_loss = fluid.layers.reduce_mean(cost)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
sgd.minimize(avg_loss)
# initialize params and fetch them

Loading…
Cancel
Save