@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License . */
# include "paddle/framework/op_desc.h"
# include <functional>
# include <unordered_map>
# include "paddle/framework/block_desc.h"
# include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
@ -184,5 +187,37 @@ void OpDescBind::Sync() {
need_update_ = false ;
}
}
using InferShapeFuncMap =
std : : unordered_map < std : : string /*op_type*/ ,
std : : function < void ( InferShapeContext * ) > > ;
static InferShapeFuncMap & InferShapeFuncs ( ) {
static InferShapeFuncMap * g_map = nullptr ;
if ( g_map = = nullptr ) {
g_map = new InferShapeFuncMap ( ) ;
auto & info_map = OpInfoMap : : Instance ( ) ;
// all registered kernels
for ( auto & pair : OperatorWithKernel : : AllOpKernels ( ) ) {
auto & info = info_map . Get ( pair . first ) ;
auto op =
static_cast < OperatorWithKernel * > ( info . Creator ( ) ( " " , { } , { } , { } ) ) ;
g_map - > insert (
{ pair . first , [ op ] ( InferShapeContext * ctx ) { op - > InferShape ( ctx ) ; } } ) ;
}
}
return * g_map ;
}
void OpDescBind : : InferShape ( const BlockDescBind & block ) const {
auto & funcs = InferShapeFuncs ( ) ;
auto it = funcs . find ( this - > Type ( ) ) ;
if ( it = = funcs . end ( ) ) {
PADDLE_THROW ( " Operator %s has not been registered " , this - > Type ( ) ) ;
}
CompileTimeInferShapeContext ctx ( * this , block ) ;
it - > second ( & ctx ) ;
}
} // namespace framework
} // namespace paddle