|
|
|
@ -19,7 +19,7 @@ import contextlib
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_simulator(use_cuda, test_batch_size=10):
|
|
|
|
|
def train_simulator(test_batch_size=10):
|
|
|
|
|
if test_batch_size <= 0:
|
|
|
|
|
raise ValueError("batch_size should be a positive integeral value, "
|
|
|
|
|
"but got batch_size={}".format(test_batch_size))
|
|
|
|
@ -34,14 +34,7 @@ def train_simulator(use_cuda, test_batch_size=10):
|
|
|
|
|
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
|
|
|
|
|
sgd_optimizer.minimize(avg_cost)
|
|
|
|
|
|
|
|
|
|
train_reader = paddle.batch(
|
|
|
|
|
paddle.reader.shuffle(
|
|
|
|
|
paddle.dataset.uci_housing.train(), buf_size=500),
|
|
|
|
|
batch_size=test_batch_size)
|
|
|
|
|
|
|
|
|
|
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
|
|
|
|
|
# Calculate memory usage in current network config
|
|
|
|
|
lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
|
|
|
|
|
fluid.default_main_program(), batch_size=test_batch_size)
|
|
|
|
|
|
|
|
|
@ -50,21 +43,17 @@ def train_simulator(use_cuda, test_batch_size=10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMemoryUsage(unittest.TestCase):
|
|
|
|
|
def test_cpu(self):
|
|
|
|
|
with self.program_scope_guard():
|
|
|
|
|
train_simulator(use_cuda=False)
|
|
|
|
|
|
|
|
|
|
def test_cpu_with_unit_KB(self):
|
|
|
|
|
def test_with_unit_B(self):
|
|
|
|
|
with self.program_scope_guard():
|
|
|
|
|
train_simulator(use_cuda=False, test_batch_size=1000)
|
|
|
|
|
train_simulator()
|
|
|
|
|
|
|
|
|
|
def test_cpu_with_unit_MB(self):
|
|
|
|
|
def test_with_unit_KB(self):
|
|
|
|
|
with self.program_scope_guard():
|
|
|
|
|
train_simulator(use_cuda=False, test_batch_size=100000)
|
|
|
|
|
train_simulator(test_batch_size=1000)
|
|
|
|
|
|
|
|
|
|
def test_cuda(self):
|
|
|
|
|
def test_with_unit_MB(self):
|
|
|
|
|
with self.program_scope_guard():
|
|
|
|
|
train_simulator(use_cuda=True)
|
|
|
|
|
train_simulator(test_batch_size=100000)
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def program_scope_guard(self):
|
|
|
|
|