@ -34,10 +34,10 @@ extern std::once_flag PaddleInferenceAnakinPredictor<T, P, R>::init_anakin_;
template < typename T , Precision P , OpRunType R >
void PaddleInferenceAnakinPredictor < T , P , R > : : InitEnv ( ) {
anakin : : TargetWrapper < T > : : set_device ( this - > config_ . device_id ) ;
std : : call_once ( this - > init_anakin_ , [ this ] ( ) {
anakin : : Env < T > : : env_init ( this - > config_ . max_stream ) ;
} ) ;
anakin : : TargetWrapper < T > : : set_device ( this - > config_ . device_id ) ;
}
template < typename T , Precision P , OpRunType R >
void PaddleInferenceAnakinPredictor < T , P , R > : : InitNet ( ) {
@ -54,14 +54,19 @@ template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor < T , P , R > : : InitGraph ( ) {
this - > graph_p_ =
std : : make_shared < anakin : : graph : : Graph < T , anakin : : Precision : : FP32 > > ( ) ;
if ( ! ( this - > graph_p_ - > load ( this - > config_ . model_file ) ) ) {
LOG ( FATAL ) < < " fail to load graph from " < < this - > config_ . model_file ;
if ( ! this - > config_ . model_file . empty ( ) ) {
this - > graph_p_ - > load ( this - > config_ . model_file ) ;
} else if ( this - > config_ . model_buf_p ) {
this - > graph_p_ - > load ( this - > config_ . model_buf_p ,
this - > config_ . model_buf_len ) ;
} else {
LOG ( FATAL ) < < " Model load error. " ;
}
auto inputs = this - > graph_p_ - > get_ins ( ) ;
for ( auto & input_str : inputs ) {
if ( this - > config_ . init_inputs_shape . find ( input_str ) = =
this - > config_ . init_inputs_shape . end ( ) ) {
LOG ( FATAL ) < < input_str < < " is not implemented ." ;
LOG ( FATAL ) < < input_str < < " should be set in init_inputs_shape ." ;
}
std : : vector < int > shape =
this - > config_ . init_inputs_shape . find ( input_str ) - > second ;
@ -189,6 +194,7 @@ template <typename T, Precision P, OpRunType R>
bool PaddleInferenceAnakinPredictor < T , P , R > : : RunImpl (
const std : : vector < PaddleTensor > & inputs ,
std : : vector < PaddleTensor > * output_data ) {
anakin : : TargetWrapper < T > : : set_device ( this - > config_ . device_id ) ;
for ( const auto & input : inputs ) {
if ( input . dtype ! = PaddleDType : : FLOAT32 ) {
LOG ( FATAL ) < < " Only support float type inputs. " < < input . name
@ -321,6 +327,27 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::Predict() {
}
# endif
# ifdef ANAKIN_BM_PLACE
template < Precision P , OpRunType R >
void PaddleInferenceAnakinBMPredictor < P , R > : : OptimizeGraph ( ) {
if ( ! this - > graph_p_ - > fusion_optimize ( ) ) {
LOG ( FATAL ) < < " Graph optimization error. " ;
}
}
template < Precision P , OpRunType R >
void PaddleInferenceAnakinBMPredictor < P , R > : : InitNet ( ) {
std : : unique_lock < std : : mutex > lock ( this - > mutex_ ) ;
this - > executor_p_ = new anakin : : Net < anakin : : BM , P , R > ( ) ;
this - > executor_p_ - > fusion_init ( * this - > graph_p_ , this - > ctx_p_ , true ) ;
}
template < Precision P , OpRunType R >
void PaddleInferenceAnakinBMPredictor < P , R > : : Predict ( ) {
anakin : : TargetWrapper < anakin : : BM > : : device_sync ( ) ;
this - > executor_p_ - > fusion_prediction ( ) ;
anakin : : TargetWrapper < anakin : : BM > : : device_sync ( ) ;
}
# endif
# ifdef PADDLE_WITH_CUDA
template class PaddleInferenceAnakinPredictor <
anakin : : NV , anakin : : Precision : : FP32 , : : anakin : : OpRunType : : ASYNC > ;
@ -333,6 +360,10 @@ template class PaddleInferenceAnakinPredictor<
template class PaddleInferenceAnakinMLUPredictor < anakin : : Precision : : FP32 ,
: : anakin : : OpRunType : : SYNC > ;
# endif
# ifdef ANAKIN_BM_PLACE
template class PaddleInferenceAnakinBMPredictor < anakin : : Precision : : FP32 ,
: : anakin : : OpRunType : : ASYNC > ;
# endif
// A factory to help create difference predictor.
template < >
@ -361,7 +392,16 @@ CreatePaddlePredictor<contrib::AnakinConfig, PaddleEngineKind::kAnakin>(
config ) ) ;
}
# endif
LOG ( FATAL ) < < " Anakin Predictor create on unknown platform. " ;
# ifdef ANAKIN_BM_PLACE
if ( config . target_type = = contrib : : AnakinConfig : : BM ) {
return std : : unique_ptr < PaddlePredictor > (
new PaddleInferenceAnakinBMPredictor < anakin : : Precision : : FP32 ,
: : anakin : : OpRunType : : ASYNC > (
config ) ) ;
}
# endif
LOG ( FATAL ) < < " Anakin Predictor create on unknown platform: "
< < config . target_type ;
return nullptr ;
}
template < typename T , Precision P , OpRunType R >