Optimize fleet API: add input check for some interfaces (#18971)
* fleet api add input check, test=developsigmoid_bug
parent
ed8f44ea21
commit
a25a716e87
@ -0,0 +1,208 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
|
||||
from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedRoleMaker
|
||||
from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedCollectiveRoleMaker
|
||||
from paddle.fluid.incubate.fleet.base.role_maker import Role
|
||||
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
|
||||
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer
|
||||
|
||||
|
||||
class DistributeTranspilerConfigTest(unittest.TestCase):
|
||||
def set_runtime_split_send_recv(self, config, value):
|
||||
config.runtime_split_send_recv = value
|
||||
|
||||
def set_sync_mode(self, config, value):
|
||||
config.sync_mode = value
|
||||
|
||||
def testConfig(self):
|
||||
config = DistributeTranspilerConfig()
|
||||
self.assertRaises(Exception, self.set_sync_mode, config, None)
|
||||
self.assertRaises(Exception, self.set_runtime_split_send_recv, config,
|
||||
None)
|
||||
self.assertRaises(Exception, self.set_runtime_split_send_recv, config,
|
||||
True)
|
||||
self.set_sync_mode(config, False)
|
||||
self.assertFalse(config.sync_mode)
|
||||
self.set_runtime_split_send_recv(config, True)
|
||||
self.assertRaises(Exception, self.set_sync_mode, config, True)
|
||||
|
||||
|
||||
class FleetTest(unittest.TestCase):
|
||||
def testInvalidInputs(self):
|
||||
self.assertRaises(Exception, fleet.split_files, "files")
|
||||
self.assertRaises(Exception, fleet.init, "pserver")
|
||||
|
||||
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
|
||||
hidden = fluid.layers.fc(input=data, size=10)
|
||||
loss = fluid.layers.mean(hidden)
|
||||
adam = fluid.optimizer.Adam()
|
||||
adam.minimize(loss)
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
pe = fluid.ParallelExecutor(use_cuda=False, loss_name=loss.name)
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
fleet.save_inference_model,
|
||||
dirname='/tmp/',
|
||||
feeded_var_names=['X'],
|
||||
target_vars=[loss],
|
||||
executor=pe)
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
fleet.save_inference_model,
|
||||
dirname='/tmp/',
|
||||
feeded_var_names=['X'],
|
||||
target_vars=[loss],
|
||||
executor="executor")
|
||||
compiled_prog = fluid.compiler.CompiledProgram(
|
||||
fluid.default_main_program())
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
fleet.save_inference_model,
|
||||
dirname='/tmp/',
|
||||
feeded_var_names=['X'],
|
||||
target_vars=[loss],
|
||||
executor=exe,
|
||||
main_program=compiled_prog)
|
||||
self.assertRaises(
|
||||
Exception, fleet.save_persistables, executor=pe, dirname='/tmp/')
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
fleet.save_persistables,
|
||||
executor="executor",
|
||||
dirname='/tmp/')
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
fleet.save_persistables,
|
||||
executor=exe,
|
||||
dirname='/tmp/',
|
||||
main_program=compiled_prog)
|
||||
self.assertRaises(Exception, fleet._transpile, "config")
|
||||
|
||||
|
||||
class TranspilerOptimizerTest(unittest.TestCase):
|
||||
def testInvalidInputs(self):
|
||||
self.assertRaises(Exception, TranspilerOptimizer, "Adam", None)
|
||||
self.assertRaises(Exception, TranspilerOptimizer,
|
||||
fluid.optimizer.Adam(0.001), "strategy")
|
||||
|
||||
transpiler = TranspilerOptimizer(fluid.optimizer.Adam(0.001))
|
||||
self.assertRaises(Exception, transpiler.minimize, loss=[])
|
||||
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
|
||||
hidden = fluid.layers.fc(input=data, size=10)
|
||||
loss = fluid.layers.mean(hidden)
|
||||
self.assertRaises(
|
||||
Exception, transpiler.minimize, loss=loss.name, startup_program=[])
|
||||
|
||||
|
||||
class UserDefinedRoleMakerTest(unittest.TestCase):
|
||||
def createRoleMaker(self,
|
||||
current_id=0,
|
||||
role=Role.WORKER,
|
||||
worker_num=1,
|
||||
server_endpoints=["127.0.0.1:8080"]):
|
||||
role = UserDefinedRoleMaker(current_id, role, worker_num,
|
||||
server_endpoints)
|
||||
|
||||
def testRoleMaker(self):
|
||||
self.createRoleMaker()
|
||||
## test all invalid server_endpoints
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
server_endpoints=None) # server_endpoints must be as list
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
server_endpoints=[]) # server_endpoints can't be empty
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker, server_endpoints=[
|
||||
3, []
|
||||
]) # element in server_endpoints must be as string
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
self.createRoleMaker,
|
||||
server_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"]
|
||||
) # element in server_endpoints can't be duplicate
|
||||
## test all invalid current_id
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
current_id="0") # current_id must be as int
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
current_id=-1) # current_id must be greater than or equal to 0
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
self.createRoleMaker,
|
||||
current_id=1,
|
||||
role=Role.SERVER,
|
||||
server_endpoints=["127.0.0.1:8080"]
|
||||
) # if role is server, current_id must be less than len(server_endpoints)
|
||||
## test all invalid worker_num
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
worker_num="1") # worker_num must be as int
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
worker_num=0) # worker_num must be greater than 0
|
||||
## test all invalid role
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
role=3) # role must be as Role(Role.WORKER=1, Role.SERVER=2)
|
||||
|
||||
|
||||
class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
|
||||
def createRoleMaker(self, current_id=0,
|
||||
worker_endpoints=["127.0.0.1:8080"]):
|
||||
role = UserDefinedCollectiveRoleMaker(current_id, worker_endpoints)
|
||||
|
||||
def testRoleMaker(self):
|
||||
self.createRoleMaker()
|
||||
## test all invalid worker_endpoints
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
worker_endpoints=None) # worker_endpoints must be as list
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
worker_endpoints=[]) # worker_endpoints can't be empty
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
worker_endpoints=[3,
|
||||
[]]) # element worker_endpoints must be as string
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
self.createRoleMaker,
|
||||
worker_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"]
|
||||
) # element in worker_endpoints can't be duplicate
|
||||
## test all invalid current_id
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
current_id="0") # current_id must be as int
|
||||
self.assertRaises(
|
||||
Exception, self.createRoleMaker,
|
||||
current_id=-1) # current_id must be greater than or equal to 0
|
||||
self.assertRaises(
|
||||
Exception,
|
||||
self.createRoleMaker,
|
||||
current_id=1,
|
||||
worker_endpoints=["127.0.0.1:8080"]
|
||||
) # current_id must be less than len(worker_endpoints)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue