|
|
|
@ -48,54 +48,65 @@ class TestFlagsUseMkldnn(unittest.TestCase):
|
|
|
|
|
returncode = proc.returncode
|
|
|
|
|
|
|
|
|
|
assert returncode == 0
|
|
|
|
|
return out
|
|
|
|
|
return out, err
|
|
|
|
|
|
|
|
|
|
def found(self, regex, out):
|
|
|
|
|
return re.search(regex, out, re.MULTILINE)
|
|
|
|
|
def _print_when_false(self, cond, out, err):
|
|
|
|
|
if not cond:
|
|
|
|
|
print('out', out)
|
|
|
|
|
print('err', err)
|
|
|
|
|
return cond
|
|
|
|
|
|
|
|
|
|
def found(self, regex, out, err):
|
|
|
|
|
_found = re.search(regex, out, re.MULTILINE)
|
|
|
|
|
return self._print_when_false(_found, out, err)
|
|
|
|
|
|
|
|
|
|
def not_found(self, regex, out, err):
|
|
|
|
|
_not_found = not re.search(regex, out, re.MULTILINE)
|
|
|
|
|
return self._print_when_false(_not_found, out, err)
|
|
|
|
|
|
|
|
|
|
def test_flags_use_mkl_dnn_on_empty_off_empty(self):
|
|
|
|
|
out = self.flags_use_mkl_dnn_common({})
|
|
|
|
|
assert self.found(self.relu_regex, out)
|
|
|
|
|
assert self.found(self.ew_add_regex, out)
|
|
|
|
|
assert self.found(self.matmul_regex, out)
|
|
|
|
|
out, err = self.flags_use_mkl_dnn_common({})
|
|
|
|
|
assert self.found(self.relu_regex, out, err)
|
|
|
|
|
assert self.found(self.ew_add_regex, out, err)
|
|
|
|
|
assert self.found(self.matmul_regex, out, err)
|
|
|
|
|
|
|
|
|
|
def test_flags_use_mkl_dnn_on(self):
|
|
|
|
|
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu")}
|
|
|
|
|
out = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.found(self.relu_regex, out)
|
|
|
|
|
assert not self.found(self.ew_add_regex, out)
|
|
|
|
|
assert not self.found(self.matmul_regex, out)
|
|
|
|
|
out, err = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.found(self.relu_regex, out, err)
|
|
|
|
|
assert self.not_found(self.ew_add_regex, out, err)
|
|
|
|
|
assert self.not_found(self.matmul_regex, out, err)
|
|
|
|
|
|
|
|
|
|
def test_flags_use_mkl_dnn_on_multiple(self):
|
|
|
|
|
env = {str("FLAGS_tracer_mkldnn_ops_on"): str("relu,elementwise_add")}
|
|
|
|
|
out = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.found(self.relu_regex, out)
|
|
|
|
|
assert self.found(self.ew_add_regex, out)
|
|
|
|
|
assert not self.found(self.matmul_regex, out)
|
|
|
|
|
out, err = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.found(self.relu_regex, out, err)
|
|
|
|
|
assert self.found(self.ew_add_regex, out, err)
|
|
|
|
|
assert self.not_found(self.matmul_regex, out, err)
|
|
|
|
|
|
|
|
|
|
def test_flags_use_mkl_dnn_off(self):
|
|
|
|
|
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")}
|
|
|
|
|
out = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.found(self.relu_regex, out)
|
|
|
|
|
assert self.found(self.ew_add_regex, out)
|
|
|
|
|
assert not self.found(self.matmul_regex, out)
|
|
|
|
|
out, err = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.found(self.relu_regex, out, err)
|
|
|
|
|
assert self.found(self.ew_add_regex, out, err)
|
|
|
|
|
assert self.not_found(self.matmul_regex, out, err)
|
|
|
|
|
|
|
|
|
|
def test_flags_use_mkl_dnn_off_multiple(self):
|
|
|
|
|
env = {str("FLAGS_tracer_mkldnn_ops_off"): str("matmul,relu")}
|
|
|
|
|
out = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert not self.found(self.relu_regex, out)
|
|
|
|
|
assert self.found(self.ew_add_regex, out)
|
|
|
|
|
assert not self.found(self.matmul_regex, out)
|
|
|
|
|
out, err = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.not_found(self.relu_regex, out, err)
|
|
|
|
|
assert self.found(self.ew_add_regex, out, err)
|
|
|
|
|
assert self.not_found(self.matmul_regex, out, err)
|
|
|
|
|
|
|
|
|
|
def test_flags_use_mkl_dnn_on_off(self):
|
|
|
|
|
env = {
|
|
|
|
|
str("FLAGS_tracer_mkldnn_ops_on"): str("elementwise_add"),
|
|
|
|
|
str("FLAGS_tracer_mkldnn_ops_off"): str("matmul")
|
|
|
|
|
}
|
|
|
|
|
out = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert not self.found(self.relu_regex, out)
|
|
|
|
|
assert self.found(self.ew_add_regex, out)
|
|
|
|
|
assert not self.found(self.matmul_regex, out)
|
|
|
|
|
out, err = self.flags_use_mkl_dnn_common(env)
|
|
|
|
|
assert self.not_found(self.relu_regex, out, err)
|
|
|
|
|
assert self.found(self.ew_add_regex, out, err)
|
|
|
|
|
assert self.not_found(self.matmul_regex, out, err)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|