|
|
@ -462,10 +462,17 @@ class TestSeResnet(unittest.TestCase):
|
|
|
|
self.assertTrue(
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(dy_jit_pre, st_pre),
|
|
|
|
np.allclose(dy_jit_pre, st_pre),
|
|
|
|
msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre))
|
|
|
|
msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre))
|
|
|
|
self.assertTrue(
|
|
|
|
|
|
|
|
np.allclose(predictor_pre, st_pre),
|
|
|
|
flat_st_pre = st_pre.flatten()
|
|
|
|
msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(predictor_pre,
|
|
|
|
flat_predictor_pre = np.array(predictor_pre).flatten()
|
|
|
|
st_pre))
|
|
|
|
for i in range(len(flat_predictor_pre)):
|
|
|
|
|
|
|
|
# modify precision to 1e-6, avoid unittest failed
|
|
|
|
|
|
|
|
self.assertAlmostEqual(
|
|
|
|
|
|
|
|
flat_predictor_pre[i],
|
|
|
|
|
|
|
|
flat_st_pre[i],
|
|
|
|
|
|
|
|
delta=1e-6,
|
|
|
|
|
|
|
|
msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(
|
|
|
|
|
|
|
|
flat_predictor_pre[i], flat_st_pre[i]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_result(self):
|
|
|
|
def test_check_result(self):
|
|
|
|
pred_1, loss_1, acc1_1, acc5_1 = train(
|
|
|
|
pred_1, loss_1, acc1_1, acc5_1 = train(
|
|
|
|