|
|
|
@ -79,7 +79,7 @@ class TestInstanceNormOpTraining(unittest.TestCase):
|
|
|
|
|
self.init_test_case()
|
|
|
|
|
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.use_global_stats = False
|
|
|
|
|
self.shape = [2, 3, 4, 5]
|
|
|
|
|
self.no_grad_set = set()
|
|
|
|
|
self.fetch_list = [
|
|
|
|
|
'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD',
|
|
|
|
@ -181,12 +181,19 @@ class TestInstanceNormOpTraining(unittest.TestCase):
|
|
|
|
|
"instance_norm"):
|
|
|
|
|
places.append(core.CUDAPlace(0))
|
|
|
|
|
for place in places:
|
|
|
|
|
test_with_place(place, [2, 3, 4, 5])
|
|
|
|
|
test_with_place(place, self.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.use_global_stats = False
|
|
|
|
|
self.shape = [2, 3, 4, 5]
|
|
|
|
|
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
|
|
|
|
|
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestInstanceNormOpTrainingCase2(TestInstanceNormOpTraining):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.shape = [20, 50, 4, 5]
|
|
|
|
|
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
|
|
|
|
|
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
|
|
|
|
|
|
|
|
|
|