@ -23,6 +23,8 @@ import six
import numpy as np
import paddle
import paddle . fluid as fluid
from paddle . fluid . framework import IrGraph
from paddle . fluid import core
from paddle . fluid . contrib . slim . core import Compressor
from paddle . fluid . log_helper import get_logger
@ -112,6 +114,41 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
return new_config_path
def _transform_depthwise_conv ( self , graph ) :
'''
Transform depthwise_conv2d into conv2d , with MKL - DNN only
'''
ops = graph . all_op_nodes ( )
for op_node in ops :
name = op_node . name ( )
if name in [ ' depthwise_conv2d ' ] :
input_var_node = graph . _find_node_by_name (
op_node . inputs , op_node . input ( " Input " ) [ 0 ] )
weight_var_node = graph . _find_node_by_name (
op_node . inputs , op_node . input ( " Filter " ) [ 0 ] )
output_var_node = graph . _find_node_by_name (
graph . all_var_nodes ( ) , op_node . output ( " Output " ) [ 0 ] )
attrs = {
name : op_node . op ( ) . attr ( name )
for name in op_node . op ( ) . attr_names ( )
}
conv_op_node = graph . create_op_node (
op_type = ' conv2d ' ,
attrs = attrs ,
inputs = {
' Input ' : input_var_node ,
' Filter ' : weight_var_node
} ,
outputs = { ' Output ' : output_var_node } )
graph . link_to ( input_var_node , conv_op_node )
graph . link_to ( weight_var_node , conv_op_node )
graph . link_to ( conv_op_node , output_var_node )
graph . safe_remove_nodes ( op_node )
return graph
def _predict ( self , test_reader = None , model_path = None ) :
place = fluid . CPUPlace ( )
exe = fluid . Executor ( place )
@ -125,6 +162,13 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
fetch_targets ] = fluid . io . load_inference_model (
model_path , exe , ' model ' , ' params ' )
use_mkldnn = bool ( os . getenv ( " FLAGS_use_mkldnn " , False ) )
if ( use_mkldnn ) :
graph = IrGraph (
core . Graph ( inference_program . desc ) , for_test = True )
graph = self . _transform_depthwise_conv ( graph )
inference_program = graph . to_program ( )
dshape = [ 3 , 224 , 224 ]
top1 = 0.0
top5 = 0.0