@ -10,13 +10,15 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
""" Defination of Server and Worker. """
from . import ps_pb2 as pslib
class Server ( object ) :
"""
A Server basic class .
A Server basic class
it ' s a base class, does not have implementation
"""
def __init__ ( self ) :
@ -26,6 +28,7 @@ class Server(object):
class Worker ( object ) :
"""
A Worker basic class .
it ' s a base class, does not have implementation
"""
def __init__ ( self ) :
@ -169,7 +172,10 @@ class DownpourServer(Server):
"""
Args :
table_id ( int ) : id of sparse params table
strategy ( dict ) : the dense config dict .
param_var ( list ) : param vars
grad_var ( list ) : param grad vars
strategy ( dict ) : the dense config dict
sparse_table_names ( list ) : sparse table names
Returns :
return None
"""
@ -230,7 +236,11 @@ class DownpourServer(Server):
"""
Args :
table_id ( int ) : id of datanorm table
strategy ( dict ) : the datanorm config dict .
learning_rate ( float ) : the learning rate used to update parameters
param_var ( list ) : param vars
grad_var ( list ) : param grad vars
strategy ( dict ) : the datanorm config dict
sparse_table_names ( list ) : sparse table names
Returns :
return None
"""
@ -296,43 +306,60 @@ class DownpourWorker(Worker):
self . window = window
self . _worker = pslib . DownpourTrainerParameter ( )
def add_sparse_table ( self , table_id , slot_key_vars , slot_value_vars ) :
def add_sparse_table ( self ,
table_id ,
slot_key_vars ,
slot_value_vars ,
slot_value_grads = None ) :
"""
Args :
table_id ( int ) : id of sparse params table
slot_key_vars ( string ) : slot key id
slot_value_var ( string ) : slot key value after embedding
slot_key_vars ( list ) : slot key id
slot_value_vars ( list ) : slot key value after embedding
slot_value_grads ( list ) : grad of all params , default is None
Returns :
return None
"""
if slot_value_grads is None :
slot_value_grad_names = \
[ var . name + " @GRAD " for var in slot_value_vars ]
else :
value_to_key = { }
for i in range ( len ( slot_key_vars ) ) :
value_to_key [ slot_value_vars [ i ] . name ] = slot_key_vars [ i ]
slot_value_grad_names = [ ]
all_grad_names = [ var . name for var in slot_value_grads ]
for var in slot_value_vars :
if var . name + " @GRAD " in all_grad_names :
slot_value_grad_names . append ( var . name + " @GRAD " )
sorted_slot_value_vars = [ i for i in slot_value_vars if \
i . name + " @GRAD " in slot_value_grad_names ]
sorted_slot_value_vars + = [ i for i in slot_value_vars if \
i . name + " @GRAD " not in slot_value_grad_names ]
sorted_slot_key_vars = \
[ value_to_key [ v . name ] for v in sorted_slot_value_vars ]
target_table = None
for table in self . _worker . sparse_table :
if table . table_id == table_id :
if [ var . name for var in slot_key_vars
] == self . _worker . sparse_table [ table_id ] . slot_key :
if [ var . name for var in slot_value_vars
] == self . _worker . sparse_table [ table_id ] . slot_value :
if [
var . name + " @GRAD " for var in slot_value_vars
] == self . _worker . sparse_table [ table_id ] . slot_gradient :
return
else :
raise ValueError (
" sparse table %s slot_gradient error " %
table_id )
else :
raise ValueError ( " sparse table %s slot_value error " %
keys = self . _worker . sparse_table [ table_id ] . slot_key
key_names = [ var . name for var in sorted_slot_key_vars ]
for key_name in key_names :
if key_name not in keys :
raise ValueError ( " sparse table %s slot_key error " %
table_id )
else :
raise ValueError ( " sparse table %s slot_key error " %
table_id )
target_table = table
break
table = target_table
if table is not None :
self . _worker . sparse_table . remove ( table )
table = self . _worker . sparse_table . add ( )
table . table_id = table_id
table . slot_key . extend ( [ var . name for var in slot_key_vars ] )
table . slot_value . extend ( [ var . name for var in slot_value_vars ] )
table . slot_gradient . extend (
[ var . name + " @GRAD " for var in slot_value_vars ] )
table . slot_key . extend ( [ var . name for var in sorted_slot_key_vars ] )
table . slot_value . extend ( [ var . name for var in sorted_slot_value_vars ] )
table . slot_gradient . extend ( slot_value_grad_names )
def add_dense_table ( self , table_id , learning_rate , param_vars , grad_vars ,
dense_start_table_id , sparse_table_names ) :
@ -341,8 +368,10 @@ class DownpourWorker(Worker):
table_id ( int ) : id of sparse params table
learning_rate ( float ) : the learning rate used to update parameters . \
Can be a float value
param_var ( list ) : all dense param . it is a list .
grad_var ( list ) : all dense grad parm it is a list .
param_vars ( list ) : all dense param . it is a list .
grad_vars ( list ) : all dense grad parm it is a list .
dense_start_table_id ( int ) : dense table start index
sparse_table_names ( list ) : sparse table names
Returns :
return None
"""
@ -365,21 +394,19 @@ class DownpourWorker(Worker):
for table in self . _worker . dense_table :
if table . table_id == table_id :
desc_dense_param_name = list ( self . _worker . dense_table [
table_id - dense_start_table_id ] . dense_variable_name )
desc_dense_param_name = list ( table . dense_variable_name )
desc_dense_param_name . sort ( )
if dense_param_name == desc_dense_param_name :
desc_dense_grad_name = list ( self . _worker . dense_table [
table_id - dense_start_table_id ]
. dense_gradient_variable_name )
desc_dense_grad_name = list (
table . dense_gradient_variable_name )
desc_dense_grad_name . sort ( )
if dense_grad_name == desc_dense_grad_name :
return
else :
raise ValueError (
" dense table %s dense_gradient_variable_name error "
% table_id )
" dense table %s dense_gradient_variable_name "
" error " % table_id )
else :
raise ValueError (
" dense table %s dense_variable_name error " % table_id )