"""
Fast hash of Python objects.
Copyright 2021, Bill & Melinda Gates Foundation. All rights reserved.
"""
from typing import Union, BinaryIO
import decimal
import hashlib
import io
import pickle
import types
from dataclasses import fields, _MISSING_TYPE
from logging import getLogger, Logger
logger = getLogger(__name__)
Pickler = pickle._Pickler
class _ConsistentSet(object):
"""
Class used to ensure the hash of sets is preserved whatever the order of its items.
"""
def __init__(self, set_sequence):
"""
Force the order of elements in a set to ensure consistent hashing.
"""
try:
# Trying first to order the set using sorted
self._sequence = sorted(set_sequence)
except (TypeError, decimal.InvalidOperation):
# If elements are unorderable, sort them using their hash.
self._sequence = sorted((hash(e) for e in set_sequence))
class _MyHash(object):
"""
A class used to hash objects that won't normally pickle.
"""
def __init__(self, *args):
self.args = args
[docs]class Hasher(Pickler):
"""
A subclass of pickler to do hashing, rather than pickling.
"""
[docs] def __init__(self, hash_name='md5'):
"""
Initialize our hasher.
Args:
hash_name: Hash type to use. Defaults to md5
"""
self.stream = io.BytesIO()
Pickler.__init__(self, self.stream)
# Initialise the hash obj
self._hash = hashlib.new(hash_name)
[docs] def hash(self, obj, return_digest=True):
"""
Hash an object.
Args:
obj: Object to hash
return_digest: Should the digest be returned?
Returns:
None if return_digest is False, otherwise the hash digest is returned
"""
try:
self.dump(obj)
except pickle.PicklingError as e:
e.args += ('PicklingError while hashing %r: %r' % (obj, e),)
raise
dumps = self.stream.getvalue()
self._hash.update(dumps)
if return_digest:
return self._hash.hexdigest()
[docs] def save(self, obj):
"""
Save an object to hash.
Args:
obj: Obj to save.
Returns:
None
"""
from idmtools.utils.collections import ExperimentParentIterator
import abc
if isinstance(obj, abc.ABCMeta):
pass
elif isinstance(obj, ExperimentParentIterator):
pass
elif isinstance(obj, Logger):
pass
else:
if isinstance(obj, (types.MethodType, type({}.pop))):
# the Pickler cannot pickle instance methods; here we decompose
# them into components that make them uniquely identifiable
if hasattr(obj, '__func__'):
func_name = obj.__func__.__name__
else:
func_name = obj.__name__
inst = obj.__self__
if isinstance(inst, pickle):
obj = _MyHash(func_name, inst.__name__)
elif inst is None:
# type(None) or type(module) do not pickle
obj = _MyHash(func_name, inst)
else:
cls = obj.__self__.__class__
obj = _MyHash(func_name, inst, cls)
Pickler.save(self, obj)
[docs] def memoize(self, obj):
"""
Disable memoization for strings so hashing happens on value and not reference.
"""
if isinstance(obj, (str, bytes)):
return
Pickler.memoize(self, obj)
def _batch_setitems(self, items):
"""
Force the order of keys in dictionary to ensure consistent hashing.
"""
try:
# First try quick way of sorting keys if possible
Pickler._batch_setitems(self, iter(sorted(items)))
except TypeError:
# If keys are unorderable, sort them using their hash
Pickler._batch_setitems(self, iter(sorted((hash(k), v) for k, v in items)))
[docs] def save_set(self, set_items):
"""
Save set hashing.
Args:
set_items: Set items
Returns:
None
"""
# forces order of items in Set to ensure consistent hash
Pickler.save(self, _ConsistentSet(set_items))
[docs]def hash_obj(obj, hash_name='md5'):
"""
Quick calculation of a hash to identify uniquely Python objects.
Args:
obj: Object to hash
hash_name: The hashing algorithm to use. 'md5' is faster; 'sha1' is considered safer.
"""
hasher = Hasher(hash_name=hash_name)
return hasher.hash(obj)
[docs]def ignore_fields_in_dataclass_on_pickle(item):
"""
Ignore certain fields for pickling on dataclasses.
Args:
item: Item to pickle
Returns:
State of item to pickle
"""
state = item.__dict__.copy()
attrs = set(vars(item).keys())
# Retrieve fields default values
fds = fields(item)
field_default = {f.name: f.default for f in fds}
# Update default with parent's pre-populated values
if hasattr(item, 'pre_getstate'):
pre_state = item.pre_getstate()
pre_state = pre_state or {}
field_default.update(pre_state)
# Don't pickle ignore_pickle fields: set values to default
for field_name in attrs.intersection(item.pickle_ignore_fields):
if field_name in state:
if isinstance(field_default[field_name], _MISSING_TYPE):
state[field_name] = None
else:
state[field_name] = field_default[field_name]
return state
[docs]def calculate_md5(filename: str, chunk_size: int = 8192) -> str:
"""
Calculate MD5.
Args:
filename: Filename to caclulate md5 for
chunk_size: Chunk size
Returns:
md5 as string
"""
with open(filename, "rb") as f:
return calculate_md5_stream(f, chunk_size)
[docs]def calculate_md5_stream(stream: Union[io.BytesIO, BinaryIO], chunk_size: int = 8192, hash_type: str = 'md5', file_hash=None):
"""
Calculate md5 on stream.
Args:
chunk_size:
stream:
hash_type: Hash function
file_hash: File hash
Returns:
md5 of stream
"""
if file_hash is None:
if not hasattr(hashlib, hash_type):
raise ValueError(f"Could not find hash function {hash_type}")
else:
file_hash = getattr(hashlib, hash_type)()
while True:
chunk = stream.read(chunk_size)
if not chunk:
break
file_hash.update(chunk)
return file_hash.hexdigest()