Source code for sciris.sc_utils

'''
Miscellaneous utilities for type checking, printing, dates and times, etc.

Note: there are a lot! The design philosophy has been that it's easier to
ignore a function that you don't need than write one from scratch that you
do need.

Highlights:
    - ``sc.dcp()``: shortcut to ``copy.deepcopy()``
    - ``sc.pp()``: shortcut to ``pprint.pprint()``
    - ``sc.pr()``: print full representation of an object, including methods and each attribute
    - ``sc.heading()``: print text as a 'large' heading
    - ``sc.colorize()``: print text in a certain color
    - ``sc.sigfigs()``: truncate a number to a certain number of significant figures
    - ``sc.isnumber()``: checks if something is any number type
    - ``sc.promotetolist()``: converts any object to a list, for easy iteration
    - ``sc.promotetoarray()``: tries to convert any object to an array, for easy use with numpy
    - ``sc.mergedicts()``: merges any set of inputs into a dictionary
    - ``sc.readdate()``: convert strings to dates using common formats
    - ``sc.daterange()``: create a list of dates
    - ``sc.datedelta()``: perform calculations on date strings
    - ``sc.tic()/sc.toc()``: simple method for timing durations
    - ``sc.runcommand()``: simple way of executing a shell command
'''

##############################################################################
#%% Imports
##############################################################################

import os
import re
import sys
import copy
import time
import json
import zlib
import types
import psutil
import pprint
import hashlib
import dateutil
import subprocess
import itertools
import numbers
import string
import tempfile
import warnings
import numpy as np
import pylab as pl
import random as rnd
import datetime as dt
import uuid as py_uuid
import pkg_resources as pkgr
import traceback as py_traceback
from textwrap import fill
from functools import reduce
from collections import OrderedDict as OD
from distutils.version import LooseVersion

# Handle types
_stringtypes = (str, bytes)
_numtype    = numbers.Number

# Add Windows support for colors (do this at the module level so that colorama.init() only gets called once)
if 'win' in sys.platform and sys.platform != 'darwin': # pragma: no cover # NB: can't use startswith() because of 'cygwin'
    try:
        import colorama
        colorama.init()
        ansi_support = True
    except:
        ansi_support = False  # print('Warning: you have called colorize() on Windows but do not have either the colorama or tendo modules.')
else:
    ansi_support = True



##############################################################################
#%% Adaptations from other libraries
##############################################################################

# Define the modules being loaded
__all__ = ['fast_uuid', 'uuid', 'dcp', 'cp', 'pp', 'sha', 'wget', 'htmlify', 'freeze', 'require',
           'traceback', 'getplatform', 'iswindows', 'islinux', 'ismac']


