|
|
|
@ -165,6 +165,7 @@ class TestWeightDecay(unittest.TestCase):
|
|
|
|
|
for place in get_places():
|
|
|
|
|
loss = self.check_weight_decay(place, model, use_parallel_exe=False)
|
|
|
|
|
|
|
|
|
|
# TODO(zcd): should test use_reduce=True
|
|
|
|
|
loss2 = self.check_weight_decay(
|
|
|
|
|
place, model, use_parallel_exe=True, use_reduce=False)
|
|
|
|
|
|
|
|
|
@ -175,16 +176,6 @@ class TestWeightDecay(unittest.TestCase):
|
|
|
|
|
"Expect " + str(loss[i]) + "\n" + "But Got" + str(loss2[i])
|
|
|
|
|
+ " in class " + self.__class__.__name__)
|
|
|
|
|
|
|
|
|
|
loss3 = self.check_weight_decay(
|
|
|
|
|
place, model, use_parallel_exe=True, use_reduce=True)
|
|
|
|
|
|
|
|
|
|
for i in range(len(loss)):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.isclose(
|
|
|
|
|
a=loss[i], b=loss3[i], rtol=5e-5),
|
|
|
|
|
"Expect " + str(loss[i]) + "\n" + "But Got" + str(loss2[i])
|
|
|
|
|
+ " in class " + self.__class__.__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|