From f1b09ba30ec4b77b6a4747432fecf75ac9168f0c Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Sun, 24 Nov 2019 16:49:29 +0800 Subject: [PATCH] adapt test_collective_base.py for only two GPU cards available. (#21307) * adapt test_collective_base.py for only two GPU cards available. test=develop * fix bug of issue #21259 test=develop --- python/paddle/fluid/input.py | 6 +++--- python/paddle/fluid/tests/unittests/test_collective_base.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py index e949939211..1506b9d4f2 100644 --- a/python/paddle/fluid/input.py +++ b/python/paddle/fluid/input.py @@ -104,16 +104,16 @@ def one_hot(input, depth, allow_out_of_range=False): if in_dygraph_mode(): inputs = {'X': input} - attrs = {'depth': depth} + attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} else: if not isinstance(depth, Variable): # user attribute inputs = {'X': input} - attrs = {'depth': depth} + attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} else: depth.stop_gradient = True inputs = {'X': input, 'depth_tensor': depth} - attrs = {} + attrs = {'allow_out_of_range': allow_out_of_range} helper.append_op( type="one_hot_v2", inputs=inputs, diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index e0789178b3..3f3a5642ab 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -163,7 +163,7 @@ class TestDistBase(unittest.TestCase): w0_ep, w1_ep = worker_endpoints #print("w0_ep:",w0_ep," w1_ep:",w1_ep) env0 = { - "FLAGS_selected_gpus": "2", + "FLAGS_selected_gpus": "0", "PADDLE_TRAINER_ID": "0", "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, @@ -171,7 +171,7 @@ class TestDistBase(unittest.TestCase): } env1 = { - "FLAGS_selected_gpus": "3", + "FLAGS_selected_gpus": "1", "PADDLE_TRAINER_ID": "1", "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,