|
|
|
@ -115,7 +115,13 @@ class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
fetch_val2, = exe.run(compiled_prog,
|
|
|
|
|
feed=feed_dict,
|
|
|
|
|
fetch_list=[fetch_var])
|
|
|
|
|
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.array_equal(fetch_val1, fetch_val2),
|
|
|
|
|
"error var name: {}, fetch_val1: {}, fetch_val2: {}".
|
|
|
|
|
format(
|
|
|
|
|
fetch_var,
|
|
|
|
|
fetch_val1[~np.equal(fetch_val1, fetch_val2)],
|
|
|
|
|
fetch_val2[~np.equal(fetch_val1, fetch_val2)]))
|
|
|
|
|
|
|
|
|
|
def check_multi_card_fetch_var(self):
|
|
|
|
|
if self.is_invalid_test():
|
|
|
|
@ -160,6 +166,12 @@ class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
for item in fetch_vals:
|
|
|
|
|
self.assertTrue(np.array_equal(fetch_vals[0], item))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.array_equal(fetch_vals[0], item),
|
|
|
|
|
"error var name: {}, fetch_vals[0]: {}, item: {}".
|
|
|
|
|
format(fetch_var,
|
|
|
|
|
fetch_vals[0][~np.equal(fetch_vals[0], item)],
|
|
|
|
|
item[~np.equal(fetch_vals[0], item)]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CUDAInplaceTest(InplaceTestBase):
|
|
|
|
|