Test for passing global env FLAGS_use_mkldnn (#26938)
	
		
	
				
					
				
			
							parent
							
								
									21565e8d85
								
							
						
					
					
						commit
						8aebcabf34
					
				| @ -0,0 +1,48 @@ | ||||
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import unicode_literals | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import numpy as np | ||||
| import paddle.fluid as fluid | ||||
| import os | ||||
| from paddle.fluid.layer_helper import LayerHelper | ||||
| 
 | ||||
| 
 | ||||
| def check(): | ||||
|     print("check: fluid.core.globals()['FLAGS_use_mkldnn']=", | ||||
|           fluid.core.globals()["FLAGS_use_mkldnn"]) | ||||
|     print("check: fluid.get_flags('FLAGS_use_mkldnn')=", | ||||
|           fluid.get_flags(['FLAGS_use_mkldnn'])) | ||||
|     print("check: DNNL_VERBOSE=", os.environ['DNNL_VERBOSE']) | ||||
|     a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32) | ||||
|     helper = LayerHelper(fluid.unique_name.generate(str("test")), act="relu") | ||||
|     func = helper.append_activation | ||||
|     with fluid.dygraph.guard(fluid.core.CPUPlace()): | ||||
|         a = fluid.dygraph.to_variable(a_np) | ||||
|         res1 = func(a) | ||||
|         res2 = np.maximum(a_np, 0) | ||||
|     assert (np.array_equal(res1.numpy(), res2)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     try: | ||||
|         check() | ||||
|         for k, v in sorted(os.environ.items()): | ||||
|             print(k + ':', v) | ||||
|         print('\n') | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         print(type(e)) | ||||
| @ -0,0 +1,58 @@ | ||||
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import unicode_literals | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import unittest | ||||
| import os | ||||
| import sys | ||||
| import subprocess | ||||
| 
 | ||||
| 
 | ||||
| class TestFlagsUseMkldnn(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         self._python_interp = sys.executable | ||||
|         self._python_interp += " check_flags_use_mkldnn.py" | ||||
| 
 | ||||
|         self.env = os.environ.copy() | ||||
|         self.env[str("GLOG_v")] = str("3") | ||||
|         self.env[str("DNNL_VERBOSE")] = str("1") | ||||
|         self.env[str("FLAGS_use_mkldnn")] = str("1") | ||||
| 
 | ||||
|     def test_flags_use_mkl_dnn(self): | ||||
|         cmd = self._python_interp | ||||
| 
 | ||||
|         proc = subprocess.Popen( | ||||
|             cmd.split(" "), | ||||
|             stdout=subprocess.PIPE, | ||||
|             stderr=subprocess.PIPE, | ||||
|             env=self.env) | ||||
| 
 | ||||
|         out, err = proc.communicate() | ||||
|         returncode = proc.returncode | ||||
| 
 | ||||
|         print('out', out) | ||||
|         print('err', err) | ||||
| 
 | ||||
|         assert returncode == 0 | ||||
|         # in python3, type(out) is 'bytes', need use encode | ||||
|         assert out.find( | ||||
|             "dnnl_verbose,exec,cpu,eltwise,jit:avx512_common,forward_training," | ||||
|             "data_f32::blocked:abc:f0 diff_undef::undef::f0,,alg:eltwise_relu". | ||||
|             encode()) != -1 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue