You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/v2/parameters.py

112 lines
3.1 KiB

import numpy as np
from paddle.proto.ModelConfig_pb2 import ModelConfig
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
__all__ = ['IParameterPool', 'create', 'ParameterFlag']
class ParameterFlag(object):
"""
The flag for IParameterPool.get_parameter. If writeable, operation on return
numpy array will also apply to Paddle parameter. But it will be slower in
GPU mode.
"""
READ_ONLY = 0x01
WRITE_ONLY = 0x02
READ_WRITE = READ_ONLY | WRITE_ONLY
class IParameterPool(object):
"""
Interface of Parameter Pool. The parameter pool is a dictionary of
parameters. User can modify parameter or customize parameter value
by `get_parameter`.
.. code-block:: python
pool = paddle.parameters.create(topo1, topo2)
embedding = pool.get_parameter("embedding")
assert isinstance(embedding, numpy.ndarray)
print embedding[1:]
"""
def get_parameter(self, name, flag=ParameterFlag.READ_WRITE):
"""
Get a parameter by name.
:param name: parameter name.
:type name: basestring
:param flag: the flag for return value. readable or writable.
:type flag: int
:return: The parameter value
:rtype: np.ndarray
"""
raise NotImplementedError()
def get_names(self):
"""
Get all parameter names
:return: all parameter names
:rtype: list
"""
raise NotImplementedError()
class NumpyParameterPool(IParameterPool):
def __init__(self):
self.__param_configs__ = dict()
self.__params__ = dict()
def append(self, conf):
if not isinstance(conf, ParameterConfig):
raise ValueError("conf must be ParameterConfig")
if not conf.IsInitialized():
raise ValueError("conf is not initialized")
self.__param_configs__[conf.name] = conf
self.__params__[conf.name] = None
def get_config(self, name):
if name not in self.__param_configs__:
raise ValueError("parameter %s is not appended" % name)
return self.__param_configs__[name]
def get_parameter(self, name, *args, **kwargs):
if name not in self.__params__:
raise ValueError("parameter %s is not appended" % name)
param = self.__params__[name]
if param is None:
shape = self.__param_configs__[name].dims
if len(shape) == 0:
raise ValueError("parameter %s is no shape" % name)
param = np.ndarray(
shape=[int(item) for item in shape], dtype='float32')
self.__params__[name] = param
return param
def get_names(self):
return self.__param_configs__.keys()
def create(*topologies):
"""
Create parameter pool by topologies.
:param topologies:
:return:
"""
pool = NumpyParameterPool()
for topo in topologies:
if not isinstance(topo, ModelConfig):
raise ValueError(
'create must pass a topologies which type is ModelConfig')
for param in topo.parameters:
pool.append(param)
return pool