@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections import defaultdict
from paddle . fluid . framework import Program
import framework
@ -818,8 +818,8 @@ class ModelAverage(Optimizer):
min_average_window , max_average_window and current update times .
Args :
params_grads : A list of parameter - grad variable pairs .
average_window_rate : The rate of average window .
params_grads : A list of parameter - grad variable pairs .
min_average_window : The minimum size of average window .
max_average_window : The maximum size of average window .
@ -840,8 +840,8 @@ class ModelAverage(Optimizer):
"""
def __init__ ( self ,
params_grads ,
average_window_rat e,
average_window_rate= 0.15 ,
params_grads= Non e,
min_average_window = 10000 ,
max_average_window = 10000 ,
* * kwargs ) :
@ -849,25 +849,21 @@ class ModelAverage(Optimizer):
self . average_window = average_window_rate
self . min_average_window = min_average_window
self . max_average_window = max_average_window
self . params_grads = params_grads
# append 'moving mean' and 'moving variance' to self.params_grads
pattern = re . compile ( r " batch_norm_ \ d+ \ .w_[1,2] " )
self . params_grads = [ ] if params_grads is None else params_grads
params = { }
for param , grad in self . params_grads :
params [ param . name ] = ( param , grad )
for param in framework . default_main_program ( ) . global_block (
) . all_parameters ( ) :
if pattern . match ( param . name ) is not None :
self . params_grads . append ( ( param , None ) )
# create a tmp gradient variable to backup parameter value
# for parameter whose grad is None
for i , param_grad in enumerate ( self . params_grads ) :
param , grad = param_grad
if grad is None :
if param . name not in params and param . average :
grad = param . block . create_var (
name = unique_name . generate ( " . " . join ( [ param . name , ' tmp ' ] ) ) ,
dtype = param . dtype ,
persistable = False ,
stop_gradient = stop_gradient )
self . params_grads [ i ] = ( param , grad )
stop_gradient = True )
params [ param . name ] = ( param , grad )
self . params_grads = params . values ( )
for param , grad in self . params_grads :
self . _append_average_accumulate_op ( param )