@ -19,14 +19,19 @@ from . import core
import collections
import copy
import six
import logging
from . . import compat as cpt
from . import unique_name
from . import log_helper
__all__ = [
' append_backward ' ,
' gradients ' ,
]
_logger = log_helper . get_logger (
__name__ , logging . INFO , fmt = ' %(asctime)s - %(levelname)s : %(message)s ' )
class ProgramStats ( object ) :
def __init__ ( self , block , ops ) :
@ -38,7 +43,7 @@ class ProgramStats(object):
def get_input_nodes ( self ) :
input_names = [ ]
for name in self . var_op_deps :
if len ( self . var_op_deps [ name ] [ " var_as_output_ops " ] ) < = 0 and \
if len ( self . var_op_deps [ name ] [ " var_as_output_ops " ] ) = = 0 and \
len ( self . var_op_deps [ name ] [ " var_as_input_ops " ] ) > 0 :
if self . block . var ( name ) . persistable :
continue
@ -115,6 +120,22 @@ class ProgramStats(object):
for op_idx in self . op_deps [ i ] [ " in_ops " ] :
self . op_deps [ op_idx ] [ " out_ops " ] . extend ( [ i ] )
def sort_checkpoints ( self , checkpoints_name ) :
sorted_checkpoints = [ ]
for name in checkpoints_name :
if name not in self . var_op_deps :
_logger . debug (
" Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program. "
% name )
elif self . var_op_deps [ name ] [ " var_as_output_ops " ] == [ ] :
# input nodes
sorted_checkpoints . append ( ( name , - 1 ) )
else :
sorted_checkpoints . append (
( name , max ( self . var_op_deps [ name ] [ " var_as_output_ops " ] ) ) )
sorted_checkpoints = sorted ( sorted_checkpoints , key = lambda x : x [ 1 ] )
return [ x [ 0 ] for x in sorted_checkpoints ]
def _pretty_op_desc_ ( op_desc , prefix ) :
out_s = " %s \t name:[ %s ] \n %s \t inputs:[ %s ] \n %s \t outputs:[ %s ] " % \
@ -584,15 +605,17 @@ def _append_backward_ops_with_checkpoints_(
"""
checkpoints_name = [ x . name for x in checkpoints ]
checkpoints_name = list ( set ( checkpoints_name ) )
local_block = block . program . _create_block ( )
buffer_block = block . program . _create_block ( )
# 1) find ops between checkpoints, i.e. recompute_segments
program_stat = ProgramStats ( block , ops )
program_stat . build_stats ( )
checkpoints_name = program_stat . sort_checkpoints ( checkpoints_name )
segments = [ ]
if len ( checkpoints ) == 1 :
if len ( checkpoints _name ) == 1 :
# only one checkpoint
max_op_idx = - 1
var_group = [ checkpoints_name [ 0 ] ]
@ -616,8 +639,6 @@ def _append_backward_ops_with_checkpoints_(
segments . append ( [ min_idx , max_idx + 1 ] )
start_idx + = 1
checkpoints_name = list ( set ( checkpoints_name ) )
if segments != [ ] and segments [ 0 ] [ 0 ] != 0 :
recompute_segments = [ [ 0 , segments [ 0 ] [ 0 ] ] ] + segments
else :
@ -625,7 +646,7 @@ def _append_backward_ops_with_checkpoints_(
# 2) go through all forward ops and induct all variables that will be hold in memory
vars_should_be_hold = [ ]
# a. variables that are used across segments will be held in memory
# a. variables that are used across segments will be held in memory
for segment in recompute_segments :
vars_should_be_hold . extend (
program_stat . get_out_of_subgraph_vars ( segment [ 0 ] , segment [ 1 ] ) )
@ -635,10 +656,6 @@ def _append_backward_ops_with_checkpoints_(
vars_should_be_hold . extend ( program_stat . get_input_nodes ( ) )
vars_should_be_hold = list ( set ( vars_should_be_hold ) )
# find variables that can not be deleted
grad_should_be_hold = [ x + " @GRAD " for x in vars_should_be_hold ]
vars_should_be_hold . extend ( grad_should_be_hold )
# 3) go through each recompute_segments, add backward ops with forward recomputation
grad_op_descs = [ ]
var_name_dict = { }
@ -647,7 +664,7 @@ def _append_backward_ops_with_checkpoints_(
max_calculated_op_position = len ( ops )
if recompute_segments == [ ] :
# if there is no recompute segment, add backward ops like
# if there is no recompute segment, add backward ops like
# _append_backward_ops_ function
gap_ops = ops [ 0 : max_calculated_op_position ]
for op in reversed ( gap_ops ) :