You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/v2/master/client.py

96 lines
3.0 KiB

# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import os
__lib__ = None
def get_c_lib():
8 years ago
global __lib__
if __lib__ is None:
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
__lib__ = ctypes.cdll.LoadLibrary(path)
return __lib__
class client(object):
"""
client is a client to the master server.
"""
def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = get_c_lib().paddle_new_etcd_master_client(
etcd_endpoints, timeout_sec, buf_size)
def request_save_model(self, trainer_id, block_ms):
"""request to save model
Conventionally the 0-th trainer will save model. But in
distributed training, any trainer could be killed. This
function asks the master server if the trainer should proceed
with saving model.
:param trainer_id: trainer id.
:param block_ms: number of millisecond that other save model
will be blocked if this save model request succeeded.
Returns:
int: 1 if the save the model request is approved, 0 if
does the request is rejected because other trainer is
saving the model, -1 if error happened.
"""
return get_c_lib().paddle_request_save_model(self.c, trainer_id,
block_ms)
def release(self):
get_c_lib().paddle_release_master_client(self.c)
self.c = None
def set_dataset(self, paths):
holder_type = ctypes.c_char_p * len(paths)
holder = holder_type()
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
get_c_lib().paddle_set_dataset(self.c, holder, len(paths))
def next_record(self):
"""gets next record for training
Returns:
string: the record.
int: error code, 0 if successful, < 0 otherwise.
"""
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = get_c_lib().paddle_next_record(self.c, ret)
if size < 0:
8 years ago
# Error
return None, size
if size == 0:
8 years ago
# Empty record
8 years ago
return "", 0
record = ret.contents.value[:size]
8 years ago
# Memory created from C should be freed.
get_c_lib().mem_free(ret.contents)
8 years ago
return record, 0
def paddle_start_get_records(self, pass_id):
get_c_lib().paddle_start_get_records(self.c, pass_id)