|
|
|
@ -130,7 +130,9 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestResnet(TestParallelExecutorBase):
|
|
|
|
|
def check_resnet_convergence(self, balance_parameter_opt_between_cards):
|
|
|
|
|
def check_resnet_convergence(self,
|
|
|
|
|
balance_parameter_opt_between_cards,
|
|
|
|
|
use_cuda=True):
|
|
|
|
|
import functools
|
|
|
|
|
batch_size = 2
|
|
|
|
|
self.check_network_convergence(
|
|
|
|
@ -138,14 +140,17 @@ class TestResnet(TestParallelExecutorBase):
|
|
|
|
|
SE_ResNeXt50Small, batch_size=batch_size),
|
|
|
|
|
iter=20,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def test_resnet(self):
|
|
|
|
|
self.check_resnet_convergence(False)
|
|
|
|
|
self.check_resnet_convergence(False, use_cuda=True)
|
|
|
|
|
# self.check_resnet_convergence(False,use_cuda=False)
|
|
|
|
|
|
|
|
|
|
def test_resnet_with_new_strategy(self):
|
|
|
|
|
self.check_resnet_convergence(True)
|
|
|
|
|
self.check_resnet_convergence(True, use_cuda=True)
|
|
|
|
|
self.check_resnet_convergence(True, use_cuda=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|