|
|
|
@ -32,6 +32,7 @@ feed_dict = {
|
|
|
|
|
class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
def initParameter(self):
|
|
|
|
|
self.use_cuda = True
|
|
|
|
|
self.fuse_all_optimizer_ops = False
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.initParameter()
|
|
|
|
@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
self.device_count = fluid.core.get_cuda_device_count()
|
|
|
|
|
else:
|
|
|
|
|
self.device_count = 4
|
|
|
|
|
|
|
|
|
|
assert batch_size % self.device_count == 0
|
|
|
|
|
|
|
|
|
|
def build_program_and_scope(self):
|
|
|
|
@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
|
build_strategy.memory_optimize = memory_optimize
|
|
|
|
|
build_strategy.enable_inplace = enable_inplace
|
|
|
|
|
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
|
|
|
|
|
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
|
|
|
|
|
loss_name=loss.name,
|
|
|
|
|
build_strategy=build_strategy,
|
|
|
|
@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
|
build_strategy.memory_optimize = memory_optimize
|
|
|
|
|
build_strategy.enable_inplace = enable_inplace
|
|
|
|
|
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
|
|
|
|
|
compiled_program = fluid.CompiledProgram(
|
|
|
|
|
prog).with_data_parallel(
|
|
|
|
|
loss_name=loss.name,
|
|
|
|
@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
|
|
|
|
|
class CPUInplaceTest(InplaceTestBase):
|
|
|
|
|
def initParameter(self):
|
|
|
|
|
self.use_cuda = False
|
|
|
|
|
self.fuse_all_optimizer_ops = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CUDAInplaceTestWithFuseOptimizationOps(InplaceTestBase):
|
|
|
|
|
def initParameter(self):
|
|
|
|
|
self.use_cuda = True
|
|
|
|
|
self.fuse_all_optimizer_ops = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CPUInplaceTestWithFuseOptimizationOps(InplaceTestBase):
|
|
|
|
|
def initParameter(self):
|
|
|
|
|
self.use_cuda = True
|
|
|
|
|
self.fuse_all_optimizer_ops = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|