@ -86,7 +86,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
auto * out_var = block_ . FindVarRecursive ( Outputs ( out ) [ j ] ) ;
if ( in_var - > GetType ( ) ! = proto : : VarType : : LOD_TENSOR & &
in_var - > GetType ( ) ! = proto : : VarType : : LOD_TENSOR_ARRAY ) {
VLOG ( 3 ) < < " input " < < in < < " is not Lo dTensor or Lod TensorArray." ;
VLOG ( 3 ) < < " input " < < in < < " is not Lo DTensor or LoD TensorArray." ;
return ;
}
out_var - > SetLoDLevel ( in_var - > GetLoDLevel ( ) ) ;
@ -94,6 +94,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void DecreaseLoDLevel ( const std : : string & in , const std : : string & out ,
size_t i = 0 , size_t j = 0 ) const override {
// When in is a LoDTensor and out is a LoDTensorArray, there may need to
// decrease the lod_level.
PADDLE_ENFORCE_LT ( i , Inputs ( in ) . size ( ) ) ;
PADDLE_ENFORCE_LT ( j , Outputs ( out ) . size ( ) ) ;
PADDLE_ENFORCE ( Inputs ( in ) [ i ] ! = framework : : kEmptyVarName ,
@ -102,17 +104,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
" The %s[%d] is @EMPTY@ " , out , j ) ;
auto * in_var = block_ . FindVarRecursive ( Inputs ( in ) [ i ] ) ;
auto * out_var = block_ . FindVarRecursive ( Outputs ( out ) [ j ] ) ;
PADDLE_ENFORCE ( out_var - > GetType ( ) = = proto : : VarType : : LOD_TENSOR_ARRAY | |
out_var - > GetType ( ) = = proto : : VarType : : LOD_TENSOR ,
" The input %s should be LodTensorArray or LodTensor. " ,
out_var - > Name ( ) ) ;
PADDLE_ENFORCE ( in_var - > GetType ( ) = = proto : : VarType : : LOD_TENSOR ,
" The input %s should be LodTensor. " , in_var - > Name ( ) ) ;
PADDLE_ENFORCE_EQ ( in_var - > GetType ( ) , proto : : VarType : : LOD_TENSOR ,
" The input %s should be LoDTensor. " , in_var - > Name ( ) ) ;
PADDLE_ENFORCE_EQ ( out_var - > GetType ( ) , proto : : VarType : : LOD_TENSOR_ARRAY ,
" The output %s should be LoDTensorArray. " ,
out_var - > Name ( ) ) ;
if ( in_var - > GetLoDLevel ( ) > 0 ) {
out_var - > SetLoDLevel ( in_var - > GetLoDLevel ( ) - 1 ) ;
}
}
void IncreaseLoDLevel ( const std : : string & in , const std : : string & out ,
size_t i = 0 , size_t j = 0 ) const override {
// When in is a LoDTensorArray and out is a LoDTensor, there may need to
// increase the lod_level.
PADDLE_ENFORCE_LT ( i , Inputs ( in ) . size ( ) ) ;
PADDLE_ENFORCE_LT ( j , Outputs ( out ) . size ( ) ) ;
PADDLE_ENFORCE_NE ( Inputs ( in ) [ i ] , framework : : kEmptyVarName ,
" The %s[%d] is @EMPTY@ " , in , i ) ;
PADDLE_ENFORCE_NE ( Outputs ( out ) [ j ] , framework : : kEmptyVarName ,
" The %s[%d] is @EMPTY@ " , out , j ) ;
auto * in_var = block_ . FindVarRecursive ( Inputs ( in ) [ i ] ) ;
auto * out_var = block_ . FindVarRecursive ( Outputs ( out ) [ j ] ) ;
PADDLE_ENFORCE_EQ ( in_var - > GetType ( ) , proto : : VarType : : LOD_TENSOR_ARRAY ,
" The input %s should be LoDTensorArray. " , in_var - > Name ( ) ) ;
PADDLE_ENFORCE_EQ ( out_var - > GetType ( ) , proto : : VarType : : LOD_TENSOR ,
" The output %s should be LoDTensor. " , out_var - > Name ( ) ) ;
out_var - > SetLoDLevel ( in_var - > GetLoDLevel ( ) + 1 ) ;
}
std : : vector < InferShapeVarPtr > GetInputVarPtrs (
const std : : string & name ) override {
const std : : vector < std : : string > arg_names = Inputs ( name ) ;