@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License . */
# include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
# include <string>
namespace paddle {
namespace framework {
@ -21,9 +22,16 @@ namespace ir {
std : : unique_ptr < ir : : Graph > MKLDNNPlacementPass : : ApplyImpl (
std : : unique_ptr < ir : : Graph > graph ) const {
VLOG ( 3 ) < < " Aplies MKL-DNN placement strategy. " ;
const auto & op_types_list =
Get < std : : unordered_set < std : : string > > ( " mkldnn_enabled_op_types " ) ;
for ( const Node * n : graph - > Nodes ( ) ) {
if ( n - > IsOp ( ) & & n - > RuntimeHasAttr ( " use_mkldnn " ) ) {
if ( op_types_list . empty ( ) ) {
n - > Op ( ) - > SetAttr ( " use_mkldnn " , true ) ;
} else if ( std : : find ( op_types_list . begin ( ) , op_types_list . end ( ) ,
n - > Name ( ) ) ! = op_types_list . end ( ) ) {
n - > Op ( ) - > SetAttr ( " use_mkldnn " , true ) ;
}
}
}
return graph ;
@ -33,5 +41,5 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
} // namespace framework
} // namespace paddle
REGISTER_PASS ( mkldnn_placement_pass ,
paddle : : framework : : ir : : MKLDNNPlacementPass ) ;
REGISTER_PASS ( mkldnn_placement_pass , paddle : : framework : : ir : : MKLDNNPlacementPass )
. RequirePassAttr ( " mkldnn_enabled_op_types " ) ;