fix instance norm (#21042)

* fix instance norm

* update unitest,test=develop
custom_op_abi
ceci3 6 years ago committed by GitHub
parent 7041eb2170
commit f62a929151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -328,7 +328,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
epsilon, saved_mean_data, saved_var_data));
} else {
if (d_x) {
GradComputeDX<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D,
d_x->data<T>());

@ -79,7 +79,7 @@ class TestInstanceNormOpTraining(unittest.TestCase):
self.init_test_case()
def init_test_case(self):
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
self.no_grad_set = set()
self.fetch_list = [
'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD',
@ -181,12 +181,19 @@ class TestInstanceNormOpTraining(unittest.TestCase):
"instance_norm"):
places.append(core.CUDAPlace(0))
for place in places:
test_with_place(place, [2, 3, 4, 5])
test_with_place(place, self.shape)
class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
class TestInstanceNormOpTrainingCase2(TestInstanceNormOpTraining):
def init_test_case(self):
self.shape = [20, 50, 4, 5]
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']

Loading…
Cancel
Save