|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Context for parameter server training mode"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
from mindspore._c_expression import PSContext
|
|
|
|
|
|
|
|
|
|
_ps_context = None
|
|
|
|
@ -134,4 +135,10 @@ def _clone_hash_table(dest_param_name, src_param_name):
|
|
|
|
|
ps_context().clone_hash_table(dest_param_name, src_param_name)
|
|
|
|
|
|
|
|
|
|
def _set_cache_enable(cache_enable):
|
|
|
|
|
# Environment variables are used to specify a maximum number of OpenBLAS threads:
|
|
|
|
|
# In ubuntu(GPU) environment, numpy will use too many threads for computing,
|
|
|
|
|
if cache_enable:
|
|
|
|
|
os.environ['OPENBLAS_NUM_THREADS'] = '2'
|
|
|
|
|
os.environ['GOTO_NUM_THREADS'] = '2'
|
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '2'
|
|
|
|
|
ps_context().set_cache_enable(cache_enable)
|
|
|
|
|