def fast_uuid(which=None, length=None, n=1, secure=False, forcelist=False, safety=1000, recursion=0, recursion_limit=10, verbose=True):
    '''
    Create a fast UID or set of UIDs.

    Args:
        which (str): the set of characters to choose from (default ascii)
        length (int): length of UID (default 6)
        n (int): number of UIDs to generate
        forcelist (bool): whether or not to return a list even for a single UID (used for recursive calls)
        safety (float): ensure that the space of possible UIDs is at least this much larger than the number requested
        recursion (int): the recursion level of the call (since the function calls itself if not all UIDs are unique)
        recursion_limit (int): # Maximum number of times to try regeneraring keys

    Returns:
        uid (str or list): a string UID, or a list of string UIDs

    **Example**::

        uuids = sc.fast_uuid(n=100) # Generate 100 UUIDs

    Inspired by https://stackoverflow.com/questions/2257441/random-string-generation-with-upper-case-letters-and-digits/30038250#30038250
    '''

    # Set defaults
    if which  is None: which  = 'ascii'
    if length is None: length = 6
    length = int(length)
    n = int(n)

    choices = {
        'lowercase':    string.ascii_lowercase,
        'letters':      string.ascii_letters,
        'numeric':      string.digits,
        'digits':       string.digits,
        'hex':          string.hexdigits.lower(),
        'hexdigits':    string.hexdigits.lower(),
        'alphanumeric': string.ascii_lowercase + string.digits,
        'ascii':        string.ascii_letters + string.digits,
        }

    if which not in choices: # pragma: no cover
        errormsg = f'Choice {which} not found; choices are: {strjoin(choices.keys())}'
        raise KeyError(errormsg)
    else:
        charlist = choices[which]

    # Check that there are enough options
    if n > 1:
        n_possibilities = len(charlist)**length
        allowed = n_possibilities//safety
        if n > allowed:
            errormsg = f'With a UID of type "{which}" and length {length}, there are {n_possibilities} possible UIDs, and you requested {n}, which exceeds the maximum allowed ({allowed})'
            raise ValueError(errormsg)

    # Secure uses system random which is secure, but >10x slower
    if secure:
        choices_func = rnd.SystemRandom().choices
    else:
        choices_func = rnd.choices

    # Generate the UUID(s) string as one big block
    uid_str = ''.join(choices_func(charlist, k=length*n))

    # Parse if n==1
    if n == 1:
        if forcelist:
            output = [uid_str]
        else:
            output = uid_str

    # Otherwise, we're generating multiple, so do additional checking to ensure they're actually unique
    else:
        # Split from one long string into multiple and check length
        output = [uid_str[chunk*length:(chunk+1)*length] for chunk in range(len(uid_str)//length)]
        n_unique_keys = len(dict.fromkeys(output))

        # Check that length is correct, i.e. no duplicates!
        while n_unique_keys != n:

            # Set recursion and do error checking
            recursion += 1
            if recursion > recursion_limit:
                errormsg = f'Could only generate {n_unique_keys}/{n} unique UIDs after {recursion_limit} tries: please increase UID length or character set size to ensure more unique options'
                raise ValueError(errormsg)
            if verbose:
                print(f'Warning: duplicates found in UID list ({n_unique_keys}/{n} unique); regenerating...')

            # Extend the list of UIDs
            new_n = n - n_unique_keys
            new_uuids = fast_uuid(which=which, length=length, n=new_n, secure=secure, safety=safety, recursion=recursion, recursion_limit=recursion_limit, verbose=verbose, forcelist=True)
            output.extend(new_uuids)
            n_unique_keys = len(dict.fromkeys(output)) # Recalculate the number of keys

    return output


def uuid(uid=None, which=None, die=False, tostring=False, length=None, n=1, **kwargs):
    '''
    Shortcut for creating a UUID; default is to create a UUID4. Can also convert a UUID.

    Args:
        uid (str or uuid): if a string, convert to an actual UUID; otherwise, return unchanged
        which (int or str): if int, choose a Python UUID function; otherwise, generate a random alphanumeric string (default 4)
        die (bool): whether to fail for converting a supplied uuid (default False)
        tostring (bool): whether or not to return a string instead of a UUID object (default False)
        length (int): number of characters to trim to, if returning a string
        n (int): number of UUIDs to generate; if n>1, return a list

    Returns:
        uid (UUID or str): the UID object

    **Examples**::

        sc.uuid() # Alias to uuid.uuid4()
        sc.uuid(which='hex') # Creates a length-6 hex string
        sc.uuid(which='ascii', length=10, n=50) # Creates 50 UUIDs of length 10 each using the full ASCII character set
    '''

    # Set default UUID type
    if which is None:
        which = 4
    n = int(n)

    # Choose the different functions
    if   which==1: uuid_func = py_uuid.uuid1
    elif which==3: uuid_func = py_uuid.uuid3
    elif which==4: uuid_func = py_uuid.uuid4
    elif which==5: uuid_func = py_uuid.uuid5
    else:
        return fast_uuid(which=which, length=length, n=n, **kwargs) # ...or just go to fast_uuid()

    # If a UUID was supplied, try to parse it
    if uid is not None:
        try:
            if isinstance(uid, py_uuid.UUID):
                output = uid # Use directly
            else: # Convert
                output = py_uuid.UUID(uid)
        except Exception as E: # pragma: no cover
            errormsg = f'Could not convert "{uid}" to a UID ({repr(E)})'
            if die:
                raise TypeError(errormsg)
            else:
                print(errormsg)
                uid = None # Just create a new one

    # If not, make a new one
    if uid is None:
        uuid_list = []
        for i in range(n): # Loop over
            uid = uuid_func(**kwargs)  # If not supplied, create a new UUID

            # Convert to a string, and optionally trim
            if tostring or length:
                uid = str(uid)
            if length:
                if length<len(uid):
                    uid = uid[:length]
                else:
                    errormsg = f'Cannot choose first {length} chars since UID has length {len(uid)}'
                    raise ValueError(errormsg)
            uuid_list.append(uid)

        # Process the output: string if 1, list if more
        if len(uuid_list) == 1:
            output = uuid_list[0]
        else:
            output = uuid_list

    return output


def dcp(obj, verbose=True, die=False):
    '''
    Shortcut to perform a deep copy operation

    Almost identical to ``copy.deepcopy()``
    '''
    try:
        output = copy.deepcopy(obj)
    except Exception as E: # pragma: no cover
        output = cp(obj)
        errormsg = f'Warning: could not perform deep copy, performing shallow instead: {str(E)}'
        if die: raise RuntimeError(errormsg)
        else:   print(errormsg)
    return output


def cp(obj, verbose=True, die=True):
    '''
    Shortcut to perform a shallow copy operation

    Almost identical to ``copy.copy()``
    '''
    try:
        output = copy.copy(obj)
    except Exception as E:
        output = obj
        errormsg = 'Could not perform shallow copy, returning original object'
        if die: raise ValueError(errormsg) from E
        else:   print(errormsg)
    return output


def pp(obj, jsonify=True, verbose=False, doprint=True, *args, **kwargs):
    '''
    Shortcut for pretty-printing the object

    Almost identical to ``pprint.pprint()``
    '''
    # Get object
    if jsonify:
        try:
            toprint = json.loads(json.dumps(obj)) # This is to handle things like OrderedDicts
        except Exception as E:
            if verbose: print(f'Could not jsonify object ("{str(E)}"), printing default...')
            toprint = obj # If problems are encountered, just return the object
    else:
        toprint = obj

    # Decide what to do with object
    if doprint:
        pprint.pprint(toprint, *args, **kwargs)
        return None
    else:
        output = pprint.pformat(toprint, *args, **kwargs)
        return output


def sha(obj, encoding='utf-8', digest=False):
    '''
    Shortcut for the standard hashing (SHA) method

    Equivalent to ``hashlib.sha224()``.

    Args:
        obj (any): the object to be hashed; if not a string, converted to one
        encoding (str): the encoding to use
        digest (bool): whether to return the hex digest instead of the hash objet

    **Example**::

        sha1 = sc.sha(dict(foo=1, bar=2), digest=True)
        sha2 = sc.sha(dict(foo=1, bar=2), digest=True)
        sha3 = sc.sha(dict(foo=1, bar=3), digest=True)
        assert sha1 == sha2
        assert sha2 != sha3
    '''
    if not isstring(obj): # Ensure it's actually a string
        string = repr(obj)
    else:
        string = obj
    needsencoding = isinstance(string, str)
    if needsencoding: # If it's unicode, encode it to bytes first
        string = string.encode(encoding)
    output = hashlib.sha224(string)
    if digest:
        output = output.hexdigest()
    return output


def wget(url, convert=True):
    '''
    Download a URL

    Alias to urllib.request.urlopen(url).read()

    **Example**::

        html = sc.wget('http://sciris.org')
    '''
    from urllib import request # Bizarrely, urllib.request sometimes fails
    output = request.urlopen(url).read()
    if convert:
        output = output.decode()
    return output


def htmlify(string, reverse=False, tostring=False):
    '''
    Convert a string to its HTML representation by converting unicode characters,
    characters that need to be escaped, and newlines. If reverse=True, will convert
    HTML to string. If tostring=True, will convert the bytestring back to Unicode.

    **Examples**::

        output = sc.htmlify('foo&\\nbar') # Returns b'foo&amp;<br>bar'
        output = sc.htmlify('föö&\\nbar', tostring=True) # Returns 'f&#246;&#246;&amp;&nbsp;&nbsp;&nbsp;&nbsp;bar'
        output = sc.htmlify('foo&amp;<br>bar', reverse=True) # Returns 'foo&\\nbar'
    '''
    import html
    if not reverse: # Convert to HTML
        output = html.escape(string).encode('ascii', 'xmlcharrefreplace') # Replace non-ASCII characters
        output = output.replace(b'\n', b'<br>') # Replace newlines with <br>
        output = output.replace(b'\t', b'&nbsp;&nbsp;&nbsp;&nbsp;') # Replace tabs with 4 spaces
        if tostring: # Convert from bytestring to unicode
            output = output.decode()
    else: # Convert from HTML
        output = html.unescape(string)
        output = output.replace('<br>','\n').replace('<br />','\n').replace('<BR>','\n')
    return output


def freeze(lower=False):
    '''
    Alias for pip freeze.

    Args:
        lower (bool): convert all keys to lowercase

    **Example**::

        assert 'numpy' in sc.freeze() # One way to check for versions

    New in version 1.2.2.
    '''
    raw = dict(tuple(str(ws).split()) for ws in pkgr.working_set)
    keys = sorted(raw.keys())
    if lower:
        labels = {k:k.lower() for k in keys}
    else:
        labels = {k:k for k in keys}
    data = {labels[k]:raw[k] for k in keys} # Sort alphabetically
    return data


def require(reqs=None, *args, exact=False, detailed=False, die=True, verbose=True, **kwargs):
    '''
    Check whether environment requirements are met. Alias to pkg_resources.require().

    Args:
        reqs (list/dict): a list of strings, or a dict of package names and versions
        args (list): additional requirements
        kwargs (dict): additional requirements
        exact (bool): use '==' instead of '>=' as the default comparison operator if not specified
        detailed (bool): return a dict of which requirements are/aren't met
        die (bool): whether to raise an exception if requirements aren't met
        verbose (bool): print out the exception if it's not being raised

    **Examples**::

        sc.require('numpy')
        sc.require(numpy='')
        sc.require(reqs={'numpy':'1.19.1', 'matplotlib':'3.2.2'})
        sc.require('numpy>=1.19.1', 'matplotlib==3.2.2', die=False)
        sc.require(numpy='1.19.1', matplotlib='==4.2.2', die=False, detailed=True)

    New in version 1.2.2.
    '''

    # Handle inputs
    reqlist = list(args)
    reqdict = kwargs
    if isinstance(reqs, dict):
        reqdict.update(reqs)
    else:
        reqlist = mergelists(reqs, reqlist)

    # Turn into a list of strings
    comparechars = '<>=!~'
    for k,v in reqdict.items():
        if not v:
            entry = k # If no version is provided, entry is just the module name
        else:
            compare = '' if v.startswith(tuple(comparechars)) else ('==' if exact else '>=')
            entry = k + compare + v
        reqlist.append(entry)

    # Check the requirements
    data = dict()
    errs = dict()
    for entry in reqlist:
        try:
            pkgr.require(entry)
            data[entry] = True
        except Exception as E:
            data[entry] = False
            errs[entry] = E

    # Figure out output
    met = all([e==True for e in data.values()])

    # Handle exceptions
    if not met:
        errormsg = 'The following requirements were not met:'
        for k,v in data.items():
            if not v:
                errormsg += f'\n  {k}: {str(errs[k])}'
        if die:
            raise ModuleNotFoundError(errormsg) from errs[k] # Use the last one
        elif verbose:
            print(errormsg)

    # Handle output
    if detailed:
        return data, errs
    else:
        return met


def traceback(*args, **kwargs):
    '''
    Shortcut for accessing the traceback

    Alias for ``traceback.format_exc()``.
    '''
    return py_traceback.format_exc(*args, **kwargs)


def getplatform(expected=None, die=False):
    '''
    Return the name of the current platform: 'linux', 'windows', 'mac', or 'other'.
    Alias (kind of) to sys.platform.

    Args:
        expected (str): if not None, check if the current platform is this
        die (bool): if True and expected is defined, raise an exception

    **Example**::d

        sc.getplatform() # Get current name of platform
        sc.getplatform('windows', die=True) # Raise an exception if not on Windows
    '''
    # Define different aliases for each operating system
    mapping = dict(
        linux   = ['linux', 'posix'],
        windows = ['windows', 'win', 'win32', 'cygwin', 'nt'],
        mac     = ['mac', 'macos', 'darwin', 'osx']
    )

    # Check to see what system it is
    sys_plat = sys.platform
    plat = 'other'
    for key,aliases in mapping.items():
        if sys_plat.lower() in aliases:
            plat = key
            break

    # Handle output
    if expected is not None:
        output = (expected.lower() in mapping[plat]) # Check if it's as expecte
        if not output and die:
            errormsg = f'System is "{plat}", not "{expected}"'
            raise EnvironmentError(errormsg)
    else:
        output = plat
    return output


def iswindows(die=False):
    ''' Alias to sc.getplatform('windows') '''
    return getplatform('windows', die=die)

def islinux(die=False):
    ''' Alias to sc.getplatform('linux') '''
    return getplatform('linux', die=die)

def ismac(die=False):
    ''' Alias to sc.getplatform('mac') '''
    return getplatform('mac', die=die)


##############################################################################
#%% Printing/notification functions
##############################################################################

__all__ += ['printv', 'blank', 'strjoin', 'newlinejoin', 'createcollist', 'objectid', 'objatt', 'objmeth', 'objprop', 'objrepr',
            'prepr', 'pr', 'indent', 'sigfig', 'printarr', 'printdata', 'printvars',
            'slacknotification', 'printtologfile', 'colorize', 'heading', 'percentcomplete', 'progressbar']

def printv(string, thisverbose=1, verbose=2, newline=True, indent=True):
    '''
    Optionally print a message and automatically indent. The idea is that
    a global or shared "verbose" variable is defined, which is passed to
    subfunctions, determining how much detail to print out.

    The general idea is that verbose is an integer from 0-4 as follows:

    * 0 = no printout whatsoever
    * 1 = only essential warnings, e.g. suppressed exceptions
    * 2 = standard printout
    * 3 = extra debugging detail (e.g., printout on each iteration)
    * 4 = everything possible (e.g., printout on each timestep)

    Thus a very important statement might be e.g.

    >>> sc.printv('WARNING, everything is wrong', 1, verbose)

    whereas a much less important message might be

    >>> sc.printv(f'This is timestep {i}', 4, verbose)

    Version: 2016jan30
    '''
    if thisverbose>4 or verbose>4: print(f'Warning, verbosity should be from 0-4 (this message: {thisverbose}; current: {verbose})')
    if verbose>=thisverbose: # Only print if sufficiently verbose
        indents = '  '*thisverbose*bool(indent) # Create automatic indenting
        if newline: print(indents+flexstr(string)) # Actually print
        else: print(indents+flexstr(string)), # Actually print
    return None


def blank(n=3):
    ''' Tiny function to print n blank lines, 3 by default '''
    print('\n'*n)


def strjoin(*args, sep=', '):
    '''
    Like string ``join()``, but handles more flexible inputs, converts items to
    strings. By default, join with commas.

    Args:
        args (list): the list of items to join
        sep (str): the separator string

    **Example**::

        sc.strjoin([1,2,3], 4, 'five')

    New in version 1.1.0.
    '''
    obj = []
    for arg in args:
        if isstring(arg):
            obj.append(arg)
        elif isiterable(arg):
            obj.extend([str(item) for item in arg])
        else:
            obj.append(str(arg))
    output = sep.join(obj)
    return output


def newlinejoin(*args):
    '''
    Alias to ``strjoin(*args, sep='\\n')``.

    **Example**::

        sc.newlinejoin([1,2,3], 4, 'five')

    New in version 1.1.0.
    '''
    return strjoin(*args, sep='\n')


def createcollist(items, title=None, strlen=18, ncol=3):
    ''' Creates a string for a nice columnated list (e.g. to use in __repr__ method) '''
    nrow = int(np.ceil(float(len(items))/ncol))
    newkeys = []
    for x in range(nrow):
        newkeys += items[x::nrow]

    attstring = title + ':' if title else ''
    c = 0
    for x in newkeys:
        if c%ncol == 0: attstring += '\n  '
        if len(x) > strlen: x = x[:strlen-3] + '...'
        attstring += '%-*s  ' % (strlen,x)
        c += 1
    attstring += '\n'
    return attstring


def objectid(obj):
    ''' Return the object ID as per the default Python __repr__ method '''
    c = obj.__class__
    output = f'<{c.__module__}.{c.__name__} at {hex(id(obj))}>\n'
    return output


def objatt(obj, strlen=18, ncol=3):
    ''' Return a sorted string of object attributes for the Python __repr__ method '''
    if   hasattr(obj, '__dict__'):  oldkeys = sorted(obj.__dict__.keys())
    elif hasattr(obj, '__slots__'): oldkeys = sorted(obj.__slots__)
    else:                           oldkeys = []
    if len(oldkeys): output = createcollist(oldkeys, 'Attributes', strlen = 18, ncol = 3)
    else:            output = ''
    return output


def objmeth(obj, strlen=18, ncol=3):
    ''' Return a sorted string of object methods for the Python __repr__ method '''
    try:
        oldkeys = sorted([method + '()' for method in dir(obj) if callable(getattr(obj, method)) and not method.startswith('__')])
    except: # pragma: no cover
        oldkeys = ['Methods N/A']
    if len(oldkeys): output = createcollist(oldkeys, 'Methods', strlen=strlen, ncol=ncol)
    else:            output = ''
    return output


def objprop(obj, strlen=18, ncol=3):
    ''' Return a sorted string of object properties for the Python __repr__ method '''
    try:
        oldkeys = sorted([prop for prop in dir(obj) if isinstance(getattr(type(obj), prop, None), property) and not prop.startswith('__')])
    except: # pragma: no cover
        oldkeys = ['Properties N/A']
    if len(oldkeys): output = createcollist(oldkeys, 'Properties', strlen=strlen, ncol=ncol)
    else:            output = ''
    return output


def objrepr(obj, showid=True, showmeth=True, showprop=True, showatt=True, dividerchar='—', dividerlen=60):
    ''' Return useful printout for the Python __repr__ method '''
    divider = dividerchar*dividerlen + '\n'
    output = ''
    if showid:
        output += objectid(obj)
        output += divider
    if showmeth:
        meths = objmeth(obj)
        if meths:
            output += objmeth(obj)
            output += divider
    if showprop:
        props = objprop(obj)
        if props:
            output += props
            output += divider
    if showatt:
        attrs = objatt(obj)
        if attrs:
            output += attrs
            output += divider
    return output


def prepr(obj, maxlen=None, maxitems=None, skip=None, dividerchar='—', dividerlen=60, use_repr=True, maxtime=3, die=False):
    '''
    Akin to "pretty print", returns a pretty representation of an object --
    all attributes (except any that are skipped), plust methods and ID. Usually
    used via the interactive sc.pr() (which prints), rather than this (which returns
    a string).

    Args:
        obj (anything): the object to be represented
        maxlen (int): maximum number of characters to show for each item
        maxitems (int): maximum number of items to show in the object
        skip (list): any properties to skip
        dividerchar (str): divider for methods, attributes, etc.
        divierlen (int): number of divider characters
        use_repr (bool): whether to use repr() or str() to parse the object
        maxtime (float): maximum amount of time to spend on trying to print the object
        die (bool): whether to raise an exception if an error is encountered
    '''

    # Decide how to handle representation function -- repr is dangerous since can lead to recursion
    repr_fn = repr if use_repr else str
    T = time.time() # Start the timer
    time_exceeded = False

    # Handle input arguments
    divider = dividerchar*dividerlen + '\n'
    if maxlen   is None: maxlen   = 80
    if maxitems is None: maxitems = 100
    if skip     is None: skip = []
    else:                skip = promotetolist(skip)

    # Initialize things to print out
    labels = []
    values = []

    # Wrap entire process in a try-except in case it fails
    try:
        if not (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')):
            # It's a plain object
            labels = [f'{type(obj)}']
            values = [repr_fn(obj)]
        else:
            if hasattr(obj, '__dict__'):
                labels = sorted(set(obj.__dict__.keys()) - set(skip))  # Get the dict attribute keys
            else:
                labels = sorted(set(obj.__slots__) - set(skip))  # Get the slots attribute keys

            if len(labels):
                extraitems = len(labels) - maxitems
                if extraitems>0:
                    labels = labels[:maxitems]
                values = []
                for a,attr in enumerate(labels):
                    if (time.time() - T) < maxtime:
                        try: value = repr_fn(getattr(obj, attr))
                        except: value = 'N/A'
                        values.append(value)
                    else:
                        labels = labels[:a]
                        labels.append('etc. (time exceeded)')
                        values.append(f'{len(labels)-a} entries not shown')
                        time_exceeded = True
                        break
            else:
                items = dir(obj)
                extraitems = len(items) - maxitems
                if extraitems > 0:
                    items = items[:maxitems]
                for a,attr in enumerate(items):
                    if not attr.startswith('__'):
                        if (time.time() - T) < maxtime:
                            try:    value = repr_fn(getattr(obj, attr))
                            except: value = 'N/A'
                            labels.append(attr)
                            values.append(value)
                        else:
                            labels.append('etc. (time exceeded)')
                            values.append(f'{len(labels)-a} entries not shown')
                            time_exceeded = True
            if extraitems > 0:
                labels.append('etc. (too many items)')
                values.append(f'{extraitems} entries not shown')

        # Decide how to print them
        maxkeylen = 0
        if len(labels):
            maxkeylen = max([len(label) for label in labels]) # Find the maximum length of the attribute keys
        if maxkeylen<maxlen:
            maxlen = maxlen - maxkeylen # Shorten the amount of data shown if the keys are long
        formatstr = '%'+ '%i'%maxkeylen + 's' # Assemble the format string for the keys, e.g. '%21s'
        output  = objrepr(obj, showatt=False, dividerchar=dividerchar, dividerlen=dividerlen) # Get the methods
        for label,value in zip(labels,values): # Loop over each attribute
            if len(value)>maxlen: value = value[:maxlen] + ' [...]' # Shorten it
            prefix = formatstr%label + ': ' # The format key
            output += indent(prefix, value)
        output += divider
        if time_exceeded:
            timestr = f'\nNote: the object did not finish printing within maxtime={maxtime} s.\n'
            timestr += 'To see the full object, call prepr() with increased maxtime.'
            output += timestr

    # If that failed, try progressively simpler approaches
    except Exception as E: # pragma: no cover
        if die:
            errormsg = 'Failed to create pretty representation of object'
            raise RuntimeError(errormsg) from E
        else:
            try: # Next try the objrepr, which is the same except doesn't print attribute values
                output = objrepr(obj, dividerchar=dividerchar, dividerlen=dividerlen)
                output += f'\nWarning: showing simplified output since full repr failed {str(E)}'
            except: # If that fails, try just the string representation
                output = str(obj)

    return output


def pr(obj, *args, **kwargs):
    '''
    Shortcut for printing the pretty repr for an object -- similar to prettyprint

    **Example**::

        import pandas as pd
        df = pd.DataFrame({'a':[1,2,3], 'b':[4,5,6]})
        print(df) # See just the data
        sc.pr(df) # See all the methods too
    '''
    print(prepr(obj, *args, **kwargs))
    return None


def indent(prefix=None, text=None, suffix='\n', n=0, pretty=False, simple=True, width=70, **kwargs):
    '''
    Small wrapper to make textwrap more user friendly.

    Args:
        prefix: text to begin with (optional)
        text: text to wrap
        suffix: what to put on the end (by default, a newline)
        n: if prefix is not specified, the size of the indent
        prettify: whether to use pprint to format the text
        kwargs: anything to pass to textwrap.fill() (e.g., linewidth)

    **Examples**::

        prefix = 'and then they said:'
        text = 'blah '*100
        print(indent(prefix, text))

        print('my fave is: ' + indent(text=rand(100), n=14))

    Version: 2017feb20
    '''
    # Handle no prefix
    if prefix is None: prefix = ' '*n

    # Get text in the right format -- i.e. a string
    if pretty: text = pprint.pformat(text)
    else:      text = flexstr(text)

    # If there is no newline in the text, process the output normally.
    if text.find('\n') == -1:
        output = fill(text, initial_indent=prefix, subsequent_indent=' '*len(prefix), width=width, **kwargs)+suffix
    # Otherwise, handle each line separately and splice together the output.
    else:
        textlines = text.split('\n')
        output = ''
        for i, textline in enumerate(textlines):
            if i == 0:
                theprefix = prefix
            else:
                theprefix = ' '*len(prefix)
            output += fill(textline, initial_indent=theprefix, subsequent_indent=' '*len(prefix), width=width, **kwargs)+suffix

    if n: output = output[n:] # Need to remove the fake prefix
    return output



def sigfig(x, sigfigs=5, SI=False, sep=False, keepints=False):
    '''
    Return a string representation of variable x with sigfigs number of significant figures

    Args:
        x (int/float/arr): the number(s) to round
        sigfigs (int): number of significant figures to round to
        SI (bool): whether to use SI notation
        sep (bool/str): if provided, use as thousands separator
        keepints (bool): never round ints

    **Examples**::

        x = 32433.3842
        sc.sigfig(x, SI=True) # Returns 32.433k
        sc.sigfig(x, sep=True) # Returns 32,433
    '''
    output = []

    try:
        n=len(x)
        X = x
        islist = True
    except:
        X = [x]
        n = 1
        islist = False
    for i in range(n):
        x = X[i]

        suffix = ''
        formats = [(1e18,'e18'), (1e15,'e15'), (1e12,'t'), (1e9,'b'), (1e6,'m'), (1e3,'k')]
        if SI:
            for val,suff in formats:
                if abs(x)>=val:
                    x = x/val
                    suffix = suff
                    break # Find at most one match

        try:
            if x==0:
                output.append('0')
            elif sigfigs is None:
                output.append(flexstr(x)+suffix)
            elif x>(10**sigfigs) and not SI and keepints: # e.g. x = 23432.23, sigfigs=3, output is 23432
                roundnumber = int(round(x))
                if sep: string = format(roundnumber, ',')
                else:   string = f'{x:0.0f}'
                output.append(string)
            else:
                magnitude = np.floor(np.log10(abs(x)))
                factor = 10**(sigfigs-magnitude-1)
                x = round(x*factor)/float(factor)
                digits = int(abs(magnitude) + max(0, sigfigs - max(0,magnitude) - 1) + 1 + (x<0) + (abs(x)<1)) # one because, one for decimal, one for minus
                decimals = int(max(0,-magnitude+sigfigs-1))
                strformat = '%' + f'{digits}.{decimals}' + 'f'
                string = strformat % x
                if sep: # To insert separators in the right place, have to convert back to a number
                    if decimals>0:  roundnumber = float(string)
                    else:           roundnumber = int(string)
                    string = format(roundnumber, ',') # Allow comma separator
                string += suffix
                output.append(string)
        except: # pragma: no cover
            output.append(flexstr(x))
    if islist:
        return tuple(output)
    else:
        return output[0]



def printarr(arr, arrformat='%0.2f  '):
    '''
    Print a numpy array nicely.

    **Example**::

        sc.printarr(pl.rand(3,7,4))

    Version: 2014dec01
    '''
    if np.ndim(arr)==1:
        string = ''
        for i in range(len(arr)):
            string += arrformat % arr[i]
        print(string)
    elif np.ndim(arr)==2:
        for i in range(len(arr)):
            printarr(arr[i], arrformat)
    elif np.ndim(arr)==3:
        for i in range(len(arr)):
            print('='*len(arr[i][0])*len(arrformat % 1))
            for j in range(len(arr[i])):
                printarr(arr[i][j], arrformat)
    else: # pragma: no cover
        print(arr) # Give up
    return None



def printdata(data, name='Variable', depth=1, maxlen=40, indent='', level=0, showcontents=False): # pragma: no cover
    '''
    Nicely print a complicated data structure, a la Matlab.

    Note: this function is deprecated.

    Args:
      data: the data to display
      name: the name of the variable (automatically read except for first one)
      depth: how many levels of recursion to follow
      maxlen: number of characters of data to display (if 0, don't show data)
      indent: where to start the indent (used internally)

    Version: 2015aug21
    '''
    datatype = type(data)
    def printentry(data):
        if   datatype==dict:              string = (f'dict with {len(data.keys())} keys')
        elif datatype==list:              string = (f'list of length {len(data)}')
        elif datatype==tuple:             string = (f'tuple of length {len(data)}')
        elif datatype==np.ndarray:        string = (f'array of shape {np.shape(data)}')
        elif datatype.__name__=='module': string = (f'module with {len(dir(data))} components')
        elif datatype.__name__=='class':  string = (f'class with {len(dir(data))} components')
        else: string = datatype.__name__
        if showcontents and maxlen>0:
            datastring = ' | '+flexstr(data)
            if len(datastring)>maxlen: datastring = datastring[:maxlen] + ' <etc> ' + datastring[-maxlen:]
        else: datastring=''
        return string+datastring

    string = printentry(data).replace('\n',' ') # Remove newlines
    print(level*'..' + indent + name + ' | ' + string)

    if depth>0:
        level += 1
        if type(data)==dict:
            keys = data.keys()
            maxkeylen = max([len(key) for key in keys])
            for key in keys:
                thisindent = ' '*(maxkeylen-len(key))
                printdata(data[key], name=key, depth=depth-1, indent=indent+thisindent, level=level)
        elif type(data) in [list, tuple]:
            for i in range(len(data)):
                printdata(data[i], name='[%i]'%i, depth=depth-1, indent=indent, level=level)
        elif type(data).__name__ in ['module', 'class']:
            keys = dir(data)
            maxkeylen = max([len(key) for key in keys])
            for key in keys:
                if key[0]!='_': # Skip these
                    thisindent = ' '*(maxkeylen-len(key))
                    printdata(getattr(data,key), name=key, depth=depth-1, indent=indent+thisindent, level=level)
        print('\n')
    return None


def printvars(localvars=None, varlist=None, label=None, divider=True, spaces=1, color=None):
    '''
    Print out a list of variables. Note that the first argument must be locals().

    Args:
        localvars: function must be called with locals() as first argument
        varlist: the list of variables to print out
        label: optional label to print out, so you know where the variables came from
        divider: whether or not to offset the printout with a spacer (i.e. ------)
        spaces: how many spaces to use between variables
        color: optionally label the variable names in color so they're easier to see

    **Example**::

    >>> a = range(5)
    >>> b = 'example'
    >>> sc.printvars(locals(), ['a','b'], color='green')

    Another useful usage case is to print out the kwargs for a function:

    >>> sc.printvars(locals(), kwargs.keys())

    Version: 2017oct28
    '''

    varlist = promotetolist(varlist) # Make sure it's actually a list
    dividerstr = '-'*40

    if label:  print(f'Variables for {label}:')
    if divider: print(dividerstr)
    for varnum,varname in enumerate(varlist):
        controlstr = f'{varnum}. "{varname}": ' # Basis for the control string -- variable number and name
        if color: controlstr = colorize(color, output=True) + controlstr + colorize('reset', output=True) # Optionally add color
        if spaces>1: controlstr += '\n' # Add a newline if the variables are going to be on different lines
        try:    controlstr += f'{localvars[varname]}' # The variable to be printed
        except: controlstr += 'WARNING, could not be printed' # In case something goes wrong
        controlstr += '\n' * spaces # The number of spaces to add between variables
        print(controlstr), # Print it out
    if divider: print(dividerstr) # If necessary, print the divider again
    return None



def slacknotification(message=None, webhook=None, to=None, fromuser=None, verbose=2, die=False):  # pragma: no cover
    '''
    Send a Slack notification when something is finished.

    The webhook is either a string containing the webhook itself, or a plain text file containing
    a single line which is the Slack webhook. By default it will look for the file
    ".slackurl" in the user's home folder. The webhook needs to look something like
    "https://hooks.slack.com/services/af7d8w7f/sfd7df9sb/lkcpfj6kf93ds3gj". Webhooks are
    effectively passwords and must be kept secure! Alternatively, you can specify the webhook
    in the environment variable SLACKURL.

    Args:
        message (str): The message to be posted.
        webhook (str): See above
        to (str): The Slack channel or user to post to. Channels begin with #, while users begin with @ (note: ignored by new-style webhooks)
        fromuser (str): The pseudo-user the message will appear from (note: ignored by new-style webhooks)
        verbose (bool): How much detail to display.
        die (bool): If false, prints warnings. If true, raises exceptions.

    **Example**::

        sc.slacknotification('Long process is finished')
        sc.slacknotification(webhook='/.slackurl', channel='@username', message='Hi, how are you going?')

    What's the point? Add this to the end of a very long-running script to notify
    your loved ones that the script has finished.

    Version: 2018sep25
    '''
    try:
        from requests import post # Simple way of posting data to a URL
        from json import dumps # For sanitizing the message
    except Exception as E:
        errormsg = f'Cannot use Slack notification since imports failed: {str(E)}'
        if die: raise ImportError(errormsg)
        else:   print(errormsg)

    # Validate input arguments
    printv('Sending Slack message', 1, verbose)
    if not webhook:  webhook    = os.path.expanduser('~/.slackurl')
    if not to:       to       = '#general'
    if not fromuser: fromuser = 'sciris-bot'
    if not message:  message  = 'This is an automated notification: your notifier is notifying you.'
    printv(f'Channel: {to} | User: {fromuser} | Message: {message}', 3, verbose) # Print details of what's being sent

    # Try opening webhook as a file
    if webhook.find('hooks.slack.com')>=0: # It seems to be a URL, let's proceed
        slackurl = webhook
    elif os.path.exists(os.path.expanduser(webhook)): # If not, look for it sa a file
        with open(os.path.expanduser(webhook)) as f: slackurl = f.read()
    elif os.getenv('SLACKURL'): # See if it's set in the user's environment variables
        slackurl = os.getenv('SLACKURL')
    else:
        slackurl = webhook # It doesn't seemt to be a URL but let's try anyway
        errormsg = f'"{webhook}" does not seem to be a valid webhook string or file'
        if die: raise ValueError(errormsg)
        else:   print(errormsg)

    # Package and post payload
    try:
        payload = '{"text": %s, "channel": %s, "username": %s}' % (dumps(message), dumps(to), dumps(fromuser))
        printv(f'Full payload: {payload}', 4, verbose)
        response = post(url=slackurl, data=payload)
        printv(response, 3, verbose) # Optionally print response
        printv('Message sent.', 2, verbose) # We're done
    except Exception as E:
        errormsg = f'Sending of Slack message failed: {repr(E)}'
        if die: raise RuntimeError(errormsg)
        else:   print(errormsg)
    return None


def printtologfile(message=None, filename=None):
    '''
    Append a message string to a file specified by a filename name/path.  This
    is especially useful for capturing information from spawned processes not
    so handily captured through print statements.

    Warning: If you pass a file in, existing or not, it will try to append
    text to it!
    '''

    # Set defaults
    if message is None:
        return None # Return immediately if nothing to append
    if filename is None:
        import tempfile
        tempdir = tempfile.gettempdir()
        filename = os.path.join(tempdir, 'logfile') # Some generic filename that should work on *nix systems

    # Try writing to file
    try:
        with open(filename, 'a') as f:
            f.write('\n'+message+'\n') # Add a newline to the message.
    except Exception as E: # pragma: no cover # Fail gracefully
        print(f'Warning, could not write to logfile {filename}: {str(E)}')

    return None


def colorize(color=None, string=None, output=False, showhelp=False, enable=True):
    '''
    Colorize output text.

    Args:
        color: the color you want (use 'bg' with background colors, e.g. 'bgblue')
        string: the text to be colored
        output: whether to return the modified version of the string
        enable: switch to allow colorize() to be easily turned off

    **Examples**::

        sc.colorize('green', 'hi') # Simple example
        sc.colorize(['yellow', 'bgblack']); print('Hello world'); print('Goodbye world'); colorize() # Colorize all output in between
        bluearray = sc.colorize(color='blue', string=str(range(5)), output=True); print("c'est bleu: " + bluearray)
        sc.colorize('magenta') # Now type in magenta for a while
        sc.colorize() # Stop typing in magenta

    To get available colors, type ``sc.colorize(showhelp=True)``.

    Version: 2018sep09
    '''

    # Handle short-circuit case
    if not enable: # pragma: no cover
        if output:
            return string
        else:
            print(string)
            return None

    # Define ANSI colors
    ansicolors = OD([
        ('black', '30'),
        ('red', '31'),
        ('green', '32'),
        ('yellow', '33'),
        ('blue', '34'),
        ('magenta', '35'),
        ('cyan', '36'),
        ('gray', '37'),
        ('bgblack', '40'),
        ('bgred', '41'),
        ('bggreen', '42'),
        ('bgyellow', '43'),
        ('bgblue', '44'),
        ('bgmagenta', '45'),
        ('bgcyan', '46'),
        ('bggray', '47'),
        ('reset', '0'),
    ])
    for key, val in ansicolors.items(): ansicolors[key] = '\033[' + val + 'm'

    # Determine what color to use
    colorlist = promotetolist(color)  # Make sure it's a list
    for color in colorlist:
        if color not in ansicolors.keys(): # pragma: no cover
            print(f'Color "{color}" is not available, use colorize(showhelp=True) to show options.')
            return None  # Don't proceed if the color isn't found
    ansicolor = ''
    for color in colorlist:
        ansicolor += ansicolors[color]

    # Modify string, if supplied
    if string is None: ansistring = ansicolor # Just return the color
    else:              ansistring = ansicolor + str(string) + ansicolors['reset'] # Add to start and end of the string
    if not ansi_support: ansistring = str(string) # To avoid garbling output on unsupported systems

    if showhelp:
        print('Available colors are:')
        for key in ansicolors.keys():
            if key[:2] == 'bg':
                darks = ['bgblack', 'bgred', 'bgblue', 'bgmagenta']
                if key in darks: foreground = 'gray'
                else:            foreground = 'black'
                helpcolor = [foreground, key]
            else:
                helpcolor = key
            colorize(helpcolor, '  ' + key)
    elif output:
        return ansistring  # Return the modified string
    else:
        try:    print(ansistring) # Content, so print with newline
        except: print(string) # If that fails, just go with plain version
        return None


def heading(string=None, *args, color=None, divider=None, spaces=None, minlength=None, maxlength=None, sep=' ', output=True, **kwargs):
    '''
    Create a colorful heading. If just supplied with a string (or list of inputs like print()),
    create blue text with horizontal lines above and below and 3 spaces above. You
    can customize the color, the divider character, how many spaces appear before
    the heading, and the minimum length of the divider (otherwise will expand to
    match the length of the string, up to a maximum length).

    Args:
        string (str): The string to print as the heading (or object to convert to astring)
        args (list): Additional strings to print
        color (str): The color to use for the heading (default blue)
        divider (str): The symbol to use for the divider (default em dash)
        spaces (int): The number of spaces to put before the heading
        minlength (int): The minimum length of the divider
        maxlength (int): The maximum length of the divider
        sep (str): If multiple arguments are supplied, use this separator to join them
        output (bool): Whether to return the string as output (else, print)
        kwargs (dict): Arguments to pass to sc.colorize()

    Returns:
        String, unless output=False.

    Examples
    --------
    >>> import sciris as sc
    >>> sc.heading('This is a heading')
    >>> sc.heading(string='This is also a heading', color='red', divider='*', spaces=0, minlength=50)
    '''
    if string    is None: string    = ''
    if color     is None: color     = 'cyan' # Reasonable defualt for light and dark consoles
    if divider   is None: divider   = '—' # Em dash for a continuous line
    if spaces    is None: spaces    = 2
    if minlength is None: minlength = 30
    if maxlength is None: maxlength = 120

    # Convert to single string
    args = list(args)
    if string is not None:
        args = [string] + args
    string = sep.join(str(item) for item in args)

    # Add header and footer
    length = int(np.median([minlength, len(string), maxlength]))
    space = '\n'*spaces
    if divider and length: fulldivider = '\n'+divider*length+'\n'
    else:                  fulldivider = ''
    fullstring = space + fulldivider + string + fulldivider

    # Create output
    outputstring = colorize(color=color, string=fullstring, **kwargs)

    if output:
        return outputstring
    else:
        print(outputstring)
        return


def percentcomplete(step=None, maxsteps=None, stepsize=1, prefix=None):
    '''
    Display progress.

    **Example**::

        maxiters = 500
        for i in range(maxiters):
            sc.percentcomplete(i, maxiters) # will print on every 5th iteration
            sc.percentcomplete(i, maxiters, stepsize=10) # will print on every 50th iteration
            sc.percentcomplete(i, maxiters, prefix='Completeness: ') # will print e.g. 'Completeness: 1%'
    '''
    if prefix is None:
        prefix = ' '
    elif isnumber(prefix):
        prefix = ' '*prefix
    onepercent = max(stepsize,round(maxsteps/100*stepsize)); # Calculate how big a single step is -- not smaller than 1
    if not step%onepercent: # Does this value lie on a percent
        thispercent = round(step/maxsteps*100) # Calculate what percent it is
        print(prefix + '%i%%'% thispercent) # Display the output
    return None


def progressbar(i, maxiters, label='', length=30, empty='—', full='•', newline=False):
    '''
    Call in a loop to create terminal progress bar.

    Args:
        i (int): current iteration
        maxiters (int): maximum number of iterations
        label (str): initial label to print
        length (int): length of progress bar
        empty (str): character for empty steps
        full (str): character for empty steps

    **Example**::

        import pylab as pl
        for i in range(100):
            progressbar(i+1, 100)
            pl.pause(0.05)

    Adapted from example by Greenstick (https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console)
    '''
    ending = None if newline else '\r'
    pct = i/maxiters*100
    percent = f'{pct:0.0f}%'
    filled = int(length*i//maxiters)
    bar = full*filled + empty*(length-filled)
    print(f'\r{label} {bar} {percent}', end=ending)
    if i == maxiters: print()
    return



##############################################################################
#%% Type functions
##############################################################################

__all__ += ['flexstr', 'isiterable', 'checktype', 'isnumber', 'isstring', 'isarray',
            'promotetoarray', 'promotetolist', 'toarray', 'tolist', 'transposelist',
            'mergedicts', 'mergelists']

def flexstr(arg, force=True):
    '''
    Try converting any object to a "regular" string (i.e. ``str``), but proceed
    if it fails. Note: this function calls ``repr()`` rather than ``str()`` to
    ensure a more robust representation of objects.
    '''
    if isinstance(arg, str):
        return arg
    elif isinstance(arg, bytes):
        try:
            output = arg.decode() # If it's bytes, decode to unicode
        except: # pragma: no cover
            if force: output = repr(arg) # If that fails, just print its representation
            else:     output = arg
    else: # pragma: no cover
        if force: output = repr(arg)
        else:     output = arg # Optionally don't do anything for non-strings
    return output


def isiterable(obj):
    '''
    Simply determine whether or not the input is iterable.

    Works by trying to iterate via iter(), and if that raises an exception, it's
    not iterable.

    From http://stackoverflow.com/questions/1952464/in-python-how-do-i-determine-if-an-object-is-iterable
    '''
    try:
        iter(obj)
        return True
    except:
        return False


def checktype(obj=None, objtype=None, subtype=None, die=False):
    '''
    A convenience function for checking instances. If objtype is a type,
    then this function works exactly like isinstance(). But, it can also
    be one of the following strings:

        - 'str', 'string': string or bytes object
        - 'num', 'number': any kind of number
        - 'arr', 'array': a Numpy array (equivalent to np.ndarray)
        - 'listlike': a list, tuple, or array
        - 'arraylike': a list, tuple, or array with numeric entries

    If subtype is not None, then checktype will iterate over the object and check
    recursively that each element matches the subtype.

    Args:
        obj     (any):         the object to check the type of
        objtype (str or type): the type to confirm the object belongs to
        subtype (str or type): optionally check the subtype if the object is iterable
        die     (bool):        whether or not to raise an exception if the object is the wrong type

    **Examples**::

        sc.checktype(rand(10), 'array', 'number') # Returns True
        sc.checktype(['a','b','c'], 'listlike') # Returns True
        sc.checktype(['a','b','c'], 'arraylike') # Returns False
        sc.checktype([{'a':3}], list, dict) # Returns True
    '''

    # Handle "objtype" input
    if   objtype in ['str','string']:          objinstance = _stringtypes
    elif objtype in ['num', 'number']:         objinstance = _numtype
    elif objtype in ['arr', 'array']:          objinstance = np.ndarray
    elif objtype in ['listlike', 'arraylike']: objinstance = (list, tuple, np.ndarray) # Anything suitable as a numerical array
    elif type(objtype) == type:                objinstance = objtype  # Don't need to do anything
    elif objtype is None:                      return None # If not supplied, exit
    else: # pragma: no cover
        errormsg = f'Could not understand what type you want to check: should be either a string or a type, not "{objtype}"'
        raise ValueError(errormsg)

    # Do first-round checking
    result = isinstance(obj, objinstance)

    # Do second round checking
    if result and objtype in ['listlike', 'arraylike']: # Special case for handling arrays which may be multi-dimensional
        obj = promotetoarray(obj).flatten() # Flatten all elements
        if objtype == 'arraylike' and subtype is None: subtype = 'number'
    if isiterable(obj) and subtype is not None:
        for item in obj:
            result = result and checktype(item, subtype)

    # Decide what to do with the information thus gleaned
    if die: # Either raise an exception or do nothing if die is True
        if not result: # It's not an instance
            errormsg = f'Incorrect type: object is {type(obj)}, but {objtype} is required'
            raise TypeError(errormsg)
        else:
            return None # It's fine, do nothing
    else: # Return the result of the comparison
        return result


def isnumber(obj, isnan=None):
    '''
    Determine whether or not the input is a number.

    Args:
        obj (any): the object to check if it's a number
        isnan (bool): an optional additional check to determine whether the number is/isn't NaN

    Almost identical to isinstance(obj, numbers.Number).
    '''
    output = checktype(obj, 'number')
    if output and isnan is not None: # It is a number, so can check for nan
        output = (np.isnan(obj) == isnan) # See if they match
    return output


def isstring(obj):
    '''
    Determine whether or not the input is a string (i.e., str or bytes).

    Equivalent to isinstance(obj, (str, bytes))
    '''
    return checktype(obj, 'string')


def isarray(obj, dtype=None):
    '''
    Check whether something is a Numpy array, and optionally check the dtype.

    Almost the same as ``isinstance(obj, np.ndarray)``.

    **Example**::

        sc.isarray(np.array([1,2,3]), dtype=float) # False, dtype is int

    New in version 1.0.0.
    '''
    if isinstance(obj, np.ndarray):
        if dtype is None:
            return True
        else:
            if obj.dtype == dtype:
                return True
            else:
                return False


def promotetoarray(x, keepnone=False, **kwargs):
    '''
    Small function to ensure consistent format for things that should be arrays
    (note: toarray()/promotetoarray() are identical).

    Very similar to ``np.array``, with the main difference being that ``sc.promotetoarray(3)``
    will return ``np.array([3])`` (i.e. a 1-d array that can be iterated over), while
    ``np.array(3)`` will return a 0-d array that can't be iterated over.

    Args:
        keepnone (bool): whether ``sc.promotetoarray(None)`` should return ``np.array([])`` or ``np.array([None], dtype=object)``
        kwargs (dict): passed to ``np.array()``

    **Examples**::

        sc.promotetoarray(5) # Returns np.array([5])
        sc.promotetoarray([3,5]) # Returns np.array([3,5])
        sc.promotetoarray(None, skipnone=True) # Returns np.array([])

    New in version 1.1.0: replaced "skipnone" with "keepnone"; allowed passing
    kwargs to ``np.array()``.
    '''
    skipnone = kwargs.pop('skipnone', None)
    if skipnone is not None: # pragma: no cover
        keepnone = not(skipnone)
        warnmsg = 'sc.promotetoarray() argument "skipnone" has been deprecated as of v1.1.0; use keepnone instead'
        warnings.warn(warnmsg, category=DeprecationWarning, stacklevel=2)
    if isnumber(x) or (isinstance(x, np.ndarray) and not np.shape(x)): # e.g. 3 or np.array(3)
        x = [x]
    elif x is None and not keepnone:
        x = []
    output = np.array(x, **kwargs)
    return output


def promotetolist(obj=None, objtype=None, keepnone=False, coerce='default'):
    '''
    Make sure object is always a list (note: tolist()/promotetolist() are identical).

    Used so functions can handle inputs like ``'a'``  or ``['a', 'b']``. In other
    words, if an argument can either be a single thing (e.g., a single dict key)
    or a list (e.g., a list of dict keys), this function can be used to do the
    conversion, so it's always safe to iterate over the output.

    While this usually wraps objects in a list rather than converts them to a list,
    the "coerce" argument can be used to change this behavior. Options are:

    - 'none' or None: do not coerce
    - 'default': coerce objects that were lists in Python 2 (range, map, dict_keys, dict_values, dict_items)
    - 'full': all the types in default, plus tuples and arrays

    Args:
        obj (anything): object to ensure is a list
        objtype (anything): optional type to check for each element; see ``sc.checktype()`` for details
        keepnone (bool): if ``keepnone`` is false, then ``None`` is converted to ``[]``; else, it's converted to ``[None]``
        coerce (str/tuple):  tuple of additional types to coerce to a list (as opposed to wrapping in a list)

    **Examples**::

        sc.promotetolist(5) # Returns [5]
        sc.promotetolist(np.array([3,5])) # Returns [np.array([3,5])] -- not [3,5]!
        sc.promotetolist(np.array([3,5]), coerce=np.ndarray) # Returns [3,5], since arrays are coerced to lists
        sc.promotetolist(None) # Returns []
        sc.promotetolist(range(3)) # Returns [0,1,2] since range is coerced by default
        sc.promotetolist(['a', 'b', 'c'], objtype='number') # Raises exception

        def myfunc(data, keys):
            keys = sc.promotetolist(keys)
            for key in keys:
                print(data[key])

        data = {'a':[1,2,3], 'b':[4,5,6]}
        myfunc(data, keys=['a', 'b']) # Works
        myfunc(data, keys='a') # Still works, equivalent to needing to supply keys=['a'] without promotetolist()

    New in version 1.1.0: "coerce" argument
    New in version 1.2.2: default coerce values
    '''
    # Handle coerce
    default_coerce = (range, map, type({}.keys()), type({}.values()), type({}.items()))
    if isinstance(coerce, str):
        if coerce == 'none':
            coerce = None
        elif coerce == 'default':
            coerce = default_coerce
        elif coerce == 'full':
            coerce = default_coerce + (tuple, np.ndarray)
        else:
            errormsg = f'Option "{coerce}"; not recognized; must be "none", "default", or "full"'

    if objtype is None: # Don't do type checking
        if isinstance(obj, list):
            output = obj # If it's already a list and we're not doing type checking, just return
        elif obj is None:
            if keepnone:
                output = [None] # Wrap in a list
            else:
                output = [] # Return an empty list, the "none" equivalent for a list
        else:
            if coerce is not None and isinstance(obj, coerce):
                output = list(obj) # Coerce to list
            else:
                output = [obj] # Main usage case -- listify it
    else: # Do type checking
        if checktype(obj=obj, objtype=objtype, die=False):
            output = [obj] # If the object is already of the right type, wrap it in a list
        else:
            try:
                if not isiterable(obj): # Ensure it's iterable -- a mini promote-to-list
                    iterable_obj = [obj]
                else:
                    iterable_obj = obj
                for item in iterable_obj:
                    checktype(obj=item, objtype=objtype, die=True)
                output = list(iterable_obj) # If all type checking passes, cast to list instead of wrapping
            except TypeError as E:
                errormsg = f'promotetolist(): type mismatch, expecting type {objtype}'
                raise TypeError(errormsg) from E
    return output


# Aliases for core functions
toarray = promotetoarray
tolist = promotetolist


def transposelist(obj):
    '''
    Convert e.g. a list of key-value tuples into a list of keys and a list of values.

    **Example**::

        o = sc.odict(a=1, b=4, c=9, d=16)
        itemlist = o.enumitems()
        inds, keys, vals = sc.transposelist(itemlist)

    New in version 1.1.0.
    '''
    return list(map(list, zip(*obj)))


def mergedicts(*args, strict=False, overwrite=True, copy=False):
    '''
    Small function to merge multiple dicts together. By default, skips things
    that are not, dicts (e.g., None), and allows keys to be set multiple times.
    Similar to dict.update(), except returns a value. The first dictionary supplied
    will be used for the output type (e.g. if the first dictionary is an odict,
    an odict will be returned).

    Useful for cases, e.g. function arguments, where the default option is ``None``
    but you will need a dict later on.

    Args:
        strict    (bool): if True, raise an exception if an argument isn't a dict
        overwrite (bool): if False, raise an exception if multiple keys are found
        copy      (bool): whether or not to deepcopy the merged dictionary
        *args     (dict): the sequence of dicts to be merged

    **Examples**::

        d0 = sc.mergedicts(user_args) # Useful if user_args might be None, but d0 is always a dict
        d1 = sc.mergedicts({'a':1}, {'b':2}) # Returns {'a':1, 'b':2}
        d2 = sc.mergedicts({'a':1, 'b':2}, {'b':3, 'c':4}) # Returns {'a':1, 'b':3, 'c':4}
        d3 = sc.mergedicts(sc.odict({'b':3, 'c':4}), {'a':1, 'b':2}) # Returns sc.odict({'b':2, 'c':4, 'a':1})
        d4 = sc.mergedicts({'b':3, 'c':4}, {'a':1, 'b':2}, overwrite=False) # Raises exception

    New in version 1.1.0: "copy" argument
    '''
    # Try to get the output type from the first argument, but revert to a standard dict if that fails
    try:
        assert isinstance(args[0], dict)
        outputdict = args[0].__class__() # This creates a new instance of the class
    except:
        outputdict = {}

    # Merge over the dictionaries in order
    for arg in args:
        is_dict = isinstance(arg, dict)
        if strict and not is_dict:
            errormsg = f'Argument of "{type(arg)}" found; must be dict since strict=True'
            raise TypeError(errormsg)
        if is_dict:
            if not overwrite:
                intersection = set(outputdict.keys()).intersection(arg.keys())
                if len(intersection):
                    errormsg = f'Could not merge dicts since keys "{strjoin(intersection)}" overlap and overwrite=False'
                    raise KeyError(errormsg)
            outputdict.update(arg)
    if copy:
        outputdict = dcp(outputdict)
    return outputdict


def mergelists(*args, copy=False, **kwargs):
    '''
    Merge multiple lists together.

    Args:
        args (any): the lists, or items, to be joined together into a list
        copy (bool): whether to deepcopy the resultant object
        kwargs (dict): passed to ``sc.promotetolist()``, which is called on each argument

    **Examples**::

        sc.mergelists(None) # Returns []
        sc.mergelists([1,2,3], [4,5,6]) # Returns [1, 2, 3, 4, 5, 6]
        sc.mergelists([1,2,3], 4, 5, 6) # Returns [1, 2, 3, 4, 5, 6]
        sc.mergelists([(1,2), (3,4)], (5,6)) # Returns [(1, 2), (3, 4), (5, 6)]
        sc.mergelists((1,2), (3,4), (5,6)) # Returns [(1, 2), (3, 4), (5, 6)]
        sc.mergelists((1,2), (3,4), (5,6), coerce=tuple) # Returns [1, 2, 3, 4, 5, 6]

    New in version 1.1.0.
    '''
    obj = []
    for arg in args:
        arg = promotetolist(arg, **kwargs)
        obj.extend(arg)
    if copy:
        obj = dcp(obj)
    return obj


##############################################################################
#%% Time/date functions
##############################################################################

__all__ += ['now', 'getdate', 'readdate', 'date', 'day', 'daydiff', 'daterange', 'datedelta', 'datetoyear',
            'elapsedtimestr', 'tic', 'toc', 'toctic', 'timedsleep']

def now(timezone=None, utc=False, die=False, astype='dateobj', tostring=False, dateformat=None):
    '''
    Get the current time, optionally in UTC time.

    **Examples**::

        sc.now() # Return current local time, e.g. 2019-03-14 15:09:26
        sc.now('US/Pacific') # Return the time now in a specific timezone
        sc.now(utc=True) # Return the time in UTC
        sc.now(astype='str') # Return the current time as a string instead of a date object
        sc.now(tostring=True) # Backwards-compatible alias for astype='str'
        sc.now(dateformat='%Y-%b-%d') # Return a different date format
    '''
    if isinstance(utc, str): timezone = utc # Assume it's a timezone
    if timezone is not None: tzinfo = dateutil.tz.gettz(timezone) # Timezone is a string
    elif utc:                tzinfo = dateutil.tz.tzutc() # UTC has been specified
    else:                    tzinfo = None # Otherwise, do nothing
    if tostring: astype = 'str'
    timenow = dt.datetime.now(tzinfo)
    output = getdate(timenow, astype=astype, dateformat=dateformat)
    return output



def getdate(obj=None, astype='str', dateformat=None):
        '''
        Alias for converting a date object to a formatted string.

        **Examples**::

            sc.getdate() # Returns a string for the current date
            sc.getdate(sc.now(), astype='int') # Convert today's time to an integer
        '''
        if obj is None:
            obj = now()

        if dateformat is None:
            dateformat = '%Y-%b-%d %H:%M:%S'
        else:
            astype = 'str' # If dateformat is specified, assume type is a string

        try:
            if isstring(obj):
                return obj # Return directly if it's a string
            obj.timetuple() # Try something that will only work if it's a date object
            dateobj = obj # Test passed: it's a date object
        except Exception as E: # pragma: no cover # It's not a date object
            errormsg = f'Getting date failed; date must be a string or a date object: {repr(E)}'
            raise TypeError(errormsg)

        if   astype == 'str':     output = dateobj.strftime(dateformat)
        elif astype == 'int':     output = time.mktime(dateobj.timetuple()) # So ugly!! But it works -- return integer representation of time
        elif astype == 'dateobj': output = dateobj
        else: # pragma: no cover
            errormsg = f'"astype={astype}" not understood; must be "str" or "int"'
            raise ValueError(errormsg)
        return output


def _sanitize_iterables(obj, *args):
    '''
    Take input as a list, array, or non-iterable type, along with one or more
    arguments, and return a list, along with information on what the input types
    were.

    **Examples**::

        _sanitize_iterables(1, 2, 3)             # Returns [1,2,3], False, False
        _sanitize_iterables([1, 2], 3)           # Returns [1,2,3], True, False
        _sanitize_iterables(np.array([1, 2]), 3) # Returns [1,2,3], True, True
        _sanitize_iterables(np.array([1, 2, 3])) # Returns [1,2,3], False, True
    '''
    is_list = isinstance(obj, list) or len(args)>0 # If we're given a list of args, treat it like a list
    is_array = isinstance(obj, np.ndarray) # Check if it's an array
    if is_array: # If it is, convert it to a list
        obj = obj.tolist()
    objs = dcp(promotetolist(obj)) # Ensure it's a list, and deepcopy to avoid mutability
    objs.extend(args) # Add on any arguments
    return objs, is_list, is_array


def _sanitize_output(obj, is_list, is_array, dtype=None):
    '''
    The companion to _sanitize_iterables, convert the object back to the original
    type supplied.
    '''
    if is_array:
        output = np.array(obj, dtype=dtype)
    elif not is_list and len(obj) == 1:
        output = obj[0]
    else:
        output = obj
    return output


def readdate(datestr=None, *args, dateformat=None, return_defaults=False):
    '''
    Convenience function for loading a date from a string. If dateformat is None,
    this function tries a list of standard date types.

    By default, a numeric date is treated as a POSIX (Unix) timestamp. This can be changed
    with the ``dateformat`` argument, specifically:

    - 'posix'/None: treat as a POSIX timestamp, in seconds from 1970
    - 'ordinal'/'matplotlib': treat as an ordinal number of days from 1970 (Matplotlib default)

    Args:
        datestr (int, float, str or list): the string containing the date, or the timestamp (in seconds), or a list of either
        args (list): additional dates to convert
        dateformat (str or list): the format for the date, if known; if 'dmy' or 'mdy', try as day-month-year or month-day-year formats; can also be a list of options
        return_defaults (bool): don't convert the date, just return the defaults

    Returns:
        dateobj (date): a datetime object

    **Examples**::

        dateobj  = sc.readdate('2020-03-03') # Standard format, so works
        dateobj  = sc.readdate('04-03-2020', dateformat='dmy') # Date is ambiguous, so need to specify day-month-year order
        dateobj  = sc.readdate(1611661666) # Can read timestamps as well
        dateobj  = sc.readdate(16166, dateformat='ordinal') # Or ordinal numbers of days, as used by Matplotlib
        dateobjs = sc.readdate(['2020-06', '2020-07'], dateformat='%Y-%m') # Can read custom date formats
        dateobjs = sc.readdate('20200321', 1611661666) # Can mix and match formats
    '''

    # Define default formats
    formats_to_try = {
        'date':           '%Y-%m-%d', # 2020-03-21
        'date-slash':     '%Y/%m/%d', # 2020/03/21
        'date-dot':       '%Y.%m.%d', # 2020.03.21
        'date-space':     '%Y %m %d', # 2020 03 21
        'date-alpha':     '%Y-%b-%d', # 2020-Mar-21
        'date-alpha-rev': '%d-%b-%Y', # 21-Mar-2020
        'date-alpha-sp':  '%d %b %Y', # 21 Mar 2020
        'date-Alpha':     '%Y-%B-%d', # 2020-March-21
        'date-Alpha-rev': '%d-%B-%Y', # 21-March-2020
        'date-Alpha-sp':  '%d %B %Y', # 21 March 2020
        'date-numeric':   '%Y%m%d',   # 20200321
        'datetime':       '%Y-%m-%d %H:%M:%S',    # 2020-03-21 14:35:21
        'datetime-alpha': '%Y-%b-%d %H:%M:%S',    # 2020-Mar-21 14:35:21
        'default':        '%Y-%m-%d %H:%M:%S.%f', # 2020-03-21 14:35:21.23483
        'ctime':          '%a %b %d %H:%M:%S %Y', # Sat Mar 21 23:09:29 2020
        }

    # Define day-month-year formats
    dmy_formats = {
        'date':           '%d-%m-%Y', # 21-03-2020
        'date-slash':     '%d/%m/%Y', # 21/03/2020
        'date-dot':       '%d.%m.%Y', # 21.03.2020
        'date-space':     '%d %m %Y', # 21 03 2020
    }

    # Define month-day-year formats
    mdy_formats = {
        'date':           '%m-%d-%Y', # 03-21-2020
        'date-slash':     '%m/%d/%Y', # 03/21/2020
        'date-dot':       '%m.%d.%Y', # 03.21.2020
        'date-space':     '%m %d %Y', # 03 21 2020
    }

    # To get the available formats
    if return_defaults:
        return formats_to_try

    # Handle date formats
    format_list = promotetolist(dateformat, keepnone=True) # Keep none which signifies default
    if dateformat is not None:
        if dateformat == 'dmy':
            formats_to_try = dmy_formats
        elif dateformat == 'mdy':
            formats_to_try = mdy_formats
        else:
            formats_to_try = {}
            for f,fmt in enumerate(format_list):
                formats_to_try[f'User supplied {f}'] = fmt

    # Ensure everything is in a consistent format
    datestrs, is_list, is_array = _sanitize_iterables(datestr, *args)

    # Actually process the dates
    dateobjs = []
    for datestr in datestrs: # Iterate over them
        dateobj = None
        exceptions = {}
        if isinstance(datestr, dt.datetime):
            dateobj = datestr # Nothing to do
        elif isnumber(datestr):
            if 'posix' in format_list or None in format_list:
                dateobj = dt.datetime.fromtimestamp(datestr)
            elif 'ordinal' in format_list or 'matplotlib' in format_list:
                dateobj = pl.num2date(datestr)
            else:
                errormsg = f'Could not convert numeric date {datestr} using available formats {strjoin(format_list)}; must be "posix" or "ordinal"'
                raise ValueError(errormsg)
        else:
            for key,fmt in formats_to_try.items():
                try:
                    dateobj = dt.datetime.strptime(datestr, fmt)
                    break # If we find one that works, we can stop
                except Exception as E:
                    exceptions[key] = str(E)
            if dateobj is None:
                formatstr = newlinejoin([f'{item[1]}' for item in formats_to_try.items()])
                errormsg = f'Was unable to convert "{datestr}" to a date using the formats:\n{formatstr}'
                if dateformat not in ['dmy', 'mdy']:
                    errormsg += '\n\nNote: to read day-month-year or month-day-year dates, use dateformat="dmy" or "mdy" respectively.'
                raise ValueError(errormsg)
        dateobjs.append(dateobj)

    # If only a single date was supplied, return just that; else return the list/array
    output = _sanitize_output(dateobjs, is_list, is_array, dtype=object)
    return output


[docs]def date(obj, *args, start_date=None, readformat=None, outformat=None, as_date=True, **kwargs): ''' Convert any reasonable object -- a string, integer, or datetime object, or list/array of any of those -- to a date object. To convert an integer to a date, you must supply a start date. Caution: while this function and readdate() are similar, and indeed this function calls readdate() if the input is a string, in this function an integer is treated as a number of days from start_date, while for readdate() it is treated as a timestamp in seconds. To change Args: obj (str, int, date, datetime, list, array): the object to convert args (str, int, date, datetime): additional objects to convert start_date (str, date, datetime): the starting date, if an integer is supplied readformat (str/list): the format to read the date in; passed to sc.readdate() outformat (str): the format to output the date in, if returning a string as_date (bool): whether to return as a datetime date instead of a string Returns: dates (date or list): either a single date object, or a list of them (matching input data type where possible) **Examples**:: sc.date('2020-04-05') # Returns datetime.date(2020, 4, 5) sc.date([35,36,37], start_date='2020-01-01', as_date=False) # Returns ['2020-02-05', '2020-02-06', '2020-02-07'] sc.date(1923288822, readformat='posix') # Interpret as a POSIX timestamp New in version 1.0.0. New in version 1.2.2: "readformat" argument; renamed "dateformat" to "outformat" ''' # Handle deprecation dateformat = kwargs.pop('dateformat', None) if dateformat is not None: # pragma: no cover outformat = dateformat warnmsg = 'sc.date() argument "dateformat" has been deprecated as of v1.2.2; use "outformat" instead' warnings.warn(warnmsg, category=DeprecationWarning, stacklevel=2) # Convert to list and handle other inputs if obj is None: return None if outformat is None: outformat = '%Y-%m-%d' obj, is_list, is_array = _sanitize_iterables(obj, *args) dates = [] for d in obj: if d is None: dates.append(d) continue try: if type(d) == dt.date: # Do not use isinstance, since must be the exact type pass elif isinstance(d, dt.datetime): d = d.date() elif isstring(d): d = readdate(d, dateformat=readformat).date() elif isnumber(d): if readformat is not None: d = readdate(d, dateformat=readformat).date() else: if start_date is None: errormsg = f'To convert the number {d} to a date, you must either specify "posix" or "ordinal" read format, or supply start_date' raise ValueError(errormsg) d = date(start_date) + dt.timedelta(days=int(d)) else: # pragma: no cover errormsg = f'Cannot interpret {type(d)} as a date, must be date, datetime, or string' raise TypeError(errormsg) if as_date: dates.append(d) else: dates.append(d.strftime(outformat)) except Exception as E: errormsg = f'Conversion of "{d}" to a date failed: {str(E)}' raise ValueError(errormsg) # Return an integer rather than a list if only one provided output = _sanitize_output(dates, is_list, is_array, dtype=object) return output
[docs]def day(obj, *args, start_date=None, **kwargs): ''' Convert a string, date/datetime object, or int to a day (int), the number of days since the start day. See also sc.date() and sc.daydiff(). If a start day is not supplied, it returns the number of days into the current year. Args: obj (str, date, int, list, array): convert any of these objects to a day relative to the start day args (list): additional days start_date (str or date): the start day; if none is supplied, return days since (supplied year)-01-01. Returns: days (int or list): the day(s) in simulation time (matching input data type where possible) **Examples**:: sc.day(sc.now()) # Returns how many days into the year we are sc.day(['2021-01-21', '2024-04-04'], start_date='2022-02-22') # Days can be positive or negative New in version 1.0.0. New in version 1.2.2: renamed "start_day" to "start_date" ''' # Handle deprecation start_day = kwargs.pop('start_day', None) if start_day is not None: # pragma: no cover start_date = start_day warnmsg = 'sc.day() argument "start_day" has been deprecated as of v1.2.2; use "start_date" instead' warnings.warn(warnmsg, category=DeprecationWarning, stacklevel=2) # Do not process a day if it's not supplied, and ensure it's a list if obj is None: return None obj, is_list, is_array = _sanitize_iterables(obj, *args) days = [] for d in obj: if d is None: days.append(d) elif isnumber(d): days.append(int(d)) # Just convert to an integer else: try: if isstring(d): d = readdate(d).date() elif isinstance(d, dt.datetime): d = d.date() if start_date: start_date = date(start_date) else: start_date = date(f'{d.year}-01-01') d_day = (d - start_date).days # Heavy lifting -- actually compute the day days.append(d_day) except Exception as E: # pragma: no cover errormsg = f'Could not interpret "{d}" as a date: {str(E)}' raise ValueError(errormsg) # Return an integer rather than a list if only one provided output = _sanitize_output(days, is_list, is_array) return output
[docs]def daydiff(*args): ''' Convenience function to find the difference between two or more days. With only one argument, calculate days since 2020-01-01. **Examples**:: diff = sc.daydiff('2020-03-20', '2020-04-05') # Returns 16 diffs = sc.daydiff('2020-03-20', '2020-04-05', '2020-05-01') # Returns [16, 26] New in version 1.0.0. ''' days = [date(day) for day in args] if len(days) == 1: days.insert(0, date(f'{now().year}-01-01')) # With one date, return days since Jan. 1st output = [] for i in range(len(days)-1): diff = (days[i+1] - days[i]).days output.append(diff) if len(output) == 1: output = output[0] return output
def daterange(start_date, end_date, inclusive=True, as_date=False, dateformat=None): ''' Return a list of dates from the start date to the end date. To convert a list of days (as integers) to dates, use sc.date() instead. Args: start_date (int/str/date): the starting date, in any format end_date (int/str/date): the end date, in any format inclusive (bool): if True (default), return to end_date inclusive; otherwise, stop the day before as_date (bool): if True, return a list of datetime.date objects instead of strings dateformat (str): passed to date() **Example**:: dates = sc.daterange('2020-03-01', '2020-04-04') New in version 1.0.0. ''' end_day = day(end_date, start_date=start_date) if inclusive: end_day += 1 days = list(range(end_day)) dates = date(days, start_date=start_date, as_date=as_date, dateformat=dateformat) return dates def datedelta(datestr, days=0, months=0, years=0, weeks=0, as_date=None, **kwargs): ''' Perform calculations on a date string (or date object), returning a string (or a date). Wrapper to dateutil.relativedelta(). Args: datestr (str/date): the starting date (typically a string) days (int): the number of days (positive or negative) to increment months (int): as above years (int): as above weeks (int): as above as_date (bool): if True, return a date object; otherwise, return as input type kwargs (dict): passed to ``sc.readdate()`` **Examples**:: sc.datedelta('2021-07-07', 3) # Add 3 days sc.datedelta('2021-07-07', days=-4) # Subtract 4 days sc.datedelta('2021-07-07', weeks=4, months=-1, as_date=True) # Add 4 weeks but subtract a month, and return a dateobj ''' if as_date is None and isinstance(datestr, str): # Typical case as_date = False dateobj = readdate(datestr, **kwargs) newdate = dateobj + dateutil.relativedelta.relativedelta(days=days, months=months, years=years, weeks=weeks) newdate = date(newdate, as_date=as_date) return newdate def datetoyear(dateobj, dateformat=None): """ Convert a DateTime instance to decimal year. Args: dateobj (date, str): The datetime instance to convert dateformat (str): If dateobj is a string, the optional date conversion format to use Returns: Equivalent decimal year **Example**:: sc.datetoyear('2010-07-01') # Returns approximately 2010.5 By Luke Davis from https://stackoverflow.com/a/42424261, adapted by Romesh Abeysuriya. New in version 1.0.0. """ if isstring(dateobj): dateobj = readdate(dateobj, dateformat=dateformat) year_part = dateobj - dt.datetime(year=dateobj.year, month=1, day=1) year_length = dt.datetime(year=dateobj.year + 1, month=1, day=1) - dt.datetime(year=dateobj.year, month=1, day=1) return dateobj.year + year_part / year_length def elapsedtimestr(pasttime, maxdays=5, minseconds=10, shortmonths=True): """ Accepts a datetime object or a string in ISO 8601 format and returns a human-readable string explaining when this time was. The rules are as follows: * If a time is within the last hour, return 'XX minutes' * If a time is within the last 24 hours, return 'XX hours' * If within the last 5 days, return 'XX days' * If in the same year, print the date without the year * If in a different year, print the date with the whole year These can be configured as options. **Examples**:: yesterday = sc.datedelta(sc.now(), days=-1) sc.elapsedtimestr(yesterday) """ # Elapsed time function by Alex Chan # https://gist.github.com/alexwlchan/73933442112f5ae431cc def print_date(date, includeyear=True, shortmonths=True): """Prints a datetime object as a full date, stripping off any leading zeroes from the day (strftime() gives the day of the month as a zero-padded decimal number). """ # %b/%B are the tokens for abbreviated/full names of months to strftime() if shortmonths: month_token = '%b' else: month_token = '%B' # Get a string from strftime() if includeyear: date_str = date.strftime('%d ' + month_token + ' %Y') else: date_str = date.strftime('%d ' + month_token) # There will only ever be at most one leading zero, so check for this and # remove if necessary if date_str[0] == '0': date_str = date_str[1:] return date_str now_time = dt.datetime.now() # If the user passes in a string, try to turn it into a datetime object before continuing if isinstance(pasttime, str): try: pasttime = readdate(pasttime) except ValueError as E: # pragma: no cover errormsg = f"User supplied string {pasttime} is not in a readable format." raise ValueError(errormsg) from E elif isinstance(pasttime, dt.datetime): pass else: # pragma: no cover errormsg = f"User-supplied value {pasttime} is neither a datetime object nor an ISO 8601 string." raise TypeError(errormsg) # It doesn't make sense to measure time elapsed between now and a future date, so we'll just print the date if pasttime > now_time: includeyear = (pasttime.year != now_time.year) time_str = print_date(pasttime, includeyear=includeyear, shortmonths=shortmonths) # Otherwise, start by getting the elapsed time as a datetime object else: elapsed_time = now_time - pasttime # Check if the time is within the last minute if elapsed_time < dt.timedelta(seconds=60): if elapsed_time.seconds <= minseconds: time_str = "just now" else: time_str = f"{elapsed_time.seconds} secs ago" # Check if the time is within the last hour elif elapsed_time < dt.timedelta(seconds=60 * 60): # We know that seconds > 60, so we can safely round down minutes = int(elapsed_time.seconds / 60) if minutes == 1: time_str = "a minute ago" else: time_str = f"{minutes} mins ago" # Check if the time is within the last day elif elapsed_time < dt.timedelta(seconds=60 * 60 * 24 - 1): # We know that it's at least an hour, so we can safely round down hours = int(elapsed_time.seconds / (60 * 60)) if hours == 1: time_str = "1 hour ago" else: time_str = f"{hours} hours ago" # Check if it's within the last N days, where N is a user-supplied argument elif elapsed_time < dt.timedelta(days=maxdays): if elapsed_time.days == 1: time_str = "yesterday" else: time_str = f"{elapsed_time.days} days ago" # If it's not within the last N days, then we're just going to print the date else: includeyear = (pasttime.year != now_time.year) time_str = print_date(pasttime, includeyear=includeyear, shortmonths=shortmonths) return time_str def tic(): ''' With toc(), a little pair of functions to calculate a time difference: **Examples**:: sc.tic() slow_func() sc.toc() T = sc.tic() slow_func2() sc.toc(T, label='slow_func2') ''' global _tictime # The saved time is stored in this global _tictime = time.time() # Store the present time in the global return _tictime # Return the same stored number def toc(start=None, output=False, label=None, sigfigs=None, filename=None, reset=False): ''' With tic(), a little pair of functions to calculate a time difference. Args: start (float): the starting time, as returned by e.g. sc.tic() output (bool): whether to return the output (otherwise print) label (str): optional label to add sigfigs (int): number of significant figures for time estimate filename (str): log file to write results to reset (bool): reset the time; like calling sc.toctic() or sc.tic() again **Examples**:: sc.tic() slow_func() sc.toc() T = sc.tic() slow_func2() sc.toc(T, label='slow_func2') ''' global _tictime # The saved time is stored in this global # Set defaults if label is None: label = '' if sigfigs is None: sigfigs = 3 # If no start value is passed in, try to grab the global _tictime. if start is None: try: start = _tictime except: start = 0 # This doesn't exist, so just leave start at 0. # Get the elapsed time in seconds. elapsed = time.time() - start # Create the message giving the elapsed time. if label=='': base = 'Elapsed time: ' else: base = f'Elapsed time for {label}: ' logmessage = base + f'{sigfig(elapsed, sigfigs=sigfigs)} s' # Optionally reset the counter if reset: _tictime = time.time() # Store the present time in the global if output: return elapsed else: if filename is not None: printtologfile(logmessage, filename) # If we passed in a filename, append the message to that file. else: print(logmessage) # Otherwise, print the message. return def toctic(returntic=False, returntoc=False, *args, **kwargs): ''' A convenience function for multiple timings. Can return the default output of either tic() or toc() (default neither). Arguments are passed to toc(). Equivalent to sc.toc(reset=True). **Example**:: sc.tic() slow_operation_1() sc.toctic() slow_operation_2() sc.toc() New in version 1.0.0. ''' tocout = toc(*args, **kwargs) ticout = tic() if returntic: return ticout elif returntoc: return tocout else: return None def timedsleep(delay=None, verbose=True): ''' Delay for a certain amount of time, to ensure accurate timing. **Example**:: for i in range(10): sc.timedsleep('start') # Initialize for j in range(int(1e6)): tmp = pl.rand() sc.timedsleep(1) # Wait for one second including computation time ''' global _delaytime if delay is None or delay=='start': _delaytime = time.time() # Store the present time in the global. return _delaytime # Return the same stored number. else: try: start = _delaytime except: start = time.time() elapsed = time.time() - start remaining = delay-elapsed if remaining>0: if verbose: print(f'Pausing for {remaining:0.1f} s') time.sleep(remaining) else: if verbose: print(f'Warning, delay less than elapsed time ({delay:0.1f} vs. {elapsed:0.1f})') return None ############################################################################## #%% Misc. functions ############################################################################## __all__ += ['checkmem', 'checkram', 'runcommand', 'gitinfo', 'compareversions', 'uniquename', 'importbyname', 'suggest', 'profile', 'mprofile', 'getcaller'] def checkmem(var, descend=None, alphabetical=False, plot=False, verbose=False): ''' Checks how much memory the variable or variables in question use by dumping them to file. See also checkram(). Args: var (any): the variable being checked descend (bool): whether or not to descend one level into the object alphabetical (bool): if descending into a dict or object, whether to list items by name rather than size plot (bool): if descending, show the results as a pie chart verbose (bool or int): detail to print, if >1, print repr of objects along the way **Example**:: import sciris as sc sc.checkmem(['spiffy',rand(2483,589)], descend=True) ''' from .sc_fileio import saveobj # Here to avoid recursion def check_one_object(variable): ''' Check the size of one variable ''' if verbose>1: print(f' Checking size of {variable}...') # Create a temporary file, save the object, check the size, remove it filename = tempfile.mktemp() saveobj(filename, variable, die=False) filesize = os.path.getsize(filename) os.remove(filename) # Convert to string factor = 1 label = 'B' labels = ['KB','MB','GB'] for i,f in enumerate([3,6,9]): if filesize>10**f: factor = 10**f label = labels[i] humansize = float(filesize/float(factor)) sizestr = f'{humansize:0.3f} {label}' return filesize, sizestr # Initialize varnames = [] variables = [] sizes = [] sizestrs = [] # Create the object(s) to check the size(s) of varnames = [''] # Set defaults variables = [var] if descend or descend is None: if hasattr(var, '__dict__'): # It's an object if verbose>1: print('Iterating over object') varnames = sorted(list(var.__dict__.keys())) variables = [getattr(var, attr) for attr in varnames] elif np.iterable(var): # Handle dicts and lists if isinstance(var, dict): # Handle dicts if verbose>1: print('Iterating over dict') varnames = list(var.keys()) variables = var.values() else: # Handle lists and other things if verbose>1: print('Iterating over list') varnames = [f'item {i}' for i in range(len(var))] variables = var else: if descend: # Could also be None print('Object is not iterable: cannot descend') # Print warning and use default # Compute the sizes for v,variable in enumerate(variables): if verbose: print(f'Processing variable {v} of {len(variables)}') filesize, sizestr = check_one_object(variable) sizes.append(filesize) sizestrs.append(sizestr) if alphabetical: inds = np.argsort(varnames) else: inds = np.argsort(sizes)[::-1] for i in inds: varstr = f'Variable "{varnames[i]}"' if varnames[i] else 'Variable' print(f'{varstr} is {sizestrs[i]}') if plot: # pragma: no cover import pylab as pl # Optional import pl.axes(aspect=1) pl.pie(pl.array(sizes)[inds], labels=pl.array(varnames)[inds], autopct='%0.2f') return None def checkram(unit='mb', fmt='0.2f', start=0, to_string=True): ''' Unlike checkmem(), checkram() looks at actual memory usage, typically at different points throughout execution. **Example**:: import sciris as sc import numpy as np start = sc.checkram(to_string=False) a = np.random.random((1_000, 10_000)) print(sc.checkram(start=start)) New in version 1.0.0. ''' process = psutil.Process(os.getpid()) mapping = {'b':1, 'kb':1e3, 'mb':1e6, 'gb':1e9} try: factor = mapping[unit.lower()] except KeyError: # pragma: no cover raise KeyNotFoundError(f'Unit {unit} not found among {strjoin(mapping.keys())}') mem_use = process.memory_info().rss/factor - start if to_string: output = f'{mem_use:{fmt}} {unit.upper()}' else: output = mem_use return output def runcommand(command, printinput=False, printoutput=False, wait=True): ''' Make it easier to run shell commands. Alias to ``subprocess.Popen()``. **Examples**:: myfiles = sc.runcommand('ls').split('\\n') # Get a list of files in the current folder sc.runcommand('sshpass -f %s scp myfile.txt [email protected]:myfile.txt' % 'pa55w0rd', printinput=True, printoutput=True) # Copy a file remotely sc.runcommand('sleep 600; mkdir foo', wait=False) # Waits 10 min, then creates the folder "foo", but the function returns immediately Date: 2019sep04 ''' if printinput: print(command) try: p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) if wait: # Whether to run in the background stderr = p.stdout.read().decode("utf-8") # Somewhat confusingly, send stderr to stdout stdout = p.communicate()[0].decode("utf-8") # ...and then stdout to the pipe output = stdout + '\n' + stderr if stderr else stdout # Only include the error if it was non-empty else: output = '' except Exception as E: # pragma: no cover output = f'runcommand(): shell command failed: {str(E)}' # This is for a Python error, not a shell error -- those get passed to output if printoutput: print(output) return output def gitinfo(path=None, hashlen=7, die=False, verbose=True): """ Retrieve git info This function reads git branch and commit information from a .git directory. Given a path, it will check for a ``.git`` directory. If the path doesn't contain that directory, it will search parent directories for ``.git`` until it finds one. Then, the current information will be parsed. Note: if direct directory reading fails, it will attempt to use the gitpython library. Args: path (str): A folder either containing a .git directory, or with a parent that contains a .git directory hashlen (int): Length of hash to return (default: 7) die (bool): whether to raise an exception if git information can't be retrieved (default: no) verbose (bool): if not dying, whether to print information about the exception Returns: Dictionary containing the branch, hash, and commit date **Examples**:: info = sc.gitinfo() # Get git info for current script repository info = sc.gitinfo(my_package.__file__) # Get git info for a particular Python package """ if path is None: path = os.getcwd() gitbranch = "Branch N/A" githash = "Hash N/A" gitdate = "Date N/A" try: # First, get the .git directory curpath = os.path.dirname(os.path.abspath(path)) while curpath: if os.path.exists(os.path.join(curpath, ".git")): gitdir = os.path.join(curpath, ".git") break else: # pragma: no cover parent, _ = os.path.split(curpath) if parent == curpath: curpath = None else: curpath = parent else: # pragma: no cover raise RuntimeError("Could not find .git directory") # Then, get the branch and commit with open(os.path.join(gitdir, "HEAD"), "r") as f1: ref = f1.read() if ref.startswith("ref:"): refdir = ref.split(" ")[1].strip() # The path to the file with the commit gitbranch = refdir.replace("refs/heads/", "") # / is always used (not os.sep) with open(os.path.join(gitdir, refdir), "r") as f2: githash = f2.read().strip() # The hash of the commit else: # pragma: no cover gitbranch = "Detached head (no branch)" githash = ref.strip() # Now read the time from the commit with open(os.path.join(gitdir, "objects", githash[0:2], githash[2:]), "rb") as f3: compressed_contents = f3.read() decompressed_contents = zlib.decompress(compressed_contents).decode() for line in decompressed_contents.split("\n"): if line.startswith("author"): _re_actor_epoch = re.compile(r"^.+? (.*) (\d+) ([+-]\d+).*$") m = _re_actor_epoch.search(line) actor, epoch, offset = m.groups() t = time.gmtime(int(epoch)) gitdate = time.strftime("%Y-%m-%d %H:%M:%S UTC", t) except Exception as E: # pragma: no cover try: # Second, try importing gitpython import git rootdir = os.path.abspath(path) # e.g. /user/username/my/folder repo = git.Repo(path=rootdir, search_parent_directories=True) try: gitbranch = str(repo.active_branch.name) # Just make sure it's a string except TypeError: gitbranch = 'Detached head (no branch)' githash = str(repo.head.object.hexsha) gitdate = str(repo.head.object.authored_datetime.isoformat()) except Exception as E2: errormsg = f'''Could not extract git info; please check paths: Method 1 (direct read) error: {str(E)} Method 2 (gitpython) error: {str(E2)}''' if die: raise RuntimeError(errormsg) from E elif verbose: print(errormsg + f'\nError: {str(E)}') # Trim the hash, but not if loading failed if len(githash)>hashlen and 'N/A' not in githash: githash = githash[:hashlen] # Assemble output output = {"branch": gitbranch, "hash": githash, "date": gitdate} return output def compareversions(version1, version2): ''' Function to compare versions, expecting both arguments to be a string of the format 1.2.3, but numeric works too. Returns 0 for equality, -1 for v1<v2, and 1 for v1>v2. If ``version2`` starts with >, >=, <, <=, or ==, the function returns True or False depending on the result of the comparison. **Examples**:: sc.compareversions('1.2.3', '2.3.4') # returns -1 sc.compareversions(2, '2') # returns 0 sc.compareversions('3.1', '2.99') # returns 1 sc.compareversions('3.1', '>=2.99') # returns True sc.compareversions(mymodule.__version__, '>=1.0') # common usage pattern sc.compareversions(mymodule, '>=1.0') # alias to the above New in version 1.2.1: relational operators ''' # Handle inputs if isinstance(version1, types.ModuleType): try: version1 = version1.__version__ except Exception as E: errormsg = f'{version1} is a module, but does not have a __version__ attribute' raise AttributeError(errormsg) from E v1 = str(version1) v2 = str(version2) # Process version2 valid = None if v2.startswith('>'): valid = [1] elif v2.startswith('>='): valid = [0,1] elif v2.startswith('='): valid = [0] elif v2.startswith('=='): valid = [0] elif v2.startswith('~='): valid = [-1,1] elif v2.startswith('!='): valid = [-1,1] elif v2.startswith('<='): valid = [0,-1] elif v2.startswith('<'): valid = [-1] v2 = v2.lstrip('<>=!~') # Do comparison if LooseVersion(v1) > LooseVersion(v2): comparison = 1 elif LooseVersion(v1) < LooseVersion(v2): comparison = -1 else: comparison = 0 # Return if valid is None: return comparison else: tf = (comparison in valid) return tf def uniquename(name=None, namelist=None, style=None): """ Given a name and a list of other names, find a replacement to the name that doesn't conflict with the other names, and pass it back. **Example**:: name = sc.uniquename(name='file', namelist=['file', 'file (1)', 'file (2)']) """ if style is None: style = ' (%d)' namelist = promotetolist(namelist) unique_name = str(name) # Start with the passed in name. i = 0 # Reset the counter while unique_name in namelist: # Try adding an index (i) to the name until we find one that's unique i += 1 unique_name = str(name) + style%i return unique_name # Return the found name. def importbyname(name=None, output=False, die=True): ''' A little function to try loading optional imports. **Example**:: np = sc.importbyname('numpy') ''' import importlib try: module = importlib.import_module(name) globals()[name] = module except Exception as E: # pragma: no cover errormsg = f'Cannot use "{name}" since {name} is not installed.\nPlease install {name} and try again.' print(errormsg) if die: raise E else: return False if output: return module else: return True def suggest(user_input, valid_inputs, n=1, threshold=None, fulloutput=False, die=False, which='damerau'): """ Return suggested item Returns item with lowest Levenshtein distance, where case substitution and stripping whitespace are not included in the distance. If there are ties, then the additional operations will be included. Args: user_input (str): User's input valid_inputs (list): List/collection of valid strings n (int): Maximum number of suggestions to return threshold (int): Maximum number of edits required for an option to be suggested (by default, two-thirds the length of the input; for no threshold, set to -1) die (bool): If True, an informative error will be raised (to avoid having to implement this in the calling code) which (str): Distance calculation method used; options are "damerau" (default), "levenshtein", or "jaro" Returns: suggestions (str or list): Suggested string. Returns None if no suggestions with edit distance less than threshold were found. This helps to make suggestions more relevant. **Examples**:: >>> sc.suggest('foo',['Foo','Bar']) 'Foo' >>> sc.suggest('foo',['FOO','Foo']) 'Foo' >>> sc.suggest('foo',['Foo ','boo']) 'Foo ' """ try: import jellyfish # To allow as an optional import except ModuleNotFoundError as e: # pragma: no cover raise ModuleNotFoundError('The "jellyfish" Python package is not available; please install via "pip install jellyfish"') from e valid_inputs = promotetolist(valid_inputs, objtype='string') mapping = { 'damerau': jellyfish.damerau_levenshtein_distance, 'levenshtein': jellyfish.levenshtein_distance, 'jaro': jellyfish.jaro_distance, } keys = list(mapping.keys()) if which not in keys: # pragma: no cover errormsg = f'Method {which} not available; options are {strjoin(keys)}' raise NotImplementedError(errormsg) dist_func = mapping[which] distance = np.zeros(len(valid_inputs)) cs_distance = np.zeros(len(valid_inputs)) # We will switch inputs to lowercase because we want to consider case substitution a 'free' operation # Similarly, stripping whitespace is a free operation. This ensures that something like # 'foo ' will match 'Foo' ahead of 'boo ' for i, s in enumerate(valid_inputs): distance[i] = dist_func(user_input, s.strip().lower()) cs_distance[i] = dist_func(user_input, s.strip()) # If there is a tie for the minimum distance, use the case sensitive comparison if sum(distance==min(distance)) > 1: distance = cs_distance # Order by distance, then pull out the right inputs, then turn them into a list order = np.argsort(distance) suggestions = [valid_inputs[i] for i in order] suggestionstr = strjoin([f'"{sugg}"' for sugg in suggestions[:n]]) # Handle threshold if threshold is None: threshold = np.ceil(len(user_input)*2/3) if threshold < 0: threshold = np.inf # Output if min(distance) > threshold: if die: # pragma: no cover errormsg = f'"{user_input}" not found' raise ValueError(errormsg) else: return None elif die: errormsg = f'"{user_input} not found - did you mean {suggestionstr}' raise ValueError(errormsg) else: if fulloutput: output = dict(zip(suggestions, distance[order])) return output else: if n==1: return suggestions[0] else: return suggestions[:n] def profile(run, follow=None, print_stats=True, *args, **kwargs): ''' Profile the line-by-line time required by a function. Args: run (function): The function to be run follow (function): The function or list of functions to be followed in the profiler; if None, defaults to the run function print_stats (bool): whether to print the statistics of the profile to stdout args, kwargs: Passed to the function to be run Returns: LineProfiler (by default, the profile output is also printed to stdout) **Example**:: def slow_fn(): n = 10000 int_list = [] int_dict = {} for i in range(n): int_list.append(i) int_dict[i] = i return class Foo: def __init__(self): self.a = 0 return def outer(self): for i in range(100): self.inner() return def inner(self): for i in range(1000): self.a += 1 return foo = Foo() sc.profile(run=foo.outer, follow=[foo.outer, foo.inner]) sc.profile(slow_fn) # Profile the constructor for Foo f = lambda: Foo() sc.profile(run=f, follow=[foo.__init__]) ''' try: from line_profiler import LineProfiler except ModuleNotFoundError as E: # pragma: no cover if 'win' in sys.platform: errormsg = 'The "line_profiler" package is not included by default on Windows;' \ 'please install using "pip install line_profiler" (note: you will need a ' \ 'C compiler installed, e.g. Microsoft Visual Studio)' else: errormsg = 'The "line_profiler" Python package is required to perform profiling' raise ModuleNotFoundError(errormsg) from E if follow is None: follow = run orig_func = run lp = LineProfiler() follow = promotetolist(follow) for f in follow: lp.add_function(f) lp.enable_by_count() wrapper = lp(run) if print_stats: # pragma: no cover print('Profiling...') wrapper(*args, **kwargs) run = orig_func if print_stats: # pragma: no cover lp.print_stats() print('Done.') return lp def mprofile(run, follow=None, show_results=True, *args, **kwargs): ''' Profile the line-by-line memory required by a function. See profile() for a usage example. Args: run (function): The function to be run follow (function): The function or list of functions to be followed in the profiler; if None, defaults to the run function show_results (bool): whether to print the statistics of the profile to stdout args, kwargs: Passed to the function to be run Returns: LineProfiler (by default, the profile output is also printed to stdout) ''' try: import memory_profiler as mp except ModuleNotFoundError as E: # pragma: no cover if 'win' in sys.platform: errormsg = 'The "memory_profiler" package is not included by default on Windows;' \ 'please install using "pip install memory_profiler" (note: you will need a ' \ 'C compiler installed, e.g. Microsoft Visual Studio)' else: errormsg = 'The "memory_profiler" Python package is required to perform profiling' raise ModuleNotFoundError(errormsg) from E if follow is None: follow = run lp = mp.LineProfiler() follow = promotetolist(follow) for f in follow: lp.add_function(f) lp.enable_by_count() try: wrapper = lp(run) except TypeError as e: # pragma: no cover raise TypeError('Function wrapping failed; are you profiling an already-profiled function?') from e if show_results: print('Profiling...') wrapper(*args, **kwargs) if show_results: mp.show_results(lp) print('Done.') return lp def getcaller(frame=2, tostring=True): ''' Try to get information on the calling function, but fail gracefully. Frame 1 is the current file (this one), so not very useful. Frame 2 is the default assuming it is being called directly. Frame 3 is used if another function is calling this function internally. Args: frame (int): how many frames to descend (e.g. the caller of the caller of the...) tostring (bool): whether to return a string instead of a dict Returns: output (str/dict): the filename and line number of the calling function, either as a string or dict New in version 1.0.0. ''' try: import inspect result = inspect.getouterframes(inspect.currentframe(), 2) fname = str(result[frame][1]) lineno = str(result[frame][2]) if tostring: output = f'{fname}, line {lineno}' else: output = {'filename':fname, 'lineno':lineno} except Exception as E: # pragma: no cover if tostring: output = f'Calling function information not available ({str(E)})' else: output = {'filename':'N/A', 'lineno':'N/A'} return output ############################################################################## #%% Nested dictionary functions ############################################################################## __all__ += ['getnested', 'setnested', 'makenested', 'iternested', 'mergenested', 'flattendict', 'search', 'nestedloop'] def makenested(nesteddict, keylist=None, value=None, overwrite=False, generator=None): ''' Little functions to get and set data from nested dictionaries. The first two were adapted from: http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys "getnested" will get the value for the given list of keys: >>> sc.getnested(foo, ['a','b']) "setnested" will set the value for the given list of keys: >>> sc.setnested(foo, ['a','b'], 3) "makenested" will recursively update a dictionary with the given list of keys: >>> sc.makenested(foo, ['a','b']) "iternested" will return a list of all the twigs in the current dictionary: >>> twigs = sc.iternested(foo) **Example 1**:: foo = {} sc.makenested(foo, ['a','b']) foo['a']['b'] = 3 print(sc.getnested(foo, ['a','b'])) # 3 sc.setnested(foo, ['a','b'], 7) print(sc.getnested(foo, ['a','b'])) # 7 sc.makenested(foo, ['bar','cat']) sc.setnested(foo, ['bar','cat'], 'in the hat') print(foo['bar']) # {'cat': 'in the hat'} **Example 2**:: foo = {} sc.makenested(foo, ['a','x']) sc.makenested(foo, ['a','y']) sc.makenested(foo, ['a','z']) sc.makenested(foo, ['b','a','x']) sc.makenested(foo, ['b','a','y']) count = 0 for twig in sc.iternested(foo): count += 1 sc.setnested(foo, twig, count) # {'a': {'y': 1, 'x': 2, 'z': 3}, 'b': {'a': {'y': 4, 'x': 5}}} Version: 2014nov29 ''' if generator is None: generator = nesteddict.__class__ # By default, generate new dicts of the same class as the original one currentlevel = nesteddict for i,key in enumerate(keylist[:-1]): if not(key in currentlevel): currentlevel[key] = generator() # Create a new dictionary currentlevel = currentlevel[key] lastkey = keylist[-1] if isinstance(currentlevel, dict): if overwrite or lastkey not in currentlevel: currentlevel[lastkey] = value elif not overwrite and value is not None: errormsg = f'Not overwriting entry {keylist} since overwrite=False' raise ValueError(errormsg) elif value is not None: errormsg = f'Cannot set value {value} since entry {keylist} is a {type(currentlevel)}, not a dict' raise TypeError(errormsg) return def getnested(nesteddict, keylist, safe=False): ''' Get the value for the given list of keys >>> sc.getnested(foo, ['a','b']) See sc.makenested() for full documentation. ''' output = reduce(lambda d, k: d.get(k) if d else None if safe else d[k], keylist, nesteddict) return output def setnested(nesteddict, keylist, value, force=True): ''' Set the value for the given list of keys >>> sc.setnested(foo, ['a','b'], 3) See sc.makenested() for full documentation. ''' if force: makenested(nesteddict, keylist, overwrite=False) currentlevel = getnested(nesteddict, keylist[:-1]) if not isinstance(currentlevel, dict): errormsg = f'Cannot set {keylist} since parent is a {type(currentlevel)}, not a dict' raise TypeError(errormsg) else: currentlevel[keylist[-1]] = value return # Modify nesteddict in place def iternested(nesteddict, previous=None): ''' Return a list of all the twigs in the current dictionary >>> twigs = sc.iternested(foo) See sc.makenested() for full documentation. ''' if previous is None: previous = [] output = [] for k in nesteddict.items(): if isinstance(k[1],dict): output += iternested(k[1], previous+[k[0]]) # Need to add these at the first level else: output.append(previous+[k[0]]) return output def mergenested(dict1, dict2, die=False, verbose=False, _path=None): ''' Merge different nested dictionaries See sc.makenested() for full documentation. Adapted from https://stackoverflow.com/questions/7204805/dictionaries-of-dictionaries-merge ''' if _path is None: _path = [] if _path: a = dict1 # If we're being recursive, work in place else: a = dcp(dict1) # Otherwise, make a copy b = dict2 # Don't need to make a copy for key in b: keypath = ".".join(_path + [str(key)]) if verbose: print(f'Working on {keypath}') if key in a: if isinstance(a[key], dict) and isinstance(b[key], dict): mergenested(dict1=a[key], dict2=b[key], _path=_path+[str(key)], die=die, verbose=verbose) elif a[key] == b[key]: pass # same leaf value else: errormsg = f'Warning! Conflict at {keypath}: {a[key]} vs. {b[key]}' if die: raise ValueError(errormsg) else: a[key] = b[key] if verbose: print(errormsg) else: a[key] = b[key] return a def flattendict(nesteddict, sep=None, _prefix=None): """ Flatten nested dictionary **Example**:: >>> sc.flattendict({'a':{'b':1,'c':{'d':2,'e':3}}}) {('a', 'b'): 1, ('a', 'c', 'd'): 2, ('a', 'c', 'e'): 3} >>> sc.flattendict({'a':{'b':1,'c':{'d':2,'e':3}}}, sep='_') {'a_b': 1, 'a_c_d': 2, 'a_c_e': 3} Args: d: Input dictionary potentially containing dicts as values sep: Concatenate keys using string separator. If ``None`` the returned dictionary will have tuples as keys _prefix: Internal argument for recursively accumulating the nested keys Returns: A flat dictionary where no values are dicts """ output_dict = {} for k, v in nesteddict.items(): if sep is None: if _prefix is None: k2 = (k,) else: k2 = _prefix + (k,) else: if _prefix is None: k2 = k else: k2 = _prefix + sep + k if isinstance(v, dict): output_dict.update(flattendict(nesteddict[k], sep=sep, _prefix=k2)) else: output_dict[k2] = v return output_dict def search(obj, attribute, _trace=''): """ Find a key or attribute within a dictionary or object. This function facilitates finding nested key(s) or attributes within an object, by searching recursively through keys or attributes. Args: obj: A dict or class with __dict__ attribute attribute: The substring to search for _trace: Not for user input - internal variable used for recursion Returns: A list of matching attributes. The items in the list are the Python strings used to access the attribute (via attribute or dict indexing) **Example**:: nested = {'a':{'foo':1, 'bar':2}, 'b':{'bar':3, 'cat':4}} matches = sc.search(nested, 'bar') # Returns ['["a"]["bar"]', '["b"]["bar"]'] """ matches = [] if isinstance(obj, dict): d = obj elif hasattr(obj, '__dict__'): d = obj.__dict__ else: return matches for attr in d: if isinstance(obj, dict): s = _trace + f'["{attr}"]' else: s = _trace + f'.{attr}' if attribute in attr: matches.append(s) matches += search(d[attr], attribute, s) return matches def nestedloop(inputs, loop_order): """ Zip list of lists in order This function takes in a list of lists to iterate over, and their nesting order. It then yields tuples of items in the given order. Only tested for two levels but in theory supports an arbitrary number of items. Args: inputs (list): List of lists. All lists should have the same length loop_order (list): Nesting order for the lists Returns: Generator yielding tuples of items, one for each list Example usage: >>> list(sc.nestedloop([['a','b'],[1,2]],[0,1])) [['a', 1], ['a', 2], ['b', 1], ['b', 2]] Notice how the first two items have the same value for the first list while the items from the second list vary. If the `loop_order` is reversed, then: >>> list(sc.nestedloop([['a','b'],[1,2]],[1,0])) [['a', 1], ['b', 1], ['a', 2], ['b', 2]] Notice now how now the first two items have different values from the first list but the same items from the second list. From Atomica by Romesh Abeysuriya. New in version 1.0.0. """ loop_order = list(loop_order) # Convert to list, in case loop order was passed in as a generator e.g. from map() inputs = [inputs[i] for i in loop_order] iterator = itertools.product(*inputs) # This is in the loop order for item in iterator: out = [None] * len(loop_order) for i in range(len(item)): out[loop_order[i]] = item[i] yield out ############################################################################## #%% Classes ############################################################################## __all__ += ['KeyNotFoundError', 'LinkException', 'prettyobj', 'autolist', 'Link', 'Timer'] class KeyNotFoundError(KeyError): ''' A tiny class to fix repr for KeyErrors. KeyError prints the repr of the error message, rather than the actual message, so e.g. newline characters print as the character rather than the actual newline. **Example**:: raise sc.KeyNotFoundError('The key "foo" is not available, but these are: "bar", "cat"') ''' def __str__(self): # pragma: no cover return Exception.__str__(self) class LinkException(Exception): ''' An exception to raise when links are broken, for exclusive use with the Link class. ''' def __init(self, *args, **kwargs): Exception.__init__(self, *args, **kwargs) class prettyobj(object): ''' Use pretty repr for objects, instead of just showing the type and memory pointer (the Python default for objects). Can also be used as the base class for custom classes. **Examples** >>> myobj = sc.prettyobj() >>> myobj.a = 3 >>> myobj.b = {'a':6} >>> print(myobj) <sciris.sc_utils.prettyobj at 0x7ffa1e243910> ———————————————————————————————————————————————————————————— a: 3 b: {'a': 6} ———————————————————————————————————————————————————————————— >>> class MyObj(sc.prettyobj): >>> >>> def __init__(self, a, b): >>> self.a = a >>> self.b = b >>> >>> def mult(self): >>> return self.a * self.b >>> >>> myobj = MyObj(a=4, b=6) >>> print(myobj) <__main__.MyObj at 0x7fd9acd96c10> ———————————————————————————————————————————————————————————— Methods: mult() ———————————————————————————————————————————————————————————— a: 4 b: 6 ———————————————————————————————————————————————————————————— ''' def __repr__(self): output = prepr(self) return output class autolist(list): ''' A simple extension to a list that defines add methods to simplify appending and extension. **Examples**:: ls = sc.autolist(3) # Quickly convert a scalar to a list ls = sc.autolist() for i in range(5): ls += i # No need for ls += [i] ''' def __init__(self, *args): arglist = mergelists(*args) # Convert non-iterables to iterables return super().__init__(arglist) def __add__(self, obj=None): ''' Allows non-lists to be concatenated ''' obj = promotetolist(obj) new = super().__add__(obj) return new def __radd__(self, obj): ''' Allows sum() to work correctly ''' return self.__add__(obj) def __iadd__(self, obj): ''' Allows += to work correctly ''' obj = promotetolist(obj) self.extend(obj) return self class Link(object): ''' A class to differentiate between an object and a link to an object. The idea is that this object is parsed differently from other objects -- most notably, a recursive method (such as a pickle) would skip over Link objects, and then would fix them up after the other objects had been reinstated. Version: 2017jan31 ''' def __init__(self, obj=None): self.obj = obj # Store the object -- or rather a reference to it, if it's mutable try: self.uid = obj.uid # If the object has a UID, store it separately except: self.uid = None # If not, just use None def __repr__(self): # pragma: no cover ''' Just use default ''' output = prepr(self) return output def __call__(self, obj=None): ''' If called with no argument, return the stored object; if called with argument, update object ''' if obj is None: if type(self.obj)==LinkException: # If the link is broken, raise it now raise self.obj return self.obj else: # pragma: no cover self.__init__(obj) return None def __copy__(self, *args, **kwargs): ''' Do NOT automatically copy link objects!! ''' return Link(LinkException('Link object copied but not yet repaired')) def __deepcopy__(self, *args, **kwargs): ''' Same as copy ''' return self.__copy__(*args, **kwargs) class Timer(object): ''' Simple timer class This wraps ``tic`` and ``toc`` with the formatting arguments and the start time (at construction) Use this in a ``with...as`` block to automatically print elapsed time when the block finishes. Implementation based on https://preshing.com/20110924/timing-your-code-using-pythons-with-statement/ Example making repeated calls to the same Timer:: >>> timer = Timer() >>> timer.toc() Elapsed time: 2.63 s >>> timer.toc() Elapsed time: 5.00 s Example wrapping code using with-as:: >>> with Timer(label='mylabel') as t: >>> foo() ''' def __init__(self,**kwargs): self.tic() self.kwargs = kwargs #: Store kwargs to pass to :func:`toc` at the end of the block return def __enter__(self): ''' Reset start time when entering with-as block ''' self.tic() return self def __exit__(self, *args): ''' Print elapsed time when leaving a with-as block ''' self.toc() def tic(self): ''' Set start time ''' self.start = tic() def toc(self): ''' Print elapsed time ''' toc(self.start,**self.kwargs)