Merge pull request #9051 from panyx0718/profiler

Add a test to ensure profiler works on multi-gpu
shanyi15-patch-2
Yibing Liu 7 years ago committed by GitHub
commit a4b0e4a196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,6 +15,7 @@
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import numpy import numpy
@ -60,20 +61,23 @@ class BaseParallelForTest(unittest.TestCase):
feed=feed, feed=feed,
fetch=fetch, fetch=fetch,
place=gpu, place=gpu,
use_parallel=False) use_parallel=False,
use_gpu=True)
result_gpu_parallel = self._run_test_impl_( result_gpu_parallel = self._run_test_impl_(
callback=callback, callback=callback,
feed=feed, feed=feed,
fetch=fetch, fetch=fetch,
place=gpu, place=gpu,
use_parallel=True) use_parallel=True,
use_gpu=True)
result_gpu_nccl = self._run_test_impl_( result_gpu_nccl = self._run_test_impl_(
callback=callback, callback=callback,
feed=feed, feed=feed,
fetch=fetch, fetch=fetch,
place=gpu, place=gpu,
use_parallel=True, use_parallel=True,
use_nccl=True) use_nccl=True,
use_gpu=True)
self._assert_same_(fetch, result_cpu, result_cpu_parallel, self._assert_same_(fetch, result_cpu, result_cpu_parallel,
result_gpu, result_gpu_parallel, result_gpu_nccl) result_gpu, result_gpu_parallel, result_gpu_nccl)
else: else:
@ -85,7 +89,8 @@ class BaseParallelForTest(unittest.TestCase):
fetch, fetch,
place, place,
use_parallel=False, use_parallel=False,
use_nccl=False): use_nccl=False,
use_gpu=False):
""" """
Run a single test, returns the fetch values Run a single test, returns the fetch values
Args: Args:
@ -132,7 +137,12 @@ class BaseParallelForTest(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
return exe.run(main, feed=feed, fetch_list=fetch) if use_gpu:
profile_type = 'GPU'
else:
profile_type = 'CPU'
with profiler.profiler(profile_type, 'total', '/tmp/profiler'):
return exe.run(main, feed=feed, fetch_list=fetch)
def _assert_same_(self, fetch, *args): def _assert_same_(self, fetch, *args):
""" """

Loading…
Cancel
Save