@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License . */
# include "paddle/fluid/operators/controlflow/conditional_block_op.h"
# include "paddle/fluid/operators/assign_op.h"
namespace paddle {
namespace operators {
@ -58,13 +59,15 @@ class ConditionalBlockOp : public ConditionalOp {
scopes - > resize ( 1 ) ;
scopes - > front ( ) = & scope . NewScope ( ) ;
auto & cur_scope = * scopes - > front ( ) ;
framework : : Executor exec ( dev_place ) ;
auto * block = Attr < framework : : BlockDesc * > ( " sub_block " ) ;
VLOG ( 3 ) < < " Conditional block.idx = " < < block - > ID ( )
< < " , scope = " < < & cur_scope ;
auto & skip_vars =
Attr < std : : vector < std : : string > > ( ConditionalOp : : kSkipEagerDeletionVars ) ;
exec . Run ( * block - > Program ( ) , & cur_scope , block - > ID ( ) , false , true ,
skip_vars ) ;
skip_vars , /* force_disable_gc */ false ,
/* keep_kid_scopes */ true ) ;
}
}
} ;
@ -92,60 +95,65 @@ class ConditionalBlockGradOp : public ConditionalOp {
}
if ( need_run ) {
const auto & inputs = Inputs ( ConditionalOp : : kInputs ) ;
const auto & outside_grads =
Outputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ;
std : : vector < std : : string > inside_grads ;
inside_grads . reserve ( inputs . size ( ) ) ;
for ( auto & in : inputs ) {
inside_grads . emplace_back ( framework : : GradVarName ( in ) ) ;
}
auto * scope_var = scope . FindVar ( Input ( ConditionalOp : : kScope ) ) ;
PADDLE_ENFORCE ( scope_var ! = nullptr , " Must set scope " ) ;
PADDLE_ENFORCE_NE ( scope_var , nullptr ,
platform : : errors : : InvalidArgument (
" Scope must be set in conditional block op " ) ) ;
auto & scopes = scope_var - > Get < std : : vector < framework : : Scope * > > ( ) ;
PADDLE_ENFORCE_GT ( scopes . size ( ) , 0 ,
platform : : errors : : InvalidArgument (
" Scope must be set in conditional block op " ) ) ;
framework : : Scope & cur_scope = * scopes [ 0 ] ;
framework : : Executor exec ( dev_place ) ;
auto * block = Attr < framework : : BlockDesc * > ( " sub_block " ) ;
const auto & ins = Inputs ( ConditionalOp : : kInputs ) ;
const auto & d_ins =
Outputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ;
const auto & conds = Inputs ( ConditionalOp : : kCondition ) ;
const auto & d_conds =
Outputs ( framework : : GradVarName ( ConditionalOp : : kCondition ) ) ;
std : : vector < std : : string > ins_conds_grads ;
ins_conds_grads . reserve ( ins . size ( ) + conds . size ( ) ) ;
for ( auto & in : ins ) {
ins_conds_grads . emplace_back ( framework : : GradVarName ( in ) ) ;
}
for ( auto & cond : conds ) {
ins_conds_grads . emplace_back ( framework : : GradVarName ( cond ) ) ;
}
VLOG ( 3 ) < < " Conditional Grad block.idx = " < < block - > ID ( )
< < " , scope = " < < & cur_scope ;
exec . Run ( * block - > Program ( ) , & cur_scope , block - > ID ( ) , false , true ,
ins_conds_grads ) ;
AssignLocalGradientToGlobal ( dev_place , cur_scope , ins_conds_grads . data ( ) ,
ins . size ( ) , d_ins ) ;
inside_grads , /* force_disable_gc */ false ,
/* keep_kid_scopes */ false ) ;
AssignLocalGradientToGlobal ( dev_place , cur_scope ,
ins_conds_grads . data ( ) + ins . size ( ) ,
conds . size ( ) , d_conds ) ;
AssignLocalGradientToParentScope ( dev_place , cur_scope , scope ,
inside_grads , outside_grads ) ;
}
}
private :
void AssignLocalGradientTo Global (
void AssignLocalGradientToParentScope (
const platform : : Place & place , const framework : : Scope & cur_scope ,
const std : : string * p_grad_names , size_t p_grad_names_num ,
const std : : vector < std : : string > & pg_names ) const {
for ( size_t i = 0 ; i < p_grad_names_num ; + + i ) {
auto out_grad_name = pg_names [ i ] ;
const auto & in_grad_name = p_grad_names [ i ] ;
auto * in_var = cur_scope . FindVar ( in_grad_name ) ;
if ( in_var = = nullptr ) {
const framework : : Scope & parent_scope ,
const std : : vector < std : : string > & inside_grads ,
const std : : vector < std : : string > & outside_grads ) const {
for ( size_t i = 0 ; i < outside_grads . size ( ) ; + + i ) {
const std : : string & outside_grad_name = outside_grads [ i ] ;
const std : : string & inside_grad_name = inside_grads [ i ] ;
VLOG ( 4 ) < < " inside_grad_name = " < < inside_grad_name
< < " , outside_grad_name = " < < outside_grad_name ;
framework : : Variable * inside_var =
cur_scope . FindLocalVar ( inside_grad_name ) ;
if ( inside_var = = nullptr ) {
continue ;
}
auto new_in_grad_name = cur_scope . Rename ( in_grad_name ) ;
auto assign = framework : : OpRegistry : : CreateOp (
" assign " , { { " X " , { new_in_grad_name } } } , { { " Out " , { out_grad_name } } } ,
framework : : AttributeMap { } ) ;
assign - > Run ( cur_scope , place ) ;
cur_scope . Rename ( new_in_grad_name , in_grad_name ) ;
framework : : Variable * outside_var =
parent_scope . FindVar ( outside_grad_name ) ;
if ( outside_var = = nullptr ) {
continue ;
}
platform : : DeviceContext * dev_ctx =
platform : : DeviceContextPool : : Instance ( ) . Get ( place ) ;
framework : : VisitVarType ( * inside_var ,
AssignFunctor ( outside_var , * dev_ctx ) ) ;
}
}
} ;
@ -154,17 +162,11 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public :
void operator ( ) ( framework : : InferShapeContext * context ) const override {
PADDLE_ENFORCE ( context - > HasInputs ( ConditionalOp : : kCondition ) ) ;
if ( context - > HasInputs ( ConditionalOp : : kInputs ) ) {
PADDLE_ENFORCE (
context - > HasOutputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ) ;
if ( context - > HasInputs ( ConditionalOp : : kInputs ) & &
context - > HasOutputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ) {
context - > SetOutputsDim ( framework : : GradVarName ( ConditionalOp : : kInputs ) ,
context - > GetInputsDim ( ConditionalOp : : kInputs ) ) ;
}
if ( context - > HasOutputs (
framework : : GradVarName ( ConditionalOp : : kCondition ) ) ) {
context - > SetOutputsDim ( framework : : GradVarName ( ConditionalOp : : kCondition ) ,
context - > GetInputsDim ( ConditionalOp : : kCondition ) ) ;
}
}
} ;
@ -187,8 +189,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
this - > OutputGrad ( ConditionalOp : : kOutputs ) ) ;
grad_op - > SetInput ( ConditionalOp : : kScope ,
this - > Output ( ConditionalOp : : kScope ) ) ;
grad_op - > SetOutput ( framework : : GradVarName ( ConditionalOp : : kCondition ) ,
this - > InputGrad ( ConditionalOp : : kCondition , false ) ) ;
grad_op - > SetOutput ( framework : : GradVarName ( ConditionalOp : : kInputs ) ,
this - > InputGrad ( ConditionalOp : : kInputs , false ) ) ;
grad_op - > SetBlockAttr ( " sub_block " , this - > grad_block_ [ 0 ] ) ;