test seresnext

wangkuiyi-patch-1
chengduoZH 7 years ago
parent 27073c284d
commit d09fd1f6f0

@ -130,7 +130,9 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
class TestResnet(TestParallelExecutorBase):
def check_resnet_convergence(self, balance_parameter_opt_between_cards):
def check_resnet_convergence(self,
balance_parameter_opt_between_cards,
use_cuda=True):
import functools
batch_size = 2
self.check_network_convergence(
@ -138,14 +140,17 @@ class TestResnet(TestParallelExecutorBase):
SE_ResNeXt50Small, batch_size=batch_size),
iter=20,
batch_size=batch_size,
use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
def test_resnet(self):
self.check_resnet_convergence(False)
self.check_resnet_convergence(False, use_cuda=True)
# self.check_resnet_convergence(False,use_cuda=False)
def test_resnet_with_new_strategy(self):
self.check_resnet_convergence(True)
self.check_resnet_convergence(True, use_cuda=True)
self.check_resnet_convergence(True, use_cuda=False)
if __name__ == '__main__':

Loading…
Cancel
Save