|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
import paddle.trainer.config_parser as cp
|
|
|
|
|
import struct
|
|
|
|
|
import tarfile
|
|
|
|
@ -62,7 +63,7 @@ class Parameters(object):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.__param_conf__ = dict()
|
|
|
|
|
self.__param_conf__ = OrderedDict()
|
|
|
|
|
self.__gradient_machines__ = []
|
|
|
|
|
self.__tmp_params__ = dict()
|
|
|
|
|
|
|
|
|
@ -231,6 +232,9 @@ class Parameters(object):
|
|
|
|
|
:rtype: np.ndarray
|
|
|
|
|
"""
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
|
if self.__param_conf__[key].is_static:
|
|
|
|
|
return np.zeros(self.__param_conf__[key].size, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
return self.__getter_inner(key, api.PARAMETER_GRADIENT)
|
|
|
|
|
|
|
|
|
|
def set(self, parameter_name, value):
|
|
|
|
|