Add unittests to check channelwise add

wangkuiyi-patch-1
fengjiayi 8 years ago
parent 238124909e
commit e819f1b3a5

@ -252,5 +252,25 @@ class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp):
self.axis = 1
class TestElementwiseAddOp_channelwise_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(3, 20, 20).astype(self.dtype)
self.y = np.random.rand(3, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(3, 10, 20).astype(self.dtype)
self.y = np.random.rand(3, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save