diff --git a/dface/core/detect.py b/dface/core/detect.py index 6d5cefa..5b64407 100644 --- a/dface/core/detect.py +++ b/dface/core/detect.py @@ -14,23 +14,30 @@ def create_mtcnn_net(p_model_path=None, r_model_path=None, o_model_path=None, us if p_model_path is not None: pnet = PNet(use_cuda=use_cuda) - pnet.load_state_dict(torch.load(p_model_path)) if(use_cuda): + pnet.load_state_dict(torch.load(p_model_path)) pnet.cuda() + else: + # forcing all GPU tensors to be in CPU while loading + pnet.load_state_dict(torch.load(p_model_path, map_location=lambda storage, loc: storage)) pnet.eval() if r_model_path is not None: rnet = RNet(use_cuda=use_cuda) - rnet.load_state_dict(torch.load(r_model_path)) if (use_cuda): + rnet.load_state_dict(torch.load(r_model_path)) rnet.cuda() + else: + rnet.load_state_dict(torch.load(r_model_path, map_location=lambda storage, loc: storage)) rnet.eval() if o_model_path is not None: onet = ONet(use_cuda=use_cuda) - onet.load_state_dict(torch.load(o_model_path)) if (use_cuda): + onet.load_state_dict(torch.load(o_model_path)) onet.cuda() + else: + onet.load_state_dict(torch.load(o_model_path, map_location=lambda storage, loc: storage)) onet.eval() return pnet,rnet,onet diff --git a/environment_osx.yaml b/environment_osx.yaml new file mode 100644 index 0000000..0c579cf --- /dev/null +++ b/environment_osx.yaml @@ -0,0 +1,76 @@ +name: dface +channels: +- pytorch +- defaults +dependencies: +- backports=1.0=py27hb4f9756_1 +- backports.functools_lru_cache=1.4=py27h2aca819_1 +- backports_abc=0.5=py27h6972548_0 +- bzip2=1.0.6=h649919c_2 +- ca-certificates=2017.08.26=ha1e5d58_0 +- cairo=1.14.10=h913ea44_6 +- certifi=2017.11.5=py27hfa9a1c4_0 +- cffi=1.11.4=py27h342bebf_0 +- cycler=0.10.0=py27hfc73c78_0 +- ffmpeg=3.4=h766ddd1_0 +- fontconfig=2.12.4=hffb9db1_2 +- freetype=2.8=h12048fb_1 +- functools32=3.2.3.2=py27h8ceab06_1 +- gettext=0.19.8.1=h15daf44_3 +- glib=2.53.6=h33f6a65_2 +- graphite2=1.3.10=h233cf8b_0 +- harfbuzz=1.5.0=h6db888e_0 +- hdf5=1.10.1=ha036c08_1 +- icu=58.2=h4b95b61_1 +- intel-openmp=2018.0.0=h8158457_8 +- jasper=1.900.1=h1f36771_4 +- jpeg=9b=he5867d9_2 +- libcxx=4.0.1=h579ed51_0 +- libcxxabi=4.0.1=hebd6815_0 +- libedit=3.1=hb4e282d_0 +- libffi=3.2.1=h475c297_4 +- libgfortran=3.0.1=h93005f0_2 +- libiconv=1.15=hdd342a3_7 +- libopus=1.2.1=h169cedb_0 +- libpng=1.6.32=hd1e8b91_4 +- libprotobuf=3.4.1=h326466f_0 +- libtiff=4.0.9=h0dac147_0 +- libvpx=1.6.1=h057a404_0 +- libxml2=2.9.4=hf05c021_6 +- matplotlib=2.1.1=py27hb768455_0 +- mkl=2018.0.1=hfbd8650_4 +- ncurses=6.0=hd04f020_2 +- numpy=1.14.0=py27h8a80b8c_0 +- olefile=0.44=py27h73ba740_0 +- opencv=3.3.1=py27h60a5f38_1 +- openssl=1.0.2n=hdbc3d79_0 +- pcre=8.41=hfb6ab37_1 +- pillow=5.0.0=py27hfcce615_0 +- pip=9.0.1=py27h1567d89_4 +- pixman=0.34.0=hca0a616_3 +- pycparser=2.18=py27h0d28d88_1 +- pyparsing=2.2.0=py27h5bb6aaf_0 +- python=2.7.14=hde5916a_29 +- python-dateutil=2.6.1=py27hd56c96b_1 +- pytz=2017.3=py27h803c07a_0 +- readline=7.0=hc1231fa_4 +- setuptools=38.4.0=py27_0 +- singledispatch=3.4.0.3=py27he22c18d_0 +- six=1.11.0=py27h7252ba3_1 +- sqlite=3.20.1=h7e4c145_2 +- ssl_match_hostname=3.5.0.1=py27h8780752_2 +- subprocess32=3.2.7=py27h24b2887_0 +- tk=8.6.7=h35a86e2_3 +- tornado=4.5.3=py27_0 +- wheel=0.30.0=py27h677a027_1 +- xz=5.2.3=h0278029_2 +- zlib=1.2.11=hf3cbc9b_2 +- pytorch=0.3.0=py27_cuda0.0_cudnn0.0he480db7_4 +- torchvision=0.2.0=py27hfc0307a_1 +- pip: + - backports-abc==0.5 + - backports.functools-lru-cache==1.4 + - backports.ssl-match-hostname==3.5.0.1 + - torch==0.3.0.post4 +prefix: /Users/hfu/anaconda2/envs/dface +