|
|
|
@ -1,8 +1,15 @@
|
|
|
|
|
import ctypes
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
|
|
|
|
|
lib = ctypes.cdll.LoadLibrary(path)
|
|
|
|
|
__lib__ = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_c_lib():
|
|
|
|
|
global __lib__
|
|
|
|
|
if __lib__ is None:
|
|
|
|
|
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
|
|
|
|
|
__lib__ = ctypes.cdll.LoadLibrary(path)
|
|
|
|
|
return __lib__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class client(object):
|
|
|
|
@ -11,8 +18,8 @@ class client(object):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
|
|
|
|
|
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
|
|
|
|
|
buf_size)
|
|
|
|
|
self.c = get_c_lib().paddle_new_etcd_master_client(
|
|
|
|
|
etcd_endpoints, timeout_sec, buf_size)
|
|
|
|
|
|
|
|
|
|
def request_save_model(self, trainer_id, block_ms):
|
|
|
|
|
"""request to save model
|
|
|
|
@ -32,10 +39,11 @@ class client(object):
|
|
|
|
|
saving the model, -1 if error happened.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
|
|
|
|
|
return get_c_lib().paddle_request_save_model(self.c, trainer_id,
|
|
|
|
|
block_ms)
|
|
|
|
|
|
|
|
|
|
def release(self):
|
|
|
|
|
lib.paddle_release_master_client(self.c)
|
|
|
|
|
get_c_lib().paddle_release_master_client(self.c)
|
|
|
|
|
self.c = None
|
|
|
|
|
|
|
|
|
|
def set_dataset(self, paths):
|
|
|
|
@ -45,7 +53,7 @@ class client(object):
|
|
|
|
|
for idx, path in enumerate(paths):
|
|
|
|
|
c_ptr = ctypes.c_char_p(path)
|
|
|
|
|
holder[idx] = c_ptr
|
|
|
|
|
lib.paddle_set_dataset(self.c, holder, len(paths))
|
|
|
|
|
get_c_lib().paddle_set_dataset(self.c, holder, len(paths))
|
|
|
|
|
|
|
|
|
|
def next_record(self):
|
|
|
|
|
"""gets next record for training
|
|
|
|
@ -56,7 +64,7 @@ class client(object):
|
|
|
|
|
"""
|
|
|
|
|
p = ctypes.c_char_p()
|
|
|
|
|
ret = ctypes.pointer(p)
|
|
|
|
|
size = lib.paddle_next_record(self.c, ret)
|
|
|
|
|
size = get_c_lib().paddle_next_record(self.c, ret)
|
|
|
|
|
if size < 0:
|
|
|
|
|
# Error
|
|
|
|
|
return None, size
|
|
|
|
@ -67,5 +75,5 @@ class client(object):
|
|
|
|
|
|
|
|
|
|
record = ret.contents.value[:size]
|
|
|
|
|
# Memory created from C should be freed.
|
|
|
|
|
lib.mem_free(ret.contents)
|
|
|
|
|
get_c_lib().mem_free(ret.contents)
|
|
|
|
|
return record, 0
|
|
|
|
|