|
|
|
@ -16,7 +16,9 @@ import os
|
|
|
|
|
import re
|
|
|
|
|
import six
|
|
|
|
|
import sys
|
|
|
|
|
import json
|
|
|
|
|
import glob
|
|
|
|
|
import hashlib
|
|
|
|
|
import logging
|
|
|
|
|
import collections
|
|
|
|
|
import textwrap
|
|
|
|
@ -219,6 +221,106 @@ class CustomOpInfo:
|
|
|
|
|
return next(reversed(self.op_info_map.items()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VersionFields = collections.namedtuple('VersionFields', [
|
|
|
|
|
'sources',
|
|
|
|
|
'extra_compile_args',
|
|
|
|
|
'extra_link_args',
|
|
|
|
|
'library_dirs',
|
|
|
|
|
'runtime_library_dirs',
|
|
|
|
|
'include_dirs',
|
|
|
|
|
'define_macros',
|
|
|
|
|
'undef_macros',
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VersionManager:
|
|
|
|
|
def __init__(self, version_field):
|
|
|
|
|
self.version_field = version_field
|
|
|
|
|
self.version = self.hasher(version_field)
|
|
|
|
|
|
|
|
|
|
def hasher(self, version_field):
|
|
|
|
|
from paddle.fluid.layers.utils import flatten
|
|
|
|
|
|
|
|
|
|
md5 = hashlib.md5()
|
|
|
|
|
for field in version_field._fields:
|
|
|
|
|
elem = getattr(version_field, field)
|
|
|
|
|
if not elem: continue
|
|
|
|
|
if isinstance(elem, (list, tuple, dict)):
|
|
|
|
|
flat_elem = flatten(elem)
|
|
|
|
|
md5 = combine_hash(md5, tuple(flat_elem))
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Support types with list, tuple and dict, but received {} with {}.".
|
|
|
|
|
format(type(elem), elem))
|
|
|
|
|
|
|
|
|
|
return md5.hexdigest()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def details(self):
|
|
|
|
|
return self.version_field._asdict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def combine_hash(md5, value):
|
|
|
|
|
"""
|
|
|
|
|
Return new hash value.
|
|
|
|
|
DO NOT use `hash()` beacuse it doesn't generate stable value between different process.
|
|
|
|
|
See https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions
|
|
|
|
|
"""
|
|
|
|
|
md5.update(repr(value).encode())
|
|
|
|
|
return md5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_object_if_change_cflags(so_path, extension):
|
|
|
|
|
"""
|
|
|
|
|
If already compiling source before, we should check whether cflags
|
|
|
|
|
have changed and delete the built object to re-compile the source
|
|
|
|
|
even though source file content keeps unchanaged.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def serialize(path, version_info):
|
|
|
|
|
assert isinstance(version_info, dict)
|
|
|
|
|
with open(path, 'w') as f:
|
|
|
|
|
f.write(json.dumps(version_info, indent=4, sort_keys=True))
|
|
|
|
|
|
|
|
|
|
def deserialize(path):
|
|
|
|
|
assert os.path.exists(path)
|
|
|
|
|
with open(path, 'r') as f:
|
|
|
|
|
content = f.read()
|
|
|
|
|
return json.loads(content)
|
|
|
|
|
|
|
|
|
|
# version file
|
|
|
|
|
VERSION_FILE = "version.txt"
|
|
|
|
|
base_dir = os.path.dirname(so_path)
|
|
|
|
|
so_name = os.path.basename(so_path)
|
|
|
|
|
version_file = os.path.join(base_dir, VERSION_FILE)
|
|
|
|
|
|
|
|
|
|
# version info
|
|
|
|
|
args = [getattr(extension, field, None) for field in VersionFields._fields]
|
|
|
|
|
version_field = VersionFields._make(args)
|
|
|
|
|
versioner = VersionManager(version_field)
|
|
|
|
|
|
|
|
|
|
if os.path.exists(so_path) and os.path.exists(version_file):
|
|
|
|
|
old_version_info = deserialize(version_file)
|
|
|
|
|
so_version = old_version_info.get(so_name, None)
|
|
|
|
|
# delete shared library file if versison is changed to re-compile it.
|
|
|
|
|
if so_version is not None and so_version != versioner.version:
|
|
|
|
|
log_v(
|
|
|
|
|
"Re-Compiling {}, because specified cflags have been changed. New signature {} has been saved into {}.".
|
|
|
|
|
format(so_name, versioner.version, version_file))
|
|
|
|
|
os.remove(so_path)
|
|
|
|
|
# upate new version information
|
|
|
|
|
new_version_info = versioner.details
|
|
|
|
|
new_version_info[so_name] = versioner.version
|
|
|
|
|
serialize(version_file, new_version_info)
|
|
|
|
|
else:
|
|
|
|
|
# If compile at first time, save compiling detail information for debug.
|
|
|
|
|
if not os.path.exists(base_dir):
|
|
|
|
|
os.makedirs(base_dir)
|
|
|
|
|
details = versioner.details
|
|
|
|
|
details[so_name] = versioner.version
|
|
|
|
|
serialize(version_file, details)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_unix_cudaflags(cflags):
|
|
|
|
|
"""
|
|
|
|
|
Prepare all necessary compiled flags for nvcc compiling CUDA files.
|
|
|
|
|