@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum";
const char kOptimizerType [ ] = " sgd " ;
void LockFreeOptimizePass : : ApplyImpl ( ir : : Graph * graph ) const {
PADDLE_ENFORCE ( graph ) ;
PADDLE_ENFORCE_NOT_NULL (
graph , platform : : errors : : InvalidArgument ( " Graph cannot be nullptr. " ) ) ;
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
for ( auto * node : graph - > Nodes ( ) ) {
if ( IsOpNamed ( node , kOptimizerType ) ) {
auto & param_out_vars = node - > Op ( ) - > Output ( " ParamOut " ) ;
PADDLE_ENFORCE ( param_out_vars . size ( ) = = 1u ) ;
PADDLE_ENFORCE_EQ (
param_out_vars . size ( ) , 1u ,
platform : : errors : : InvalidArgument (
" In op(%s), find output(ParamOut) failed. " , node - > Name ( ) ) ) ;
weight_var_set . insert ( param_out_vars [ 0 ] ) ;
}
}
@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
VLOG ( 3 ) < < " Found forward_op " < < forward_op - > Name ( ) ;
PADDLE_ENFORCE ( forward_op ) ;
PADDLE_ENFORCE_NOT_NULL (
forward_op , platform : : errors : : NotFound (
" Can not find forward op for backword op(%s). " ,
backward_op - > Name ( ) ) ) ;
Node * new_optimizer_node = CreateNewSGDNode (
graph , forward_op , backward_op , node , opt_node ) ;
PADDLE_ENFORCE ( new_optimizer_node ) ;
PADDLE_ENFORCE_NOT_NULL (
new_optimizer_node ,
platform : : errors : : InvalidArgument (
" Create new SGD node failed, backward op is %s. " ,
backward_op - > Name ( ) ) ) ;
}
}
}
@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
ir : : Node * LockFreeOptimizePass : : CreateNewSGDNode (
ir : : Graph * graph , ir : : Node * forward_node , ir : : Node * backward_node ,
ir : : Node * grad_sum_node , ir : : Node * optimize_node ) const {
PADDLE_ENFORCE ( graph ) ;
PADDLE_ENFORCE ( forward_node ) ;
PADDLE_ENFORCE ( backward_node ) ;
PADDLE_ENFORCE ( grad_sum_node ) ;
PADDLE_ENFORCE ( optimize_node ) ;
PADDLE_ENFORCE_NOT_NULL ( graph ,
platform : : errors : : InvalidArgument (
" Input argument graph cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
forward_node , platform : : errors : : InvalidArgument (
" Input argument forward_node cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
backward_node , platform : : errors : : InvalidArgument (
" Input argument backward_node cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
grad_sum_node , platform : : errors : : InvalidArgument (
" Input argument grad_sum_node cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
optimize_node , platform : : errors : : InvalidArgument (
" Input argument optimize_node cannot be nullptr. " ) ) ;
// find the grad var node between the grad sum node and backward_node
std : : vector < ir : : Node * > grad_vars =
@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
grad_node = node ;
}
}
PADDLE_ENFORCE ( grad_node ) ;
PADDLE_ENFORCE_NOT_NULL ( grad_node , platform : : errors : : NotFound (
" Can not find control dep variable. " ) ) ;
// create a new SGD node
OpDesc * old_desc = optimize_node - > Op ( ) ;
@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
}
// SGD must have only one param and LR in
PADDLE_ENFORCE ( old_desc - > Input ( " LearningRate " ) . size ( ) = = 1u ) ;
PADDLE_ENFORCE ( old_desc - > Input ( " Param " ) . size ( ) = = 1u ) ;
PADDLE_ENFORCE_EQ (
old_desc - > Input ( " LearningRate " ) . size ( ) , 1u ,
platform : : errors : : InvalidArgument (
" In op(%s), find input(LearningRate) failed. " , old_desc - > Type ( ) ) ) ;
PADDLE_ENFORCE_EQ (
old_desc - > Input ( " Param " ) . size ( ) , 1u ,
platform : : errors : : InvalidArgument ( " In op(%s), find input(Param) failed. " ,
old_desc - > Type ( ) ) ) ;
// LR and weight nodes should be copied
for ( Node * upstream_node : optimize_node - > inputs ) {
@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
void LockFreeOptimizePass : : ReplaceUpstreamNode (
ir : : Node * upstream_node , ir : : Node * old_optimizer_node ,
ir : : Node * new_optimizer_node ) const {
PADDLE_ENFORCE ( upstream_node ) ;
PADDLE_ENFORCE ( old_optimizer_node ) ;
PADDLE_ENFORCE ( new_optimizer_node ) ;
PADDLE_ENFORCE_NOT_NULL (
upstream_node , platform : : errors : : InvalidArgument (
" Input argument upstream_node cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
old_optimizer_node ,
platform : : errors : : InvalidArgument (
" Input argument old_optimizer_node cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
new_optimizer_node ,
platform : : errors : : InvalidArgument (
" Input argument new_optimizer_node cannot be nullptr. " ) ) ;
// Remove the old_optimizer_node from upstream_node's outputs vector
auto & output_node_vec = upstream_node - > outputs ;
@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode(
void LockFreeOptimizePass : : ReplaceAllDownstreamNode (
ir : : Node * old_optimizer_node , ir : : Node * new_optimizer_node ) const {
PADDLE_ENFORCE ( old_optimizer_node ) ;
PADDLE_ENFORCE ( new_optimizer_node ) ;
PADDLE_ENFORCE_NOT_NULL (
old_optimizer_node ,
platform : : errors : : InvalidArgument (
" Input argument old_optimizer_node cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
new_optimizer_node ,
platform : : errors : : InvalidArgument (
" Input argument new_optimizer_node cannot be nullptr. " ) ) ;
for ( ir : : Node * downstream_node : old_optimizer_node - > outputs ) {
// Remove the old_optimizer_node from downstream_node's inputs vector
@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir : : Node * LockFreeOptimizePass : : FindForwardOpViaBackwardOp (
ir : : Graph * graph , ir : : Node * backward_node ) const {
PADDLE_ENFORCE ( graph ) ;
PADDLE_ENFORCE ( backward_node ) ;
PADDLE_ENFORCE_NOT_NULL ( graph ,
platform : : errors : : InvalidArgument (
" Input argument graph cannot be nullptr. " ) ) ;
PADDLE_ENFORCE_NOT_NULL (
backward_node , platform : : errors : : InvalidArgument (
" Input argument backward_node cannot be nullptr. " ) ) ;
// strip the suffix _grad of backward_node's name
std : : string forward_op_name = backward_node - > Name ( ) ;