port^2
Yancey1989 7 years ago
parent e896926b9c
commit a158bd9173

@ -27,6 +27,7 @@ import signal
SEED = 1 SEED = 1
DTYPE = "float32" DTYPE = "float32"
paddle.dataset.mnist.fetch()
# random seed must set before configuring the network. # random seed must set before configuring the network.
@ -147,7 +148,7 @@ class TestDistMnist(unittest.TestCase):
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
def test_with_place(self): def test_with_place(self):
p = fluid.CUDAPlace() if core.is_compiled_with_cuda( p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
pserver_pid = self.start_pserver(self._ps_endpoints) pserver_pid = self.start_pserver(self._ps_endpoints)

Loading…
Cancel
Save