Source code for COMPS.AuthManager

import os
import base64
import errno
import logging
import getpass
import tempfile
import xdg
from datetime import datetime, timedelta
from future.utils import raise_from

import COMPS
from COMPS.CredentialPrompt import get_credential_prompt, CredentialPrompt

logger = logging.getLogger(__name__)


[docs]class AuthManager(object): """ Manage authentication to COMPS. """ __comps_auth_token_key = 'X-COMPS-Token' __comps_client_version = 12 __token_filename_format = 'COMPS_Authtoken_%s_%s.txt' __token_tokentype_index = 1 __token_username_index = 2 __token_expiration_index = 6 __token_groups_index = 12 __token_environments_index = 13 __token_renewal_buffer = 120 def __init__(self, hoststring, verify_certs=False, credential_prompt=None): self._hoststring = AuthManager.__normalize_hoststring(hoststring) override_hoststring = os.environ.get('COMPS_SERVER') if override_hoststring: override_hoststring = AuthManager.__normalize_hoststring(override_hoststring) if self._hoststring != override_hoststring: logger.info('Overriding hoststring with COMPS_SERVER environment variable') self._hoststring = override_hoststring self._verify_certs = verify_certs self._auth_token = None self._username = None self._group_list = None self._env_list = None self._token_expiration = datetime.min self._token_renewal_time = None if credential_prompt is None: # Using default credential prompt self._credential_prompt = get_credential_prompt() elif isinstance(credential_prompt, CredentialPrompt): # Using user-specified credential prompt self._credential_prompt = credential_prompt else: raise RuntimeError('Invalid credential_prompt; must pass object of type "CredentialPrompt".') @property def username(self): return self._username @property def hoststring(self): return self._hoststring @property def groups(self): return self._group_list @property def environments(self): return self._env_list
[docs] def has_auth_token(self): return self._auth_token is not None
[docs] def get_auth_token(self): now = datetime.utcnow() if (self._auth_token is None or # <--- first time in this execution that we're getting an auth token now > self._token_renewal_time): # <--- token is expired or near expiry, so check to see whether anyone else (e.g. another thread) has cached a newer token token_filename = self.__get_token_path() if os.path.exists(token_filename): with open(token_filename, 'r') as tf: token_content = tf.readlines() token = token_content[0] if token_content and token_content[0].strip() != '' else None if token: self.__process_token(token) if now > self._token_expiration: self.__acquire_credentials() elif now > self._token_renewal_time: self.__renew_auth_token() return self.__comps_auth_token_key, self._auth_token
[docs] def clear_auth_token(self): path = self.__get_token_path() # delete file if it exists try: os.remove(path) except OSError as e: if e.errno != errno.ENOENT: # errno.ENOENT = no such file or directory logger.error('Error deleting cached credentials: {0}'.format(e.message)) self._auth_token = None
[docs] @staticmethod def get_environment_macros(environment_name): """ Retrieve the environment macros for a COMPS environment. This may be a somewhat temporary requirement until the Asset Service handles file dependencies more completely (allows uploads, etc). :param environment_name: the COMPS environment to retrieve macros for :return: a dictionary of environment macro key/value pairs """ path = '/Environments/{0}'.format(environment_name) resp = COMPS.Client.get(path) json_resp = resp.json() if 'Environments' not in json_resp or \ len(json_resp['Environments']) != 1 or \ 'Macros' not in json_resp['Environments'][0]: logger.debug(json_resp) raise RuntimeError('Malformed Experiments retrieve response!') return json_resp['Environments'][0]['Macros']
[docs] @staticmethod def get_group_name_for_environment(environment_name): """ Retrieve the Group associated with a particular COMPS environment. :param environment_name: the COMPS environment to retrieve the Group for :return: a string of the Group name """ path = '/Environments/{0}'.format(environment_name) resp = COMPS.Client.get(path) json_resp = resp.json() if 'Environments' not in json_resp or \ len(json_resp['Environments']) != 1 or \ 'GroupName' not in json_resp['Environments'][0]: logger.debug(json_resp) raise RuntimeError('Malformed Experiments retrieve response!') return json_resp['Environments'][0]['GroupName']
def __acquire_credentials(self): success = False logger.info('Logging into {0}'.format(self._hoststring)) while not success: try: if '_prompt_user_for_creds' in dir(self): print('WARNING! This method of overriding the credentials input is deprecated.') print(' Please check out the new "credential_prompt" argument on Client.login()') creds = self._prompt_user_for_creds() else: creds = self._credential_prompt.prompt() except: self._hoststring = None raise if creds is None: logger.info('User canceled login') self._hoststring = None break creds['ClientVersion'] = AuthManager.__comps_client_version creds['Password'] = base64.b64encode(creds['Password'].encode()).decode('utf-8') logger.debug('Sending auth request') resp = COMPS.Client.post("/tokens" , include_comps_auth_token=False , http_err_handle_exceptions=[302, 401] , allow_redirects=False , json=creds , verify=self._verify_certs) creds = None if resp.status_code == 302: raise RuntimeError('Error contacting login endpoint; attempting to hit \'http\' protocol instead of \'https\' ?') elif resp.status_code == 401: logger.info('Bad username/password') continue if AuthManager.__comps_auth_token_key in resp.headers: token = resp.headers[AuthManager.__comps_auth_token_key] self.__process_token(token) self.__cache_token(token) success = True else: logger.error('Error attempting to validate user credentials') def __renew_auth_token(self): success = False while not success: resp = COMPS.Client.put("/tokens" , include_comps_auth_token=False # Don't try to get another auth token, otherwise we get in an # endless loop... we will manually include the old one below... , json={ 'ClientVersion': AuthManager.__comps_client_version} , headers={ AuthManager.__comps_auth_token_key: self._auth_token } , verify=self._verify_certs) if AuthManager.__comps_auth_token_key in resp.headers: token = resp.headers[AuthManager.__comps_auth_token_key] self.__process_token(token) self.__cache_token(token) success = True else: logger.error('Error attempting to renew user credentials') def __process_token(self, token): tokensplit = token.split(',', AuthManager.__token_environments_index + 2) # 0-based + an extra so we separate from anything afterwards try: self._token_expiration = datetime.strptime(tokensplit[AuthManager.__token_expiration_index], '%Y-%m-%d-%H-%M-%S') self._group_list = tuple(tokensplit[AuthManager.__token_groups_index].split('-')) self._env_list = tuple(tokensplit[AuthManager.__token_environments_index].split('-')) # for when parsing cached token if self._username is None: self._username = tokensplit[AuthManager.__token_username_index] except (ValueError, IndexError) as e: raise_from(RuntimeError('Invalid auth token: {}'.format(self.__get_token_path())), None) tokentype = tokensplit[AuthManager.__token_tokentype_index] if tokentype == 'Auth': self._token_renewal_time = self._token_expiration - timedelta(minutes=AuthManager.__token_renewal_buffer) elif tokentype == 'System': self._token_renewal_time = self._token_expiration # system tokens can't be renewed anyway, so no point in trying to renew early else: raise RuntimeError('Unknown token type!') self._auth_token = token def __cache_token(self, token): path = self.__get_token_path() logger.debug('Caching auth token to ' + path) try: with open(path, 'w') as tf: tf.write(token) except ValueError as e: logger.error("Failure caching auth-token: {0}".format(e.message)) def __get_token_path(self): hoststring_repl = ''.join([c if c not in ':/' else '_' for c in self._hoststring]) tmppath = xdg.XDG_RUNTIME_DIR if not tmppath or not os.path.exists(tmppath): tmppath = tempfile.gettempdir() token_filename = os.path.join(tmppath, AuthManager.__token_filename_format % (hoststring_repl, getpass.getuser())) return token_filename @staticmethod def __normalize_hoststring(hoststring): hoststring_norm = hoststring.rstrip('/') if hoststring_norm.startswith('http:') and hoststring_norm.endswith(':80'): hoststring_norm = hoststring_norm[:-3] elif hoststring_norm.startswith('https:') and hoststring_norm.endswith(':443'): hoststring_norm = hoststring_norm[:-4] return hoststring_norm