diff --git a/Makefile b/Makefile index cd4f81d89..2f7c1ffda 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,10 @@ ifndef PY PY = 3 endif +FORMAT_ENFORCE_DIRS = ./mig/lib ./tests +FORMAT_EXCLUDE_REGEX = '.git|tests/data/|tests/fixture/' +FORMAT_EXCLUDE_GLOB = '.git/* tests/data/* tests/fixture/*' +FORMAT_LINE_LENGTH = 80 LOCAL_PYTHON_BIN = './envhelp/lpython' ifdef PYTHON_BIN @@ -35,7 +39,37 @@ ifneq ($(MIG_ENV),'local') @echo "unavailable outside local development environment" @exit 1 endif - $(LOCAL_PYTHON_BIN) -m autopep8 --ignore E402 -i + @make format-python + +.PHONY:format-python +format-python: + @$(LOCAL_PYTHON_BIN) -m black $(FORMAT_ENFORCE_DIRS) \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --exclude=$(FORMAT_EXCLUDE_REGEX) + @$(LOCAL_PYTHON_BIN) -m isort $(FORMAT_ENFORCE_DIRS) \ + --profile=black \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --skip-glob=$(FORMAT_EXCLUDE_GLOB) + +.PHONY: lint +lint: +ifneq ($(MIG_ENV),'local') + @echo "unavailable outside local development environment" + @exit 1 +endif + @make lint-python + +.PHONY: lint-python +lint-python: + @$(LOCAL_PYTHON_BIN) -m black $(FORMAT_ENFORCE_DIRS) \ + --check \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --exclude $(FORMAT_EXCLUDE_REGEX) + @$(LOCAL_PYTHON_BIN) -m isort $(FORMAT_ENFORCE_DIRS) \ + --check-only \ + --profile=black \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --skip-glob=$(FORMAT_EXCLUDE_GLOB) .PHONY: clean clean: diff --git a/local-requirements.txt b/local-requirements.txt index 7faf3b026..4941e1a6b 100644 --- a/local-requirements.txt +++ b/local-requirements.txt @@ -3,6 +3,8 @@ # This list is mainly used to specify addons needed for the unit tests. # We only need autopep8 on py 3 as it's used in 'make fmt' (with py3) autopep8;python_version >= "3" +black +isort # We need paramiko for the ssh unit tests # NOTE: paramiko-3.0.0 dropped python2 and python3.6 support paramiko;python_version >= "3.7" diff --git a/mig/lib/accounting.py b/mig/lib/accounting.py index 82d3f63f9..3fce1e92f 100644 --- a/mig/lib/accounting.py +++ b/mig/lib/accounting.py @@ -41,11 +41,9 @@ from mig.shared.vgrid import vgrid_list, vgrid_list_vgrids -def __init_accounting_entry(user_bytes=0, - freeze_bytes=0, - vgrid_bytes=None, - peers=None, - ext_users=None): +def __init_accounting_entry( + user_bytes=0, freeze_bytes=0, vgrid_bytes=None, peers=None, ext_users=None +): """Return new user account dict entry""" if vgrid_bytes is None: vgrid_bytes = {} @@ -54,11 +52,13 @@ def __init_accounting_entry(user_bytes=0, if ext_users is None: ext_users = {} - return {'user_bytes': user_bytes, - 'freeze_bytes': freeze_bytes, - 'vgrid_bytes': vgrid_bytes, - 'peers': peers, - 'ext_users': ext_users} + return { + "user_bytes": user_bytes, + "freeze_bytes": freeze_bytes, + "vgrid_bytes": vgrid_bytes, + "peers": peers, + "ext_users": ext_users, + } def __get_owned_vgrid(configuration, verbose=False): @@ -67,17 +67,16 @@ def __get_owned_vgrid(configuration, verbose=False): NOTE: First owner of top-vgrid is primary owner""" logger = configuration.logger result = {} - (status, vgrids) = vgrid_list_vgrids(configuration) + status, vgrids = vgrid_list_vgrids(configuration) if status: for vgrid_name in vgrids: # print("checking vgrid: %s" % check_vgrid_name) - (owners_status, owners_list) = vgrid_list(vgrid_name, - 'owners', - configuration, - recursive=True) + owners_status, owners_list = vgrid_list( + vgrid_name, "owners", configuration, recursive=True + ) # Find first non-zero owner # NOTE: Some owner files contain empty owners) - owner = '' + owner = "" if owners_status and owners_list: owner = next(ent for ent in owners_list if ent) if owner: @@ -85,8 +84,7 @@ def __get_owned_vgrid(configuration, verbose=False): owned_vgrids.append(vgrid_name) result[owner] = owned_vgrids else: - msg = "Failed to find owner for vgrid: %s" \ - % vgrid_name + msg = "Failed to find owner for vgrid: %s" % vgrid_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -108,36 +106,37 @@ def __get_peers_map(configuration, verbose=False): accepted_peers = get_accepted_peers(configuration, client_id) for ext_client_id, value in accepted_peers.items(): if not isinstance(value, dict): - msg = "Invalid peers format: %s: %s: %s" \ - % (client_id, ext_client_id, value) + msg = "Invalid peers format: %s: %s: %s" % ( + client_id, + ext_client_id, + value, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue # Map external users to their peer - ext_users = peer_result.get('ext_users', {}) + ext_users = peer_result.get("ext_users", {}) ext_users[ext_client_id] = value - peer_result['ext_users'] = ext_users + peer_result["ext_users"] = ext_users # Map peers to their external user ext_result = result.get(ext_client_id, {}) - peers = ext_result.get('peers', {}) + peers = ext_result.get("peers", {}) peers[client_id] = value - ext_result['peers'] = peers + ext_result["peers"] = peers result[ext_client_id] = ext_result result[client_id] = peer_result return result -def update_accounting(configuration, - verbose=False): +def update_accounting(configuration, verbose=False): """Update user accounting information""" logger = configuration.logger retval = True - result = {'accounting': {}, - 'quota': {}} - accounting = result['accounting'] - result['timestamp'] = int(time.time()) + result = {"accounting": {}, "quota": {}} + accounting = result["accounting"] + result["timestamp"] = int(time.time()) # Map vgrid to their primary owner msg = "Creating vgrid owners map ..." @@ -184,27 +183,24 @@ def update_accounting(configuration, quota_info_json = entry.path quota_fs = entry.name.replace(".json", "") else: - logger.debug("Skipping non quota info entry: %s" - % entry.name) + logger.debug("Skipping non quota info entry: %s" % entry.name) continue quota_info = None # Try .pck first then .json if quota_info_pck: quota_info = unpickle(quota_info_pck, configuration.logger) elif quota_info_json: - quota_info = load_json(quota_info_json, - configuration.logger, - convert_utf8=False) + quota_info = load_json( + quota_info_json, configuration.logger, convert_utf8=False + ) if not quota_info: - msg = "Failed to load quota info for FS entry: %s" \ - % entry.name + msg = "Failed to load quota info for FS entry: %s" % entry.name logger.error(msg) if verbose: print("ERROR: %s" % msg) retval = False continue - quota_basepath = os.path.join(configuration.quota_home, - quota_fs) + quota_basepath = os.path.join(configuration.quota_home, quota_fs) if not os.path.isdir(quota_basepath): msg = "Missing quota_basepath: %r" % quota_basepath logger.error(msg) @@ -212,14 +208,15 @@ def update_accounting(configuration, print("ERROR: %s" % msg) retval = False continue - quota_mtime = quota_info.get('mtime', 0) - quota_datestr = datetime.datetime.fromtimestamp(quota_mtime) \ - .strftime('%d/%m/%Y-%H:%M:%S') - result['quota'][quota_fs] = {'mtime': quota_mtime} + quota_mtime = quota_info.get("mtime", 0) + quota_datestr = datetime.datetime.fromtimestamp( + quota_mtime + ).strftime("%d/%m/%Y-%H:%M:%S") + result["quota"][quota_fs] = {"mtime": quota_mtime} # User quota - user_path = os.path.join(quota_basepath, 'user') + user_path = os.path.join(quota_basepath, "user") if not os.path.isdir(user_path): msg = "Missing quota user path: %r" % user_path logger.error(msg) @@ -228,11 +225,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s user quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - user_path) + msg = "Scanning %s user quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + user_path, + ) logger.info(msg) if verbose: print(msg) @@ -241,30 +239,34 @@ def update_accounting(configuration, for user_entry in it2: if user_entry.name.endswith(".pck"): client_id = client_dir_id( - user_entry.name.replace('.pck', '')) + user_entry.name.replace(".pck", "") + ) elif user_entry.name.endswith(".json"): client_id = client_dir_id( - user_entry.name.replace('.json', '')) + user_entry.name.replace(".json", "") + ) else: - logger.debug("Skipping non-user entry: %s" - % user_entry.name) + logger.debug( + "Skipping non-user entry: %s" % user_entry.name + ) continue user_quota_files[client_id] = user_entry.path t2 = time.time() - msg = "Scanned %s user quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - user_path, - (t2 - t1)) + msg = "Scanned %s user quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + user_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) # Vgrid quota - vgrid_path = os.path.join(quota_basepath, 'vgrid') + vgrid_path = os.path.join(quota_basepath, "vgrid") if not os.path.isdir(vgrid_path): msg = "Missing quota vgrid path: %r" % vgrid_path logger.error(msg) @@ -273,11 +275,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s vgrid quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - vgrid_path) + msg = "Scanning %s vgrid quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + vgrid_path, + ) logger.info(msg) if verbose: print(msg) @@ -286,26 +289,29 @@ def update_accounting(configuration, for vgrid_entry in it2: if vgrid_entry.name.endswith(".pck"): vgrid_name = force_native_str( - vgrid_entry.name.replace('.pck', '')) + vgrid_entry.name.replace(".pck", "") + ) elif vgrid_entry.name.endswith(".json"): vgrid_name = force_native_str( - vgrid_entry.name.replace('.json', '')) + vgrid_entry.name.replace(".json", "") + ) else: # logger.debug("Skipping non-vgrid entry: %s" # % vgrid_entry.name) continue # NOTE: sub-vgrids uses ':' # as delimiter in 'vgrid_files_writable' - vgrid_name = vgrid_name.replace(':', '/') + vgrid_name = vgrid_name.replace(":", "/") # print("%s: %s" % (vgrid_name, vgrid_entry.path)) vgrid_quota_files[vgrid_name] = vgrid_entry.path t2 = time.time() - msg = "Scanned %s vgrid quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - vgrid_path, - (t2 - t1)) + msg = "Scanned %s vgrid quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + vgrid_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) @@ -313,7 +319,7 @@ def update_accounting(configuration, # Freeze quota if configuration.site_enable_freeze: - freeze_path = os.path.join(quota_basepath, 'freeze') + freeze_path = os.path.join(quota_basepath, "freeze") if not os.path.isdir(freeze_path): msg = "Missing quota freeze path: %r" % freeze_path logger.error(msg) @@ -322,11 +328,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s freeze quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - freeze_path) + msg = "Scanning %s freeze quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + freeze_path, + ) logger.info(msg) if verbose: print(msg) @@ -335,23 +342,27 @@ def update_accounting(configuration, for freeze_entry in it2: if freeze_entry.name.endswith(".pck"): freeze_client_id = client_dir_id( - freeze_entry.name.replace('.pck', '')) + freeze_entry.name.replace(".pck", "") + ) elif freeze_entry.name.endswith(".json"): freeze_client_id = client_dir_id( - freeze_entry.name.replace('.json', '')) + freeze_entry.name.replace(".json", "") + ) else: - logger.debug("Skipping non-freeze entry: %s" - % freeze_entry.name) + logger.debug( + "Skipping non-freeze entry: %s" + % freeze_entry.name + ) continue - freeze_quota_files[freeze_client_id] \ - = freeze_entry.path + freeze_quota_files[freeze_client_id] = freeze_entry.path t2 = time.time() - msg = "Scanned %s freeze quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - freeze_path, - (t2 - t1)) + msg = "Scanned %s freeze quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + freeze_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) @@ -361,17 +372,18 @@ def update_accounting(configuration, vgrids_accounted = [] for client_id, user_quota_filepath in user_quota_files.items(): # Init user accounting - peers = peers_map.get(client_id, {}).get('peers', {}) - ext_users = peers_map.get(client_id, {}).get('ext_users', {}) - accounting[client_id] = __init_accounting_entry(peers=peers, - ext_users=ext_users) + peers = peers_map.get(client_id, {}).get("peers", {}) + ext_users = peers_map.get(client_id, {}).get("ext_users", {}) + accounting[client_id] = __init_accounting_entry( + peers=peers, ext_users=ext_users + ) # Extract user bytes - if user_quota_filepath.endswith('.pck'): + if user_quota_filepath.endswith(".pck"): user_quota = unpickle(user_quota_filepath, configuration) - elif user_quota_filepath.endswith('.json'): - user_quota = load_json(user_quota_filepath, - configuration.logger, - convert_utf8=False) + elif user_quota_filepath.endswith(".json"): + user_quota = load_json( + user_quota_filepath, configuration.logger, convert_utf8=False + ) else: msg = "Invalid user quota file: %r" % user_quota_filepath logger.error(msg) @@ -380,11 +392,13 @@ def update_accounting(configuration, retval = False continue try: - accounting[client_id]['user_bytes'] = user_quota['bytes'] + accounting[client_id]["user_bytes"] = user_quota["bytes"] except Exception as err: - accounting[client_id]['user_bytes'] = 0 - msg = "Failed to load user quota: %r, error: %s" \ - % (user_quota_filepath, err) + accounting[client_id]["user_bytes"] = 0 + msg = "Failed to load user quota: %r, error: %s" % ( + user_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -394,27 +408,28 @@ def update_accounting(configuration, # Extract vgrid bytes for user 'client_id' for vgrid_name in owned_vgrid.get(client_id, []): - vgrid_quota_filepath = vgrid_quota_files.get(vgrid_name, '') + vgrid_quota_filepath = vgrid_quota_files.get(vgrid_name, "") if not os.path.exists(vgrid_quota_filepath): if verbose: # NOTE: Legacy vgrids are accounted at by top-vgrid - vgrid_array = vgrid_name.split('/') - legacy_vgrid = os.path.join(configuration.vgrid_files_home, - vgrid_name) - if not os.path.isdir(legacy_vgrid) \ - or len(vgrid_array) == 1: - msg = "Missing quota for vgrid: %r" \ - % vgrid_name + vgrid_array = vgrid_name.split("/") + legacy_vgrid = os.path.join( + configuration.vgrid_files_home, vgrid_name + ) + if not os.path.isdir(legacy_vgrid) or len(vgrid_array) == 1: + msg = "Missing quota for vgrid: %r" % vgrid_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue - if vgrid_quota_filepath.endswith('.pck'): + if vgrid_quota_filepath.endswith(".pck"): vgrid_quota = unpickle(vgrid_quota_filepath, configuration) - elif vgrid_quota_filepath.endswith('.json'): - vgrid_quota = load_json(vgrid_quota_filepath, - configuration.logger, - convert_utf8=False) + elif vgrid_quota_filepath.endswith(".json"): + vgrid_quota = load_json( + vgrid_quota_filepath, + configuration.logger, + convert_utf8=False, + ) else: msg = "Invalid vgrid quota file: %r" % vgrid_quota_filepath logger.error(msg) @@ -423,12 +438,15 @@ def update_accounting(configuration, retval = False continue try: - accounting[client_id]['vgrid_bytes'][vgrid_name] \ - = vgrid_quota['bytes'] + accounting[client_id]["vgrid_bytes"][vgrid_name] = vgrid_quota[ + "bytes" + ] except Exception as err: - accounting[client_id]['vgrid_bytes'][vgrid_name] = 0 - msg = "Failed to load vgrid quota: %r, error: %s" \ - % (vgrid_quota_filepath, err) + accounting[client_id]["vgrid_bytes"][vgrid_name] = 0 + msg = "Failed to load vgrid quota: %r, error: %s" % ( + vgrid_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -442,13 +460,15 @@ def update_accounting(configuration, for vgrid_name in vgrid_quota_files: if vgrid_name not in vgrids_accounted: - vgridowner = '' + vgridowner = "" for owner, owned_vgrids in owned_vgrid.items(): if vgrid_name in owned_vgrids: vgridowner = owner break - msg = "no accounting for vgrid: %r, missing owner?: %r" \ - % (vgrid_name, vgridowner) + msg = "no accounting for vgrid: %r, missing owner?: %r" % ( + vgrid_name, + vgridowner, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -457,24 +477,25 @@ def update_accounting(configuration, for freeze_name, freeze_quota_filepath in freeze_quota_files.items(): # Extract client_id from legacy freeze archive format - if freeze_name.startswith('archive-'): - legacy_freeze_meta_filepath \ - = os.path.join(configuration.freeze_home, - freeze_name, - 'meta.pck') - legacy_freeze_meta = unpickle(legacy_freeze_meta_filepath, - configuration.logger) + if freeze_name.startswith("archive-"): + legacy_freeze_meta_filepath = os.path.join( + configuration.freeze_home, freeze_name, "meta.pck" + ) + legacy_freeze_meta = unpickle( + legacy_freeze_meta_filepath, configuration.logger + ) if not legacy_freeze_meta: - msg = "Missing metadata for archive: %r" \ - % freeze_name + msg = "Missing metadata for archive: %r" % freeze_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue - client_id = legacy_freeze_meta.get('CREATOR', '') + client_id = legacy_freeze_meta.get("CREATOR", "") if not client_id: - msg = "Failed to extract client_id from: %r" \ - % legacy_freeze_meta_filepath + msg = ( + "Failed to extract client_id from: %r" + % legacy_freeze_meta_filepath + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -486,12 +507,12 @@ def update_accounting(configuration, # Load freeze quota freeze_bytes = 0 - if freeze_quota_filepath.endswith('.pck'): + if freeze_quota_filepath.endswith(".pck"): freeze_quota = unpickle(freeze_quota_filepath, configuration) - elif freeze_quota_filepath.endswith('.json'): - freeze_quota = load_json(freeze_quota_filepath, - configuration.logger, - convert_utf8=False) + elif freeze_quota_filepath.endswith(".json"): + freeze_quota = load_json( + freeze_quota_filepath, configuration.logger, convert_utf8=False + ) else: msg = "Invalid freeze quota file: %r" % freeze_quota_filepath logger.error(msg) @@ -500,11 +521,13 @@ def update_accounting(configuration, retval = False continue try: - freeze_bytes = int(freeze_quota['bytes']) + freeze_bytes = int(freeze_quota["bytes"]) except Exception as err: freeze_bytes = 0 - msg = "Failed to fetch freeze quota: %r, error: %s" \ - % (freeze_quota_filepath, err) + msg = "Failed to fetch freeze quota: %r, error: %s" % ( + freeze_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -512,24 +535,27 @@ def update_accounting(configuration, continue if freeze_bytes > 0: - freeze_accounting = accounting.get(client_id, '') + freeze_accounting = accounting.get(client_id, "") if not freeze_accounting: - msg = "added missing archive user: %r : %d" \ - % (client_id, freeze_bytes) + msg = "added missing archive user: %r : %d" % ( + client_id, + freeze_bytes, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) accounting[client_id] = __init_accounting_entry() freeze_accounting = accounting[client_id] - freeze_accounting['freeze_bytes'] += freeze_bytes + freeze_accounting["freeze_bytes"] += freeze_bytes # Save accounting result - accounting_filepath = os.path.join(configuration.accounting_home, - "%s.pck" % result['timestamp']) + accounting_filepath = os.path.join( + configuration.accounting_home, "%s.pck" % result["timestamp"] + ) status = pickle(result, accounting_filepath, configuration.logger) if status: - latest = os.path.join(configuration.accounting_home, 'latest') + latest = os.path.join(configuration.accounting_home, "latest") status = make_symlink(accounting_filepath, latest, logger, force=True) if not status: retval = False @@ -546,33 +572,26 @@ def human_readable_filesize(filesize): return "0 B" try: p = int(math.floor(math.log(filesize, 2) / 10)) - return "%.3f %s" % (filesize / math.pow(1024, p), - ['B', - 'KiB', - 'MiB', - 'GiB', - 'TiB', - 'PiB', - 'EiB', - 'ZiB', - 'YiB'][p]) + return "%.3f %s" % ( + filesize / math.pow(1024, p), + ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"][p], + ) except (ValueError, TypeError, IndexError): - return 'NaN' + return "NaN" -def get_usage(configuration, - userlist=[], - timestamp=0, - verbose=False): +def get_usage(configuration, userlist=[], timestamp=0, verbose=False): """Generate and return 'storage' usage""" # Load accounting if it exists logger = configuration.logger if timestamp == 0: - accounting_filepath = os.path.join(configuration.accounting_home, - "latest") + accounting_filepath = os.path.join( + configuration.accounting_home, "latest" + ) else: - accounting_filepath = os.path.join(configuration.accounting_home, - "%s.pck" % timestamp) + accounting_filepath = os.path.join( + configuration.accounting_home, "%s.pck" % timestamp + ) data = unpickle(accounting_filepath, configuration.logger) if not data: msg = "Failed to load accounting data from: %r" % accounting_filepath @@ -581,7 +600,7 @@ def get_usage(configuration, print("ERROR: %s" % msg) return None - accounting = data.get('accounting', {}) + accounting = data.get("accounting", {}) # Do not show external users as main accounts unless requested # or if the user act as both peer and external user @@ -590,10 +609,13 @@ def get_usage(configuration, peer_users = [] skip_ext_users = [] for values in accounting.values(): - ext_users.extend(list(values.get('ext_users', {}))) - peer_users.extend(list(values.get('peers', {}))) - skip_ext_users = [user for user in ext_users - if user not in userlist and user not in peer_users] + ext_users.extend(list(values.get("ext_users", {}))) + peer_users.extend(list(values.get("peers", {}))) + skip_ext_users = [ + user + for user in ext_users + if user not in userlist and user not in peer_users + ] # Create accounting report @@ -610,7 +632,7 @@ def get_usage(configuration, # Home usage - home_bytes = values.get('user_bytes', 0) + home_bytes = values.get("user_bytes", 0) total_bytes += home_bytes home_report = "" if create_reports: @@ -619,7 +641,7 @@ def get_usage(configuration, # Freeze archive usage - freeze_bytes = values.get('freeze_bytes', 0) + freeze_bytes = values.get("freeze_bytes", 0) total_bytes += freeze_bytes freeze_report = "" if create_reports and freeze_bytes > 0: @@ -630,32 +652,34 @@ def get_usage(configuration, vgrid_report = "" vgrid_total = 0 - for vgrid_name, vgrid_bytes in values.get('vgrid_bytes', {}).items(): + for vgrid_name, vgrid_bytes in values.get("vgrid_bytes", {}).items(): vgrid_total += vgrid_bytes if create_reports: vgrid_bytes_human = human_readable_filesize(vgrid_bytes) - vgrid_report += "\n - %s: %s" \ - % (vgrid_name, vgrid_bytes_human) + vgrid_report += "\n - %s: %s" % (vgrid_name, vgrid_bytes_human) if vgrid_report: - vgrid_report = "%s usage (total: %s)%s" \ - % (configuration.site_vgrid_label, - human_readable_filesize(vgrid_total), - vgrid_report) + vgrid_report = "%s usage (total: %s)%s" % ( + configuration.site_vgrid_label, + human_readable_filesize(vgrid_total), + vgrid_report, + ) total_bytes += vgrid_total # Create account usage entry - account_usage[username] = {'total_bytes': total_bytes, - 'home_total': home_bytes, - 'vgrid_total': vgrid_total, - 'freeze_total': freeze_bytes, - 'ext_users_total': 0, - 'total_report': '', - 'home_report': home_report, - 'freeze_report': freeze_report, - 'vgrid_report': vgrid_report, - 'ext_users_report': '', - 'peers_report': ''} + account_usage[username] = { + "total_bytes": total_bytes, + "home_total": home_bytes, + "vgrid_total": vgrid_total, + "freeze_total": freeze_bytes, + "ext_users_total": 0, + "total_report": "", + "home_report": home_report, + "freeze_report": freeze_report, + "vgrid_report": vgrid_report, + "ext_users_report": "", + "peers_report": "", + } # Create external users report # NOTE: We need total bytes and therefore we need the above full report @@ -667,13 +691,12 @@ def get_usage(configuration, if userlist and username not in userlist: continue # Create ext_users report - ext_users = values.get('ext_users', {}) - peers = values.get('peers', {}) + ext_users = values.get("ext_users", {}) + peers = values.get("peers", {}) if not ext_users: continue if ext_users and peers: - msg = "User %r acts as both peer and external user" \ - % username + msg = "User %r acts as both peer and external user" % username logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -683,20 +706,25 @@ def get_usage(configuration, ext_users_report = "" ext_users_total = 0 for ext_user in ext_users: - ext_user_total_bytes = account_usage.get( - ext_user, {}).get('total_bytes', 0) + ext_user_total_bytes = account_usage.get(ext_user, {}).get( + "total_bytes", 0 + ) ext_users_total += ext_user_total_bytes ext_user_total_bytes_human = human_readable_filesize( - ext_user_total_bytes) - ext_users_report += "\n - %s: %s" % (ext_user, - ext_user_total_bytes_human) + ext_user_total_bytes + ) + ext_users_report += "\n - %s: %s" % ( + ext_user, + ext_user_total_bytes_human, + ) if ext_users_report: - ext_users_report = "External users usage (total: %s):%s" \ - % (human_readable_filesize(ext_users_total), - ext_users_report) - account_usage[username]['ext_users_total'] = ext_users_total - account_usage[username]['ext_users_report'] = ext_users_report - account_usage[username]['total_bytes'] += ext_users_total + ext_users_report = "External users usage (total: %s):%s" % ( + human_readable_filesize(ext_users_total), + ext_users_report, + ) + account_usage[username]["ext_users_total"] = ext_users_total + account_usage[username]["ext_users_report"] = ext_users_report + account_usage[username]["total_bytes"] += ext_users_total # Create peers report @@ -705,22 +733,26 @@ def get_usage(configuration, peers_report += "\n - %s" % peer if peers_report: peers_report = "Accepted by the following peer:%s" % peers_report - account_usage[username]['peers_report'] = peers_report + account_usage[username]["peers_report"] = peers_report # Create total usage report for each user for usage in account_usage.values(): - usage['total_report'] = "Total usage: %s" \ - % human_readable_filesize(usage['total_bytes']) + usage["total_report"] = "Total usage: %s" % human_readable_filesize( + usage["total_bytes"] + ) # External users are accounted for by their peer # unless the external user also act as a peer result = {} - result['timestamp'] = data.get('timestamp', 0) - result['quota'] = data.get('quota', {}) - result['accounting'] = {username: values for username, values - in account_usage.items() - if not userlist or username in userlist - and username not in skip_ext_users} + result["timestamp"] = data.get("timestamp", 0) + result["quota"] = data.get("quota", {}) + result["accounting"] = { + username: values + for username, values in account_usage.items() + if not userlist + or username in userlist + and username not in skip_ext_users + } return result diff --git a/mig/lib/events.py b/mig/lib/events.py index 6b49b999c..c6716db5c 100644 --- a/mig/lib/events.py +++ b/mig/lib/events.py @@ -433,8 +433,7 @@ def run_cron_command( _restore_env(saved_environ, os.environ) raise exc logger.info( - "(%s) done running command for %s: %s" % ( - pid, target_path, command_str) + "(%s) done running command for %s: %s" % (pid, target_path, command_str) ) # logger.debug('(%s) raw output is: %s' % (pid, output_objects)) @@ -532,8 +531,7 @@ def run_events_command( _restore_env(saved_environ, os.environ) raise exc logger.info( - "(%s) done running command for %s: %s" % ( - pid, target_path, command_str) + "(%s) done running command for %s: %s" % (pid, target_path, command_str) ) # logger.debug('(%s) raw output is: %s' % (pid, output_objects)) diff --git a/mig/lib/janitor.py b/mig/lib/janitor.py index ea3e0051e..c18590226 100644 --- a/mig/lib/janitor.py +++ b/mig/lib/janitor.py @@ -36,8 +36,11 @@ import os import time -from mig.shared.accountreq import accept_account_req, existing_user_collision, \ - reject_account_req +from mig.shared.accountreq import ( + accept_account_req, + existing_user_collision, + reject_account_req, +) from mig.shared.base import get_user_id from mig.shared.fileio import delete_file, listdir from mig.shared.pwcrypto import verify_reset_token @@ -274,7 +277,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): _logger.info("%r made an invalid account request" % client_id) # NOTE: 'invalid' is a list of validation error strings if set reason = "invalid request: %s." % ". ".join(req_invalid) - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -284,21 +287,27 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject invalid %r account request: %s" % (client_id, - rej_err) + "failed to reject invalid %r account request: %s" + % (client_id, rej_err) ) else: _logger.info("rejected invalid %r account request" % client_id) elif authorized: - _logger.info("%r requested renew and authorized password change" % - client_id) + _logger.info( + "%r requested renew and authorized password change" % client_id + ) peer_id = user_dict.get("peers", [None])[0] # NOTE: let authorized reqs (with valid peer) renew even with pw change default_renew = True - if accept_account_req(req_id, configuration, peer_id, - user_copy=user_copy, admin_copy=admin_copy, - auth_type=auth_type, - default_renew=default_renew): + if accept_account_req( + req_id, + configuration, + peer_id, + user_copy=user_copy, + admin_copy=admin_copy, + auth_type=auth_type, + default_renew=default_renew, + ): _logger.info("accepted authorized %r access renew" % client_id) else: _logger.warning("failed authorized %r access renew" % client_id) @@ -313,7 +322,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): "%r requested and authorized password reset" % client_id ) peer_id = user_dict.get("peers", [None])[0] - (acc_status, acc_err) = accept_account_req( + acc_status, acc_err = accept_account_req( req_id, configuration, peer_id, @@ -324,18 +333,18 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not acc_status: _logger.warning( - "failed to accept %r password reset: %s" % (client_id, - acc_err) + "failed to accept %r password reset: %s" + % (client_id, acc_err) ) else: _logger.info("accepted %r password reset" % client_id) else: _logger.warning( - "%r requested password reset with bad token: %s" % ( - client_id, reset_token) + "%r requested password reset with bad token: %s" + % (client_id, reset_token) ) reason = "invalid password reset token" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -345,8 +354,8 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject %r password reset: %s" % (client_id, - rej_err) + "failed to reject %r password reset: %s" + % (client_id, rej_err) ) else: _logger.info("rejected %r password reset" % client_id) @@ -354,7 +363,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): # NOTE: probably should no longer happen after initial auto clean _logger.warning("%r request is now past expire" % client_id) reason = "expired request - please re-request if still relevant" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -363,15 +372,15 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): auth_type=auth_type, ) if not rej_status: - _logger.warning("failed to reject expired %r request: %s" % - (client_id, rej_err) - ) + _logger.warning( + "failed to reject expired %r request: %s" % (client_id, rej_err) + ) else: _logger.info("rejected %r request now past expire" % client_id) elif existing_user_collision(configuration, req_dict, client_id): _logger.warning("ID collision in request from %r" % client_id) reason = "ID collision - please re-request with *existing* ID fields" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -381,8 +390,8 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject %r request with ID collision: %s" % - (client_id, rej_err) + "failed to reject %r request with ID collision: %s" + % (client_id, rej_err) ) else: _logger.info("rejected %r request with ID collision" % client_id) @@ -417,8 +426,7 @@ def manage_trivial_user_requests(configuration, now=None): continue req_id = filename req_path = os.path.join(configuration.user_pending, req_id) - _logger.debug("checking if account request in %r is trivial" % - req_path) + _logger.debug("checking if account request in %r is trivial" % req_path) req_age = now - os.path.getmtime(req_path) req_age_minutes = req_age / SECS_PER_MINUTE if req_age_minutes > MANAGE_TRIVIAL_REQ_MINUTES: @@ -428,8 +436,7 @@ def manage_trivial_user_requests(configuration, now=None): ) manage_single_req(configuration, req_id, req_path, db_path, now) handled += 1 - _logger.debug("handled %d trivial user account request action(s)" % - handled) + _logger.debug("handled %d trivial user account request action(s)" % handled) return handled @@ -474,7 +481,7 @@ def remind_and_expire_user_pending(configuration, now=None): ) user_copy = True admin_copy = True - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -483,11 +490,12 @@ def remind_and_expire_user_pending(configuration, now=None): auth_type=auth_type, ) if not rej_status: - _logger.warning("failed to expire %s request from %r: %s" % - (req_id, client_id, rej_err)) + _logger.warning( + "failed to expire %s request from %r: %s" + % (req_id, client_id, rej_err) + ) else: - _logger.info("expired %s request from %r" % (req_id, - client_id)) + _logger.info("expired %s request from %r" % (req_id, client_id)) handled += 1 _logger.debug("handled %d user account request action(s)" % handled) return handled diff --git a/mig/lib/lustrequota.py b/mig/lib/lustrequota.py index ea669b936..695461ebc 100644 --- a/mig/lib/lustrequota.py +++ b/mig/lib/lustrequota.py @@ -41,13 +41,24 @@ psutil = None from mig.shared.base import force_unicode -from mig.shared.fileio import make_symlink, makedirs_rec, pickle, save_json, \ - scandir, unpickle, walk, write_file +from mig.shared.fileio import ( + make_symlink, + makedirs_rec, + pickle, + save_json, + scandir, + unpickle, + walk, + write_file, +) from mig.shared.vgrid import vgrid_flat_name try: - from lustreclient.lfs import lfs_get_project_quota, lfs_set_project_id, \ - lfs_set_project_quota + from lustreclient.lfs import ( + lfs_get_project_quota, + lfs_set_project_id, + lfs_set_project_quota, + ) except ImportError: lfs_set_project_id = None lfs_get_project_quota = None @@ -63,16 +74,19 @@ def __get_lustre_basepath(configuration, lustre_basepath=None): valid_lustre_basepath = None for dpart in psutil.disk_partitions(all=True): if dpart.fstype == "lustre": - if lustre_basepath \ - and lustre_basepath.startswith(dpart.mountpoint) \ - and os.path.isdir(lustre_basepath): + if ( + lustre_basepath + and lustre_basepath.startswith(dpart.mountpoint) + and os.path.isdir(lustre_basepath) + ): valid_lustre_basepath = lustre_basepath break elif dpart.mountpoint.endswith(configuration.server_fqdn): valid_lustre_basepath = dpart.mountpoint else: - check_lustre_basepath = os.path.join(dpart.mountpoint, - configuration.server_fqdn) + check_lustre_basepath = os.path.join( + dpart.mountpoint, configuration.server_fqdn + ) if os.path.isdir(check_lustre_basepath): valid_lustre_basepath = check_lustre_basepath break @@ -85,8 +99,9 @@ def __get_gocryptfs_socket(configuration, gocryptfs_sock=None): otherwise return default if it exists""" valid_gocryptfs_sock = None if gocryptfs_sock is None: - gocryptfs_sock = "/var/run/gocryptfs.%s.sock" \ - % configuration.server_fqdn + gocryptfs_sock = ( + "/var/run/gocryptfs.%s.sock" % configuration.server_fqdn + ) if os.path.exists(gocryptfs_sock): gocryptfs_sock_stat = os.lstat(gocryptfs_sock) if stat.S_ISSOCK(gocryptfs_sock_stat.st_mode): @@ -95,12 +110,14 @@ def __get_gocryptfs_socket(configuration, gocryptfs_sock=None): return valid_gocryptfs_sock -def __shellexec(configuration, - command, - args=[], - stdin_str=None, - stdout_filepath=None, - stderr_filepath=None): +def __shellexec( + configuration, + command, + args=[], + stdin_str=None, + stdout_filepath=None, + stderr_filepath=None, +): """Execute shell command Returns (exit_code, stdout, stderr) of subprocess""" result = 0 @@ -116,10 +133,8 @@ def __shellexec(configuration, __args.extend(args) logger.debug("__args: %s" % __args) process = subprocess.Popen( - __args, - stdin=stdin_handle, - stdout=stdout_handle, - stderr=stderr_handle) + __args, stdin=stdin_handle, stdout=stdout_handle, stderr=stderr_handle + ) if stdin_str: process.stdin.write(stdin_str.encode()) stdout, stderr = process.communicate() @@ -145,28 +160,26 @@ def __shellexec(configuration, if stderr: stderr = force_unicode(stderr) if result == 0: - logger.debug("%s %s: rc: %s, stdout: %s, error: %s" - % (command, - " ".join(args), - rc, - stdout, - stderr)) + logger.debug( + "%s %s: rc: %s, stdout: %s, error: %s" + % (command, " ".join(args), rc, stdout, stderr) + ) else: - logger.error("shellexec: %s %s: rc: %s, stdout: %s, error: %s" - % (command, - " ".join(__args), - rc, - stdout, - stderr)) + logger.error( + "shellexec: %s %s: rc: %s, stdout: %s, error: %s" + % (command, " ".join(__args), rc, stdout, stderr) + ) return (rc, stdout, stderr) -def __set_project_id(configuration, - lustre_basepath, - quota_datapath, - quota_name, - quota_lustre_pid): +def __set_project_id( + configuration, + lustre_basepath, + quota_datapath, + quota_name, + quota_lustre_pid, +): """Set lustre project *quota_lustre_pid* Find the next *free* project id (PID) if *quota_lustre_pid* is occupied NOTE: lustre uses a global counter for project id's (PID) @@ -181,19 +194,22 @@ def __set_project_id(configuration, logger = configuration.logger next_lustre_pid = quota_lustre_pid while next_lustre_pid < max_lustre_pid: - (rc, currfiles, _, _, _) \ - = lfs_get_project_quota(lustre_basepath, next_lustre_pid) + rc, currfiles, _, _, _ = lfs_get_project_quota( + lustre_basepath, next_lustre_pid + ) if rc != 0: - logger.error("Failed to fetch quota for lustre project id: %d, %r" - % (next_lustre_pid, lustre_basepath) - + ", rc: %d" % rc) + logger.error( + "Failed to fetch quota for lustre project id: %d, %r" + % (next_lustre_pid, lustre_basepath) + + ", rc: %d" % rc + ) return -1 if currfiles == 0: break - logger.info("Skipping project id: %d" - % next_lustre_pid - + " already registered with %d files" - % currfiles) + logger.info( + "Skipping project id: %d" % next_lustre_pid + + " already registered with %d files" % currfiles + ) next_lustre_pid += 1 if next_lustre_pid == max_lustre_pid: @@ -202,22 +218,28 @@ def __set_project_id(configuration, # Set new project id - logger.info("Setting lustre project id: %d for %r: %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.info( + "Setting lustre project id: %d for %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) rc = lfs_set_project_id(quota_datapath, next_lustre_pid, 1) if rc != 0: - logger.error("lfs_set_project_id failed for lustre project id: %d for %r: %r" - % (next_lustre_pid, quota_name, quota_datapath) - + ", rc: %d" % rc) + logger.error( + "lfs_set_project_id failed for lustre project id: %d for %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + + ", rc: %d" % rc + ) return -1 # Dump lustre pid in quota_datapath and wait for it to appear in the quota - lustre_pid_filepath = os.path.join(quota_datapath, '.lustrepid') + lustre_pid_filepath = os.path.join(quota_datapath, ".lustrepid") status = write_file(next_lustre_pid, lustre_pid_filepath, logger) if not status: - logger.error("Failed write lustre project id: %d for %r to %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed write lustre project id: %d for %r to %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) return -1 # Wait for files to appear in quota before returning @@ -226,36 +248,44 @@ def __set_project_id(configuration, waiting = 0 max_waiting = 60 while files == 0 and waiting < max_waiting: - (rc, files, _, _, _) \ - = lfs_get_project_quota(lustre_basepath, next_lustre_pid) + rc, files, _, _, _ = lfs_get_project_quota( + lustre_basepath, next_lustre_pid + ) if rc != 0: files = 0 - logger.error("lfs_get_project_quota failed for:" - + " %d, %r, %r, rc: %d" - % (next_lustre_pid, quota_name, quota_datapath, rc)) + logger.error( + "lfs_get_project_quota failed for:" + + " %d, %r, %r, rc: %d" + % (next_lustre_pid, quota_name, quota_datapath, rc) + ) if files == 0: - logger.info("Waiting for lustre quota: %d: %r: %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.info( + "Waiting for lustre quota: %d: %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) time.sleep(1) max_waiting += 1 if waiting == max_waiting: - logger.error("Failed to fetch quota for:" - + " %d, %r, %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed to fetch quota for:" + + " %d, %r, %r" % (next_lustre_pid, quota_name, quota_datapath) + ) return -1 return next_lustre_pid -def __update_quota(configuration, - lustre_basepath, - lustre_setting, - quota_name, - quota_type, - data_basefs, - gocryptfs_sock, - timestamp): +def __update_quota( + configuration, + lustre_basepath, + lustre_setting, + quota_name, + quota_type, + data_basefs, + gocryptfs_sock, + timestamp, +): """Update quota for *quota_name*, if new entry then assign lustre project id and set default quota. If existing entry then update quota settings if changed @@ -263,15 +293,17 @@ def __update_quota(configuration, """ logger = configuration.logger quota_limits_changed = False - next_lustre_pid = lustre_setting.get('next_pid', -1) + next_lustre_pid = lustre_setting.get("next_pid", -1) if next_lustre_pid == -1: - logger.error("Invalid lustre quota next_pid: %d for: %r" - % (next_lustre_pid, quota_name)) + logger.error( + "Invalid lustre quota next_pid: %d for: %r" + % (next_lustre_pid, quota_name) + ) return False # Resolve quota limit and data basepath - if quota_type == 'vgrid': + if quota_type == "vgrid": default_quota_limit = configuration.quota_vgrid_limit data_basepath = configuration.vgrid_files_writable else: @@ -279,11 +311,13 @@ def __update_quota(configuration, data_basepath = configuration.user_home if data_basepath.startswith(configuration.state_path): - rel_data_basepath = data_basepath. \ - replace(configuration.state_path, "").lstrip(os.sep) + rel_data_basepath = data_basepath.replace( + configuration.state_path, "" + ).lstrip(os.sep) else: - logger.error("Failed to resolve relative data basepath from: %r" - % data_basepath) + logger.error( + "Failed to resolve relative data basepath from: %r" % data_basepath + ) return False # Resolve quota data path @@ -291,37 +325,38 @@ def __update_quota(configuration, if configuration.quota_backend == "lustre": quota_basefs = "lustre" - quota_datapath = os.path.join(lustre_basepath, - rel_data_basepath, - quota_name) + quota_datapath = os.path.join( + lustre_basepath, rel_data_basepath, quota_name + ) elif configuration.quota_backend == "lustre-gocryptfs": quota_basefs = "fuse.gocryptfs" stdin_str = os.path.join(rel_data_basepath, quota_name) cmd = "gocryptfs-xray -encrypt-paths %s" % gocryptfs_sock - (rc, stdout, stderr) = __shellexec(configuration, - cmd, - stdin_str=stdin_str) + rc, stdout, stderr = __shellexec( + configuration, cmd, stdin_str=stdin_str + ) if rc == 0 and stdout: encoded_path = stdout.strip() - quota_datapath = os.path.join(lustre_basepath, - encoded_path) + quota_datapath = os.path.join(lustre_basepath, encoded_path) else: - logger.error("Failed to resolve encrypted path for: %r" - % quota_name - + ", rc: %d, error: %s" - % (rc, stderr)) + logger.error( + "Failed to resolve encrypted path for: %r" % quota_name + + ", rc: %d, error: %s" % (rc, stderr) + ) return False else: - logger.error("Invalid quota backend: %r" - % configuration.quota_backend) + logger.error("Invalid quota backend: %r" % configuration.quota_backend) return False # Check if valid lustre data dir if not os.path.isdir(quota_datapath): - msg = "skipping entry: %r : %r, no lustre data path: %r" \ - % (quota_type, quota_name, quota_datapath) + msg = "skipping entry: %r : %r, no lustre data path: %r" % ( + quota_type, + quota_name, + quota_datapath, + ) # NOTE: log error and return false if dir is missing # and we expect data to be on lustre or gocryoptfs) if data_basefs == quota_basefs: @@ -335,143 +370,170 @@ def __update_quota(configuration, # Load quota if it exists otherwise new quota - quota_filepath = os.path.join(configuration.quota_home, - configuration.quota_backend, - quota_type, - "%s.pck" % quota_name) + quota_filepath = os.path.join( + configuration.quota_home, + configuration.quota_backend, + quota_type, + "%s.pck" % quota_name, + ) if os.path.exists(quota_filepath): quota = unpickle(quota_filepath, logger) if not quota: - logger.error("Failed to load quota settings for: %r from %r" - % (quota_name, quota_filepath)) + logger.error( + "Failed to load quota settings for: %r from %r" + % (quota_name, quota_filepath) + ) return False else: - quota = {'lustre_pid': next_lustre_pid, - 'files': -1, - 'bytes': -1, - 'softlimit_bytes': -1, - 'hardlimit_bytes': -1, - } + quota = { + "lustre_pid": next_lustre_pid, + "files": -1, + "bytes": -1, + "softlimit_bytes": -1, + "hardlimit_bytes": -1, + } # Fetch quota lustre pid - quota_lustre_pid = quota.get('lustre_pid', -1) + quota_lustre_pid = quota.get("lustre_pid", -1) if quota_lustre_pid == -1: - logger.error("Invalid quota lustre pid: %d for %r" - % (quota_lustre_pid, quota_name)) + logger.error( + "Invalid quota lustre pid: %d for %r" + % (quota_lustre_pid, quota_name) + ) return False # If new entry then set lustre project id new_lustre_pid = -1 if quota_lustre_pid == next_lustre_pid: - new_lustre_pid = __set_project_id(configuration, - lustre_basepath, - quota_datapath, - quota_name, - quota_lustre_pid) + new_lustre_pid = __set_project_id( + configuration, + lustre_basepath, + quota_datapath, + quota_name, + quota_lustre_pid, + ) if new_lustre_pid == -1: - logger.error("Failed to set project id: %d, %r, %r" - % (new_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed to set project id: %d, %r, %r" + % (new_lustre_pid, quota_name, quota_datapath) + ) return False - lustre_setting['next_pid'] = new_lustre_pid + 1 - quota['lustre_pid'] = quota_lustre_pid = new_lustre_pid + lustre_setting["next_pid"] = new_lustre_pid + 1 + quota["lustre_pid"] = quota_lustre_pid = new_lustre_pid # Get current quota values for lustre_pid - (rc, currfiles, currbytes, softlimit_bytes, hardlimit_bytes) \ - = lfs_get_project_quota(lustre_basepath, quota_lustre_pid) + rc, currfiles, currbytes, softlimit_bytes, hardlimit_bytes = ( + lfs_get_project_quota(lustre_basepath, quota_lustre_pid) + ) if rc != 0: - logger.error("lfs_get_project_quota failed for: %d, %r, %r" - % (quota_lustre_pid, quota_name, quota_datapath) - + ", rc: %d" % rc) + logger.error( + "lfs_get_project_quota failed for: %d, %r, %r" + % (quota_lustre_pid, quota_name, quota_datapath) + + ", rc: %d" % rc + ) return False # Update quota info if currfiles == 0 or currbytes == 0: - logger.warning("lustre_basepath: %r: pid: %d: quota_type: %s" - % (lustre_basepath, quota_lustre_pid, quota_type) - + "quota_name: %s, files: %d, bytes: %d" - % (quota_name, currfiles, currbytes)) + logger.warning( + "lustre_basepath: %r: pid: %d: quota_type: %s" + % (lustre_basepath, quota_lustre_pid, quota_type) + + "quota_name: %s, files: %d, bytes: %d" + % (quota_name, currfiles, currbytes) + ) - quota['mtime'] = timestamp - quota['files'] = currfiles - quota['bytes'] = currbytes + quota["mtime"] = timestamp + quota["files"] = currfiles + quota["bytes"] = currbytes # If new entry use default quota # and update quota if changed if new_lustre_pid > -1: quota_limits_changed = True - quota['softlimit_bytes'] = default_quota_limit - quota['hardlimit_bytes'] = default_quota_limit - elif hardlimit_bytes != quota.get('hardlimit_bytes', -1) \ - or softlimit_bytes != quota.get('softlimit_bytes', -1): + quota["softlimit_bytes"] = default_quota_limit + quota["hardlimit_bytes"] = default_quota_limit + elif hardlimit_bytes != quota.get( + "hardlimit_bytes", -1 + ) or softlimit_bytes != quota.get("softlimit_bytes", -1): quota_limits_changed = True - quota['softlimit_bytes'] = softlimit_bytes - quota['hardlimit_bytes'] = hardlimit_bytes + quota["softlimit_bytes"] = softlimit_bytes + quota["hardlimit_bytes"] = hardlimit_bytes if quota_limits_changed: - rc = lfs_set_project_quota(quota_datapath, - quota_lustre_pid, - quota['softlimit_bytes'], - quota['hardlimit_bytes'], - ) + rc = lfs_set_project_quota( + quota_datapath, + quota_lustre_pid, + quota["softlimit_bytes"], + quota["hardlimit_bytes"], + ) if rc != 0: - logger.error("Failed to set quota limit: %d/%d" - % (softlimit_bytes, - hardlimit_bytes) - + " for lustre project id: %d, %r, %r, rc: %d" - % (quota_lustre_pid, - quota_name, - quota_datapath, - rc)) + logger.error( + "Failed to set quota limit: %d/%d" + % (softlimit_bytes, hardlimit_bytes) + + " for lustre project id: %d, %r, %r, rc: %d" + % (quota_lustre_pid, quota_name, quota_datapath, rc) + ) return False # Save current quota - new_quota_basepath = os.path.join(configuration.quota_home, - configuration.quota_backend, - quota_type, - str(timestamp)) - if not os.path.exists(new_quota_basepath) \ - and not makedirs_rec(new_quota_basepath, configuration): - logger.error("Failed to create new quota base path: %r" - % new_quota_basepath) + new_quota_basepath = os.path.join( + configuration.quota_home, + configuration.quota_backend, + quota_type, + str(timestamp), + ) + if not os.path.exists(new_quota_basepath) and not makedirs_rec( + new_quota_basepath, configuration + ): + logger.error( + "Failed to create new quota base path: %r" % new_quota_basepath + ) return False - new_quota_filepath_pck = os.path.join(new_quota_basepath, - "%s.pck" % quota_name) + new_quota_filepath_pck = os.path.join( + new_quota_basepath, "%s.pck" % quota_name + ) - logger.debug("Saving: %s: %s: %s -> %r" - % (quota_type, quota_name, quota, new_quota_filepath_pck)) + logger.debug( + "Saving: %s: %s: %s -> %r" + % (quota_type, quota_name, quota, new_quota_filepath_pck) + ) status = pickle(quota, new_quota_filepath_pck, logger) if not status: - logger.error("Failed to save quota for: %r to %r" - % (quota_name, new_quota_filepath_pck)) + logger.error( + "Failed to save quota for: %r to %r" + % (quota_name, new_quota_filepath_pck) + ) return False - new_quota_filepath_json = os.path.join(new_quota_basepath, - "%s.json" % quota_name) - status = save_json(quota, - new_quota_filepath_json, - logger) + new_quota_filepath_json = os.path.join( + new_quota_basepath, "%s.json" % quota_name + ) + status = save_json(quota, new_quota_filepath_json, logger) if not status: - logger.error("Failed to save quota for: %r to %r" - % (quota_name, new_quota_filepath_json)) + logger.error( + "Failed to save quota for: %r to %r" + % (quota_name, new_quota_filepath_json) + ) return False # Create symlink to new quota - status = make_symlink(new_quota_filepath_pck, - quota_filepath, - logger, - force=True) + status = make_symlink( + new_quota_filepath_pck, quota_filepath, logger, force=True + ) if not status: - logger.error("Failed to make quota symlink for: %r: %r -> %r" - % (quota_name, new_quota_filepath_pck, quota_filepath)) + logger.error( + "Failed to make quota symlink for: %r: %r -> %r" + % (quota_name, new_quota_filepath_pck, quota_filepath) + ) return False return True @@ -483,9 +545,11 @@ def update_lustre_quota(configuration): # Check if lustreclient module was imported correctly - if lfs_set_project_id is None \ - or lfs_get_project_quota is None \ - or lfs_set_project_quota is None: + if ( + lfs_set_project_id is None + or lfs_get_project_quota is None + or lfs_set_project_quota is None + ): logger.error("Failed to import lustreclient module") return False @@ -496,11 +560,11 @@ def update_lustre_quota(configuration): lustre_basepath = __get_lustre_basepath(configuration) if lustre_basepath: - logger.debug("Using lustre basepath: %r" - % lustre_basepath) + logger.debug("Using lustre basepath: %r" % lustre_basepath) else: - logger.error("Found no valid lustre mounts for: %s" - % configuration.server_fqdn) + logger.error( + "Found no valid lustre mounts for: %s" % configuration.server_fqdn + ) return False # Get gocryptfs socket if enabled @@ -509,37 +573,38 @@ def update_lustre_quota(configuration): if configuration.quota_backend == "lustre-gocryptfs": gocryptfs_sock = __get_gocryptfs_socket(configuration) if gocryptfs_sock: - logger.debug("Using gocryptfs socket: %r" - % gocryptfs_sock) + logger.debug("Using gocryptfs socket: %r" % gocryptfs_sock) else: logger.error("Missing gocryptfs socket") return False # Load lustre quota settings - lustre_setting_filepath = os.path.join(configuration.quota_home, - '%s.pck' - % configuration.quota_backend) + lustre_setting_filepath = os.path.join( + configuration.quota_home, "%s.pck" % configuration.quota_backend + ) if os.path.exists(lustre_setting_filepath): - lustre_setting = unpickle(lustre_setting_filepath, - logger) + lustre_setting = unpickle(lustre_setting_filepath, logger) if not lustre_setting: - logger.error("Failed to load lustre quota: %r" - % lustre_setting_filepath) + logger.error( + "Failed to load lustre quota: %r" % lustre_setting_filepath + ) return False else: - lustre_setting = {'next_pid': 1, - 'mtime': 0} + lustre_setting = {"next_pid": 1, "mtime": 0} # Update quota - quota_targets = {'vgrid': {'basefs': 'lustre', - 'entries': {}, - }, - 'user': {'basefs': 'lustre', - 'entries': {}, - }, - } + quota_targets = { + "vgrid": { + "basefs": "lustre", + "entries": {}, + }, + "user": { + "basefs": "lustre", + "entries": {}, + }, + } # Resolve basefs if possible @@ -547,26 +612,29 @@ def update_lustre_quota(configuration): mountpoint = dpart.mountpoint.rstrip(os.sep) fstype = dpart.fstype if mountpoint == configuration.vgrid_files_writable.rstrip(os.sep): - logger.debug("Found basefs for vgrid data: %r : %r" - % (mountpoint, fstype)) - quota_targets['vgrid']['basefs'] = fstype + logger.debug( + "Found basefs for vgrid data: %r : %r" % (mountpoint, fstype) + ) + quota_targets["vgrid"]["basefs"] = fstype if mountpoint == configuration.user_home.rstrip(os.sep): - logger.debug("Found basefs for user data: %r : %r" - % (mountpoint, fstype)) - quota_targets['user']['basefs'] = fstype + logger.debug( + "Found basefs for user data: %r : %r" % (mountpoint, fstype) + ) + quota_targets["user"]["basefs"] = fstype # Resolve vgrids and sub-vgrids for root, dirs, _ in walk(configuration.vgrid_home, topdown=True): for dirent in dirs: vgrid_dirpath = os.path.join(root, dirent) - owners_filepath = os.path.join(vgrid_dirpath, 'owners') + owners_filepath = os.path.join(vgrid_dirpath, "owners") if os.path.isfile(owners_filepath): vgrid = vgrid_flat_name( - vgrid_dirpath[len(configuration.vgrid_home):], - configuration) + vgrid_dirpath[len(configuration.vgrid_home) :], + configuration, + ) logger.debug("Found vgrid: %r" % vgrid) - quota_targets['vgrid']['entries'][vgrid] = 1 + quota_targets["vgrid"]["entries"][vgrid] = 1 # Resolve users @@ -576,45 +644,46 @@ def update_lustre_quota(configuration): userhome = os.readlink(entry.path) # NOTE: Relative links are prefixed with 'user_home' if not userhome.startswith(os.sep): - userhome = os.path.join(configuration.user_home, - userhome) + userhome = os.path.join(configuration.user_home, userhome) else: userhome = entry.path if os.path.isdir(userhome): user = os.path.basename(userhome) else: - logger.debug("skipping non-userhome: %r (%r)" - % (userhome, entry.path)) + logger.debug( + "skipping non-userhome: %r (%r)" % (userhome, entry.path) + ) continue # NOTE: Multiple links might point to same user - quota_targets['user']['entries'][user] = True + quota_targets["user"]["entries"][user] = True # Update quotas for quota_type in quota_targets: target = quota_targets.get(quota_type, {}) - data_basefs = target.get('basefs', 'lustre') - quota_entries = target.get('entries', {}) + data_basefs = target.get("basefs", "lustre") + quota_entries = target.get("entries", {}) for quota_entry in quota_entries: - status = __update_quota(configuration, - lustre_basepath, - lustre_setting, - quota_entry, - quota_type, - data_basefs, - gocryptfs_sock, - timestamp) + status = __update_quota( + configuration, + lustre_basepath, + lustre_setting, + quota_entry, + quota_type, + data_basefs, + gocryptfs_sock, + timestamp, + ) if not status: retval = False # Save updated lustre quota settings - lustre_setting['mtime'] = timestamp - status = pickle(lustre_setting, - lustre_setting_filepath, - logger) + lustre_setting["mtime"] = timestamp + status = pickle(lustre_setting, lustre_setting_filepath, logger) if not status: - logger.error("Failed to save lustra quota settings: %r" - % lustre_setting_filepath) + logger.error( + "Failed to save lustra quota settings: %r" % lustre_setting_filepath + ) return retval diff --git a/mig/lib/quota.py b/mig/lib/quota.py index e93830d28..85b795f2c 100644 --- a/mig/lib/quota.py +++ b/mig/lib/quota.py @@ -30,20 +30,22 @@ from mig.lib.lustrequota import update_lustre_quota - -supported_quota_backends = ['lustre', 'lustre-gocryptfs'] +supported_quota_backends = ["lustre", "lustre-gocryptfs"] def update_quota(configuration): """Update quota for users and vgrids""" retval = False logger = configuration.logger - if configuration.quota_backend == 'lustre' \ - or configuration.quota_backend == 'lustre-gocryptfs': + if ( + configuration.quota_backend == "lustre" + or configuration.quota_backend == "lustre-gocryptfs" + ): retval = update_lustre_quota(configuration) else: - logger.error("quota_backend: %r not in supported_quota_backends: %r" - % (configuration.quota_backend, - supported_quota_backends)) + logger.error( + "quota_backend: %r not in supported_quota_backends: %r" + % (configuration.quota_backend, supported_quota_backends) + ) return retval diff --git a/tests/__init__.py b/tests/__init__.py index bcec2ab8a..d1e480f4d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,10 +1,14 @@ def _print_identity(): import os import sys - python_version_string = sys.version.split(' ')[0] - mig_env = os.environ.get('MIG_ENV', 'local') - print("running with MIG_ENV='%s' under Python %s" % - (mig_env, python_version_string)) + + python_version_string = sys.version.split(" ")[0] + mig_env = os.environ.get("MIG_ENV", "local") + print( + "running with MIG_ENV='%s' under Python %s" + % (mig_env, python_version_string) + ) print("") + _print_identity() diff --git a/tests/support/__init__.py b/tests/support/__init__.py index eb2c8019f..da75f1b7c 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -28,8 +28,6 @@ """Supporting functions for the unit test framework""" -from collections import defaultdict -from configparser import ConfigParser import difflib import errno import io @@ -40,34 +38,40 @@ import shutil import stat import sys +from collections import defaultdict +from configparser import ConfigParser from types import SimpleNamespace -from unittest import TestCase, main as testmain +from unittest import TestCase +from unittest import main as testmain +from tests.support._env import MIG_ENV, PY2 from tests.support.configsupp import FakeConfiguration from tests.support.fixturesupp import _PreparedFixture -from tests.support.suppconst import MIG_BASE, TEST_BASE, \ - TEST_DATA_DIR, TEST_OUTPUT_DIR, ENVHELP_OUTPUT_DIR +from tests.support.suppconst import ( + ENVHELP_OUTPUT_DIR, + MIG_BASE, + TEST_BASE, + TEST_DATA_DIR, + TEST_OUTPUT_DIR, +) from tests.support.usersupp import UserAssertMixin -from tests.support._env import MIG_ENV, PY2 - - # Provide access to a configuration file for the active environment. -if MIG_ENV in ('local', 'docker'): +if MIG_ENV in ("local", "docker"): # force local testconfig - _output_dir = os.path.join(MIG_BASE, 'envhelp/output') + _output_dir = os.path.join(MIG_BASE, "envhelp/output") _conf_dir_name = "testconfs-%s" % (MIG_ENV,) _conf_dir = os.path.join(_output_dir, _conf_dir_name) - _local_conf = os.path.join(_conf_dir, 'MiGserver.conf') - _config_file = os.getenv('MIG_CONF', None) + _local_conf = os.path.join(_conf_dir, "MiGserver.conf") + _config_file = os.getenv("MIG_CONF", None) if _config_file is None: - os.environ['MIG_CONF'] = _local_conf + os.environ["MIG_CONF"] = _local_conf # adjust the link through which confs are accessed to suit the environment - _conf_link = os.path.join(_output_dir, 'testconfs') + _conf_link = os.path.join(_output_dir, "testconfs") assert os.path.lexists(_conf_link) # it must already exist - os.remove(_conf_link) # blow it away + os.remove(_conf_link) # blow it away os.symlink(_conf_dir, _conf_link) # recreate it using the active MIG_BASE else: raise NotImplementedError() @@ -97,7 +101,6 @@ from tests.support.loggersupp import FakeLogger, FakeLoggerChecker from tests.support.serversupp import make_wrapped_server - # Basic global logging configuration for testing @@ -181,7 +184,7 @@ def tearDown(self): @classmethod def tearDownClass(cls): - if MIG_ENV == 'docker': + if MIG_ENV == "docker": # the permissions story wrt running inside docker containers is # such that we can end up with files from previous test runs left # around that might subsequently cause spurious permissions errors @@ -209,20 +212,24 @@ def _reset_logging(self, stream): @staticmethod def _make_configuration_instance(testcase, configuration_to_make): - if configuration_to_make == 'fakeconfig': + if configuration_to_make == "fakeconfig": return FakeConfiguration(logger=testcase.logger) - elif configuration_to_make == 'testconfig': + elif configuration_to_make == "testconfig": from mig.shared.conf import get_configuration_object - configuration = get_configuration_object(skip_log=True, - disable_auth_log=True) + + configuration = get_configuration_object( + skip_log=True, disable_auth_log=True + ) configuration.logger = testcase.logger return configuration else: raise AssertionError( - "MigTestCase: unknown configuration %r" % (configuration_to_make,)) + "MigTestCase: unknown configuration %r" + % (configuration_to_make,) + ) def _provide_configuration(self): - return 'unspecified' + return "unspecified" @property def configuration(self): @@ -233,14 +240,16 @@ def configuration(self): configuration_to_make = self._provide_configuration() - if configuration_to_make == 'unspecified': + if configuration_to_make == "unspecified": raise AssertionError( - "configuration access but testcase did not request it") + "configuration access but testcase did not request it" + ) configuration_instance = self._make_configuration_instance( - self, configuration_to_make) + self, configuration_to_make + ) - if configuration_to_make == 'testconfig': + if configuration_to_make == "testconfig": # use the paths defined by the loaded configuration to create # the directories which are expected to be present by the code os.mkdir(configuration_instance.certs_path) @@ -274,7 +283,8 @@ def assertDirEmpty(self, relative_path): """Make sure the supplied path is an empty directory""" path_kind = self.assertPathExists(relative_path) assert path_kind == "dir", "expected a directory but found %s" % ( - path_kind, ) + path_kind, + ) absolute_path = os.path.join(TEST_OUTPUT_DIR, relative_path) entries = os.listdir(absolute_path) assert not entries, "directory is not empty" @@ -283,7 +293,8 @@ def assertDirNotEmpty(self, relative_path): """Make sure the supplied path is a non-empty directory""" path_kind = self.assertPathExists(relative_path) assert path_kind == "dir", "expected a directory but found %s" % ( - path_kind, ) + path_kind, + ) absolute_path = os.path.join(TEST_OUTPUT_DIR, relative_path) entries = os.listdir(absolute_path) assert entries, "directory is empty" @@ -291,28 +302,35 @@ def assertDirNotEmpty(self, relative_path): def assertFileContentIdentical(self, file_actual, file_expected): """Make sure file_actual and file_expected are identical""" - with io.open(file_actual) as f_actual, io.open(file_expected) as f_expected: + with io.open(file_actual) as f_actual, io.open( + file_expected + ) as f_expected: lhs = f_actual.readlines() rhs = f_expected.readlines() different_lines = list(difflib.unified_diff(rhs, lhs)) try: self.assertEqual(len(different_lines), 0) except AssertionError: - raise AssertionError("""differences found between files + raise AssertionError( + """differences found between files * %s * %s included: %s - """ % ( - os.path.relpath(file_expected, MIG_BASE), - os.path.relpath(file_actual, MIG_BASE), - ''.join(different_lines))) + """ + % ( + os.path.relpath(file_expected, MIG_BASE), + os.path.relpath(file_actual, MIG_BASE), + "".join(different_lines), + ) + ) def assertFileExists(self, relative_path): """Make sure relative_path exists and is a file""" path_kind = self.assertPathExists(relative_path) assert path_kind == "file", "expected a file but found %s" % ( - path_kind, ) + path_kind, + ) return os.path.join(TEST_OUTPUT_DIR, relative_path) def assertPathExists(self, relative_path): @@ -342,13 +360,14 @@ def assertPathWithin(self, path, start=None): """Make sure path is within start directory""" if not is_path_within(path, start=start): raise AssertionError( - "path %s is not within directory %s" % (path, start)) + "path %s is not within directory %s" % (path, start) + ) @staticmethod def pretty_display_path(absolute_path): assert os.path.isabs(absolute_path) relative_path = os.path.relpath(absolute_path, start=MIG_BASE) - assert not relative_path.startswith('..') + assert not relative_path.startswith("..") return relative_path @staticmethod @@ -359,7 +378,9 @@ def _provision_test_user(testcase, distinguished_name): Note that this method, along with a number of others, are defined in the user portion of the test support libraries. """ - return UserAssertMixin._provision_test_user(testcase, distinguished_name) + return UserAssertMixin._provision_test_user( + testcase, distinguished_name + ) def is_path_within(path, start=None, _msg=None): @@ -369,7 +390,7 @@ def is_path_within(path, start=None, _msg=None): relative = os.path.relpath(path, start=start) except: return False - return not relative.startswith('..') + return not relative.startswith("..") def ensure_dirs_exist(absolute_dir): @@ -400,7 +421,7 @@ def temppath(relative_path, test_case, ensure_dir=False): # failsafe path checking that supplied paths are rooted within valid paths is_tmp_path_within_safe_dir = False - for start in (ENVHELP_OUTPUT_DIR): + for start in ENVHELP_OUTPUT_DIR: is_tmp_path_within_safe_dir = is_path_within(tmp_path, start=start) if is_tmp_path_within_safe_dir: break @@ -413,7 +434,8 @@ def temppath(relative_path, test_case, ensure_dir=False): except OSError as oserr: if oserr.errno == errno.EEXIST: raise AssertionError( - "ABORT: use of unclean output path: %s" % tmp_path) + "ABORT: use of unclean output path: %s" % tmp_path + ) return tmp_path diff --git a/tests/support/_env.py b/tests/support/_env.py index 2c71386a4..d86bbd2d3 100644 --- a/tests/support/_env.py +++ b/tests/support/_env.py @@ -2,10 +2,10 @@ import sys # expose the configured environment as a constant -MIG_ENV = os.environ.get('MIG_ENV', 'local') +MIG_ENV = os.environ.get("MIG_ENV", "local") # force the chosen environment globally -os.environ['MIG_ENV'] = MIG_ENV +os.environ["MIG_ENV"] = MIG_ENV # expose a boolean indicating whether we are executing on Python 2 -PY2 = (sys.version_info[0] == 2) +PY2 = sys.version_info[0] == 2 diff --git a/tests/support/assertover.py b/tests/support/assertover.py index 52b445ab2..3e31b7b5b 100644 --- a/tests/support/assertover.py +++ b/tests/support/assertover.py @@ -29,11 +29,13 @@ class NoBlockError(AssertionError): """Decorate AssertionError for our own convenience""" + pass class NoCasesError(AssertionError): """Decorate AssertionError for our own convenience""" + pass @@ -76,10 +78,15 @@ def __exit__(self, exc_type, exc_value, traceback): if not any(self._attempts): return True - value_lines = ["- <%r> : %s" % (attempt[0], str(attempt[1])) for - attempt in self._attempts if attempt] - raise AssertionError("assertions raised for the following values:\n%s" - % '\n'.join(value_lines)) + value_lines = [ + "- <%r> : %s" % (attempt[0], str(attempt[1])) + for attempt in self._attempts + if attempt + ] + raise AssertionError( + "assertions raised for the following values:\n%s" + % "\n".join(value_lines) + ) def record_attempt(self, attempt_info): """Record the result of a test attempt""" @@ -89,7 +96,9 @@ def to_check_callable(self): def raise_unless_consulted(): if not self._consulted: raise AssertionError( - "no examiniation made of assertion of multiple values") + "no examiniation made of assertion of multiple values" + ) + return raise_unless_consulted def assert_success(self): @@ -103,4 +112,7 @@ def _execute_block(cls, block, block_value): block.__call__(block_value) return None except Exception as blockexc: - return (block_value, blockexc,) + return ( + block_value, + blockexc, + ) diff --git a/tests/support/configsupp.py b/tests/support/configsupp.py index 0846e465d..bb8511542 100644 --- a/tests/support/configsupp.py +++ b/tests/support/configsupp.py @@ -27,20 +27,21 @@ """Configuration related details within the test support library.""" -from tests.support.loggersupp import FakeLogger - from mig.shared.compat import SimpleNamespace -from mig.shared.configuration import \ - _CONFIGURATION_ARGUMENTS, _CONFIGURATION_PROPERTIES +from mig.shared.configuration import ( + _CONFIGURATION_ARGUMENTS, + _CONFIGURATION_PROPERTIES, +) +from tests.support.loggersupp import FakeLogger def _ensure_only_configuration_keys(thedict): - """Check a dictionary contains only keys valid as Configuration properties. - """ + """Check a dictionary contains only keys valid as Configuration properties.""" unknown_keys = set(thedict.keys()) - set(_CONFIGURATION_ARGUMENTS) - assert len(unknown_keys) == 0, \ - "non-Configuration keys: %s" % (', '.join(unknown_keys),) + assert len(unknown_keys) == 0, "non-Configuration keys: %s" % ( + ", ".join(unknown_keys), + ) def _generate_namespace_kwargs(): @@ -49,7 +50,7 @@ def _generate_namespace_kwargs(): """ properties_and_defaults = dict(_CONFIGURATION_PROPERTIES) - properties_and_defaults['logger'] = None + properties_and_defaults["logger"] = None return properties_and_defaults diff --git a/tests/support/fixturesupp.py b/tests/support/fixturesupp.py index c418fc7a2..0172a3e43 100644 --- a/tests/support/fixturesupp.py +++ b/tests/support/fixturesupp.py @@ -27,12 +27,12 @@ """Fixture related details within the test support library.""" -from configparser import ConfigParser -from datetime import date, timedelta import json import os import pickle import shutil +from configparser import ConfigParser +from datetime import date, timedelta from time import mktime from types import SimpleNamespace @@ -51,30 +51,35 @@ def _fixturefile_loadrelative(fixture_name, fixture_format=None): assert fixture_format is not None, "fixture format must be specified" relative_path_with_ext = "%s.%s" % (fixture_name, fixture_format) tmp_path = os.path.join(TEST_FIXTURE_DIR, relative_path_with_ext) - assert os.path.isfile(tmp_path), \ - 'fixture named "%s" with format %s is not present: %s' % \ - (fixture_name, fixture_format, relative_path_with_ext) + assert os.path.isfile( + tmp_path + ), 'fixture named "%s" with format %s is not present: %s' % ( + fixture_name, + fixture_format, + relative_path_with_ext, + ) data = None - if fixture_format == 'binary': - with open(tmp_path, 'rb') as binfile: + if fixture_format == "binary": + with open(tmp_path, "rb") as binfile: data = binfile.read() - elif fixture_format == 'json': + elif fixture_format == "json": with open(tmp_path) as jsonfile: data = json.load(jsonfile, object_hook=_FixtureHint.object_hook) _hints_apply_from_instances_if_present(data) _hints_apply_from_fixture_ini_if_present(fixture_name, data) else: raise AssertionError( - "unsupported fixture format: %s" % (fixture_format,)) + "unsupported fixture format: %s" % (fixture_format,) + ) return data, tmp_path -def _fixturefile_normname(relative_path, prefix=''): +def _fixturefile_normname(relative_path, prefix=""): """Grab normname from relative_path and optionally add a path prefix""" - normname, _ = relative_path.split('--') + normname, _ = relative_path.split("--") if prefix: return os.path.join(prefix, normname) return normname @@ -96,6 +101,7 @@ def _fixturefile_normname(relative_path, prefix=''): # # + def _hints_apply_array_of_tuples(value, modifier): """ Convert list of lists such that its values are instead tuples. @@ -109,7 +115,7 @@ def _hints_apply_today_relative(value, modifier): Geneate a time value by applying a declared delta to today's date. """ - kind, delta = modifier.split('|') + kind, delta = modifier.split("|") if kind == "days": time_delta = timedelta(days=int(delta)) adjusted_datetime = date.today() + time_delta @@ -131,15 +137,17 @@ def _hints_apply_dict_bytes_to_strings_kv(input_dict, modifier): for k, v in input_dict.items(): key_to_use = k if isinstance(k, bytes): - key_to_use = str(k, 'utf8') + key_to_use = str(k, "utf8") if isinstance(v, dict): - output_dict[key_to_use] = _hints_apply_dict_bytes_to_strings_kv(v, modifier) + output_dict[key_to_use] = _hints_apply_dict_bytes_to_strings_kv( + v, modifier + ) continue val_to_use = v if isinstance(v, bytes): - val_to_use = str(v, 'utf8') + val_to_use = str(v, "utf8") output_dict[key_to_use] = val_to_use @@ -159,15 +167,17 @@ def _hints_apply_dict_strings_to_bytes_kv(input_dict, modifier): for k, v in input_dict.items(): key_to_use = k if isinstance(k, str): - key_to_use = bytes(k, 'utf8') + key_to_use = bytes(k, "utf8") if isinstance(v, dict): - output_dict[key_to_use] = _hints_apply_dict_strings_to_bytes_kv(v, modifier) + output_dict[key_to_use] = _hints_apply_dict_strings_to_bytes_kv( + v, modifier + ) continue val_to_use = v if isinstance(v, str): - val_to_use = bytes(v, 'utf8') + val_to_use = bytes(v, "utf8") output_dict[key_to_use] = val_to_use @@ -176,21 +186,21 @@ def _hints_apply_dict_strings_to_bytes_kv(input_dict, modifier): # hints that can be aplied without an additional modifier argument _HINTS_APPLIERS_ARGLESS = { - 'array_of_tuples': _hints_apply_array_of_tuples, - 'today_relative': _hints_apply_today_relative, - 'convert_dict_bytes_to_strings_kv': _hints_apply_dict_bytes_to_strings_kv, - 'convert_dict_strings_to_bytes_kv': _hints_apply_dict_strings_to_bytes_kv, + "array_of_tuples": _hints_apply_array_of_tuples, + "today_relative": _hints_apply_today_relative, + "convert_dict_bytes_to_strings_kv": _hints_apply_dict_bytes_to_strings_kv, + "convert_dict_strings_to_bytes_kv": _hints_apply_dict_strings_to_bytes_kv, } # hints applicable to the conversion of attributes during fixture loading _FIXTUREFILE_APPLIERS_ATTRIBUTES = { - 'array_of_tuples': _hints_apply_array_of_tuples, - 'today_relative': _hints_apply_today_relative, + "array_of_tuples": _hints_apply_array_of_tuples, + "today_relative": _hints_apply_today_relative, } # hints applied when writing the contents of a fixture as a temporary file _FIXTUREFILE_APPLIERS_ONWRITE = { - 'convert_dict_strings_to_bytes_kv': _hints_apply_dict_strings_to_bytes_kv, + "convert_dict_strings_to_bytes_kv": _hints_apply_dict_strings_to_bytes_kv, } @@ -222,7 +232,7 @@ def _load_hints_ini_for_fixture_if_present(fixture_name): pass # ensure empty required fixture to avoid extra conditionals later - for required_section in ['ATTRIBUTES']: + for required_section in ["ATTRIBUTES"]: if not hints.has_section(required_section): hints.add_section(required_section) @@ -239,10 +249,10 @@ def _hints_apply_from_fixture_ini_if_present(fixture_name, json_object): # apply any attriutes hints ahead of specified conversions such that any # key can be specified matching what is visible within the loaded fixture - for item_name, item_hint_unparsed in hints['ATTRIBUTES'].items(): + for item_name, item_hint_unparsed in hints["ATTRIBUTES"].items(): loaded_value = json_object[item_name] - item_hint_and_maybe_modifier = item_hint_unparsed.split('--') + item_hint_and_maybe_modifier = item_hint_unparsed.split("--") item_hint = item_hint_and_maybe_modifier[0] if len(item_hint_and_maybe_modifier) == 2: modifier = item_hint_and_maybe_modifier[1] @@ -267,7 +277,9 @@ def __init__(self, hint=None, modifier=None, value=None): def decode_hint(hint_obj): """Produce a value based on the properties of a hint instance.""" assert isinstance(hint_obj, _FixtureHint) - value_from_loaded_value = _FIXTUREFILE_APPLIERS_ATTRIBUTES[hint_obj.hint] + value_from_loaded_value = _FIXTUREFILE_APPLIERS_ATTRIBUTES[ + hint_obj.hint + ] return value_from_loaded_value(hint_obj.value, hint_obj.modifier) @staticmethod @@ -278,11 +290,14 @@ def object_hook(decoded_object): """ if "_FixtureHint" in decoded_object: - fixture_hint = _FixtureHint(decoded_object["hint"], decoded_object["modifier"]) + fixture_hint = _FixtureHint( + decoded_object["hint"], decoded_object["modifier"] + ) return _FixtureHint.decode_hint(fixture_hint) return decoded_object + # @@ -295,7 +310,7 @@ def fixturepath(relative_path): def _to_display_path(value): """Convert an absolute path to one to be shown as part of test output.""" display_path = os.path.relpath(value, MIG_BASE) - if not display_path.startswith('.'): + if not display_path.startswith("."): return "./" + display_path return display_path @@ -307,10 +322,9 @@ class _PreparedFixture: NO_DATA = object() - def __init__(self, testcase, - fixture_name, - fixture_format='', - fixture_data=NO_DATA): + def __init__( + self, testcase, fixture_name, fixture_format="", fixture_data=NO_DATA + ): self.testcase = testcase self.fixture_name = fixture_name self.fixture_format = fixture_format @@ -336,9 +350,12 @@ def assertAgainstFixture(self, value): if self.fixture_format: message_infix = " with format %s" % (self.fixture_format,) else: - message_infix = '' + message_infix = "" message = "value differed from fixture named %s%s\n\n%s" % ( - self.fixture_name, message_infix, raised_exception) + self.fixture_name, + message_infix, + raised_exception, + ) raise AssertionError(message) def write_to_dir(self, target_dir, output_format=None): @@ -347,42 +364,47 @@ def write_to_dir(self, target_dir, output_format=None): directory applying any onwrite hints that may be specified. """ - assert self.fixture_data is not self.NO_DATA, \ - "fixture is not populated with data" + assert ( + self.fixture_data is not self.NO_DATA + ), "fixture is not populated with data" assert os.path.isabs(target_dir) # convert fixture name (which includes the varaint) to the target file - fixture_file_target = _fixturefile_normname(self.fixture_name, prefix=target_dir) + fixture_file_target = _fixturefile_normname( + self.fixture_name, prefix=target_dir + ) output_data = self.fixture_data # now apply any onwrite conversions hints = _load_hints_ini_for_fixture_if_present(self.fixture_name) - for item_name in hints['ONWRITE']: + for item_name in hints["ONWRITE"]: if item_name not in _FIXTUREFILE_APPLIERS_ONWRITE: raise AssertionError( - "unsupported fixture conversion: %s" % (item_name,)) + "unsupported fixture conversion: %s" % (item_name,) + ) - enabled = hints.getboolean('ONWRITE', item_name) + enabled = hints.getboolean("ONWRITE", item_name) if not enabled: continue hint_fn = _FIXTUREFILE_APPLIERS_ONWRITE[item_name] output_data = hint_fn(output_data, None) - if output_format == 'binary': - with open(fixture_file_target, 'wb') as fixture_outputfile: + if output_format == "binary": + with open(fixture_file_target, "wb") as fixture_outputfile: fixture_outputfile.write(output_data) - elif output_format == 'json': - with open(fixture_file_target, 'w') as fixture_outputfile: + elif output_format == "json": + with open(fixture_file_target, "w") as fixture_outputfile: json.dump(output_data, fixture_outputfile) - elif output_format == 'pickle': - with open(fixture_file_target, 'wb') as fixture_outputfile: + elif output_format == "pickle": + with open(fixture_file_target, "wb") as fixture_outputfile: pickle.dump(output_data, fixture_outputfile) else: raise AssertionError( - "unsupported fixture format: %s" % (output_format,)) + "unsupported fixture format: %s" % (output_format,) + ) @staticmethod def from_relpath(testcase, fixture_name, fixture_format): @@ -392,11 +414,16 @@ def from_relpath(testcase, fixture_name, fixture_format): """ fixture_data, fixture_path = _fixturefile_loadrelative( - fixture_name, fixture_format) - return _PreparedFixture(testcase, fixture_name, fixture_format, fixture_data) + fixture_name, fixture_format + ) + return _PreparedFixture( + testcase, fixture_name, fixture_format, fixture_data + ) class FixtureAssertMixin: def prepareFixtureAssert(self, fixture_relpath, fixture_format=None): """Prepare to assert a value against a fixture.""" - return _PreparedFixture.from_relpath(self, fixture_relpath, fixture_format) + return _PreparedFixture.from_relpath( + self, fixture_relpath, fixture_format + ) diff --git a/tests/support/loggersupp.py b/tests/support/loggersupp.py index b1eb2c295..5de13fb8a 100644 --- a/tests/support/loggersupp.py +++ b/tests/support/loggersupp.py @@ -28,9 +28,9 @@ """Logger related details within the test support library.""" -from collections import defaultdict import os import re +from collections import defaultdict from tests.support.suppconst import MIG_BASE, TEST_BASE @@ -47,7 +47,8 @@ class FakeLogger: """ RE_UNCLOSEDFILE = re.compile( - 'unclosed file <.*? name=\'(?P.*?)\'( .*?)?>') + "unclosed file <.*? name='(?P.*?)'( .*?)?>" + ) def __init__(self): self.channels_dict = defaultdict(list) @@ -70,13 +71,19 @@ def check_empty_and_reset(self): # complain loudly (and in detail) in the case of unclosed files if len(unclosed_by_file) > 0: - messages = '\n'.join({' --> %s: line=%s, file=%s' % (fname, lineno, outname) - for fname, (lineno, outname) in unclosed_by_file.items()}) - raise RuntimeError('unclosed files encountered:\n%s' % (messages,)) - - if channels_dict['error'] and not forgive_by_channel['error']: - raise RuntimeError('errors reported to logger:\n%s' % - '\n'.join(channels_dict['error'])) + messages = "\n".join( + { + " --> %s: line=%s, file=%s" % (fname, lineno, outname) + for fname, (lineno, outname) in unclosed_by_file.items() + } + ) + raise RuntimeError("unclosed files encountered:\n%s" % (messages,)) + + if channels_dict["error"] and not forgive_by_channel["error"]: + raise RuntimeError( + "errors reported to logger:\n%s" + % "\n".join(channels_dict["error"]) + ) def forgive_errors(self): """Allow log errors for cases where they are expected""" @@ -90,26 +97,26 @@ def forgive_messages_on(self, *, channel_name=None): def debug(self, line): """Mock log action of same name""" - self._append_as('debug', line) + self._append_as("debug", line) def error(self, line): """Mock log action of same name""" - self._append_as('error', line) + self._append_as("error", line) def info(self, line): """Mock log action of same name""" - self._append_as('info', line) + self._append_as("info", line) def warning(self, line): """Mock log action of same name""" - self._append_as('warning', line) + self._append_as("warning", line) def write(self, message): """Actual write handler""" - channel, namespace, specifics = message.split(':', 2) + channel, namespace, specifics = message.split(":", 2) # ignore everything except warnings sent by the python runtime - if not (channel == 'WARNING' and namespace == 'py.warnings'): + if not (channel == "WARNING" and namespace == "py.warnings"): return filename_and_datatuple = FakeLogger.identify_unclosed_file(specifics) @@ -119,10 +126,10 @@ def write(self, message): @staticmethod def identify_unclosed_file(specifics): """Warn about unclosed files""" - filename, lineno, exc_name, message = specifics.split(':', 3) + filename, lineno, exc_name, message = specifics.split(":", 3) exc_name = exc_name.lstrip() - if exc_name != 'ResourceWarning': + if exc_name != "ResourceWarning": return matched = FakeLogger.RE_UNCLOSEDFILE.match(message.lstrip()) @@ -131,7 +138,8 @@ def identify_unclosed_file(specifics): relative_testfile = os.path.relpath(filename, start=MIG_BASE) relative_outputfile = os.path.relpath( - matched.groups('location')[0], start=TEST_BASE) + matched.groups("location")[0], start=TEST_BASE + ) return (relative_testfile, (lineno, relative_outputfile)) diff --git a/tests/support/picklesupp.py b/tests/support/picklesupp.py index 667dd4b01..262e4c50c 100644 --- a/tests/support/picklesupp.py +++ b/tests/support/picklesupp.py @@ -29,8 +29,8 @@ import pickle -from tests.support.suppconst import TEST_OUTPUT_DIR from tests.support.fixturesupp import _HINTS_APPLIERS_ARGLESS +from tests.support.suppconst import TEST_OUTPUT_DIR class PickleAssertMixin: @@ -44,7 +44,7 @@ def assertPickledFile(self, pickle_file_path, apply_hints=None): having been optionally transformed as requested by hints. """ - with open(pickle_file_path, 'rb') as picklefile: + with open(pickle_file_path, "rb") as picklefile: pickled = pickle.load(picklefile) if not apply_hints: diff --git a/tests/support/serversupp.py b/tests/support/serversupp.py index 0e0fd4b94..bdca0fa33 100644 --- a/tests/support/serversupp.py +++ b/tests/support/serversupp.py @@ -27,7 +27,8 @@ """Server threading related details within the test support library""" -from threading import Thread, Event as ThreadEvent +from threading import Event as ThreadEvent +from threading import Thread class ServerWithinThreadExecutor: @@ -50,7 +51,7 @@ def run(self): """Mimic the same method from the standard thread API""" server_args, server_kwargs = self._arguments - server_kwargs['on_start'] = lambda _: self._started.set() + server_kwargs["on_start"] = lambda _: self._started.set() self._wrapped = self._serverclass(*server_args, **server_kwargs) diff --git a/tests/support/snapshotsupp.py b/tests/support/snapshotsupp.py index 355095814..a24bb2238 100644 --- a/tests/support/snapshotsupp.py +++ b/tests/support/snapshotsupp.py @@ -28,14 +28,14 @@ import difflib import errno -import re import os +import re from tests.support.suppconst import TEST_BASE -HTML_TAG = '' -MARKER_CONTENT_BEGIN = '' -MARKER_CONTENT_END = '' +HTML_TAG = "" +MARKER_CONTENT_BEGIN = "" +MARKER_CONTENT_END = "" TEST_SNAPSHOTS_DIR = os.path.join(TEST_BASE, "snapshots") try: @@ -57,9 +57,9 @@ def _html_content_only(value): # set the index after the content marker content_start_index += len(MARKER_CONTENT_BEGIN) # we now need to remove the container div inside it ..first find it - content_start_inner_div = value.find('', content_start_inner_div) + 1 + content_start_index = value.find(">", content_start_inner_div) + 1 content_end_index = value.find(MARKER_CONTENT_END) assert content_end_index > -1, "unable to locate end of content" @@ -77,7 +77,7 @@ def _delimited_lines(value): lines = [] while from_index < last_index: - found_index = value.find('\n', from_index) + found_index = value.find("\n", from_index) if found_index == -1: break found_index += 1 @@ -93,8 +93,8 @@ def _delimited_lines(value): def _force_refresh_snapshots(): """Check whether the environment specifies snapshots should be refreshed.""" - env_refresh_snapshots = os.environ.get('REFRESH_SNAPSHOTS', 'no').lower() - return env_refresh_snapshots in ('true', 'yes', '1') + env_refresh_snapshots = os.environ.get("REFRESH_SNAPSHOTS", "no").lower() + return env_refresh_snapshots in ("true", "yes", "1") class SnapshotAssertMixin: @@ -107,7 +107,7 @@ def _snapshotsupp_compare_snapshot(self, extension, actual_content): In the case a snapshot does not exist it is saved on first invocation. """ - file_name = ''.join([self._testMethodName, ".", extension]) + file_name = "".join([self._testMethodName, ".", extension]) file_path = os.path.join(TEST_SNAPSHOTS_DIR, file_name) if not os.path.isfile(file_path) or _force_refresh_snapshots(): @@ -126,11 +126,12 @@ def _snapshotsupp_compare_snapshot(self, extension, actual_content): udiff = difflib.unified_diff( _delimited_lines(expected_content), _delimited_lines(actual_content), - 'expected', - 'actual' + "expected", + "actual", ) raise AssertionError( - "content did not match snapshot\n\n%s" % (''.join(udiff),)) + "content did not match snapshot\n\n%s" % ("".join(udiff),) + ) def assertSnapshot(self, actual_content, extension=None): """Load a snapshot corresponding to the named test and check that what @@ -148,4 +149,4 @@ def assertSnapshotOfHtmlContent(self, actual_content): """ actual_content = _html_content_only(actual_content) - self._snapshotsupp_compare_snapshot('html', actual_content) + self._snapshotsupp_compare_snapshot("html", actual_content) diff --git a/tests/support/suppconst.py b/tests/support/suppconst.py index 148303f0d..204b0ba29 100644 --- a/tests/support/suppconst.py +++ b/tests/support/suppconst.py @@ -29,11 +29,11 @@ from tests.support._env import MIG_ENV -if MIG_ENV == 'local': +if MIG_ENV == "local": # Use abspath for __file__ on Py2 _SUPPORT_DIR = os.path.dirname(os.path.abspath(__file__)) -elif MIG_ENV == 'docker': - _SUPPORT_DIR = '/usr/src/app/tests/support' +elif MIG_ENV == "docker": + _SUPPORT_DIR = "/usr/src/app/tests/support" else: raise NotImplementedError("ABORT: unsupported environment: %s" % (MIG_ENV,)) @@ -46,7 +46,8 @@ ENVHELP_OUTPUT_DIR = os.path.join(ENVHELP_DIR, "output") -if __name__ == '__main__': +if __name__ == "__main__": + def print_root_relative(prefix, path): print("%s = /%s" % (prefix, os.path.relpath(path, MIG_BASE))) diff --git a/tests/support/usersupp.py b/tests/support/usersupp.py index 65a02fa4a..849c98f92 100644 --- a/tests/support/usersupp.py +++ b/tests/support/usersupp.py @@ -33,15 +33,11 @@ import pickle from mig.shared.base import client_id_dir - from tests.support.fixturesupp import _PreparedFixture +TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" -TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' - -_FIXTURE_NAME_BY_USER_DN = { - TEST_USER_DN: 'MiG-users.db--example' -} +_FIXTURE_NAME_BY_USER_DN = {TEST_USER_DN: "MiG-users.db--example"} class UserAssertMixin: @@ -59,9 +55,9 @@ def _provision_user_db_dir(testcase): conf_user_db_home = testcase.configuration.user_db_home os.makedirs(conf_user_db_home, exist_ok=True) - user_db_file = os.path.join(conf_user_db_home, 'MiG-users.db') + user_db_file = os.path.join(conf_user_db_home, "MiG-users.db") if os.path.exists(user_db_file): - raise AssertionError('a user database file already exists') + raise AssertionError("a user database file already exists") return conf_user_db_home @@ -79,19 +75,19 @@ def _provision_test_user(testcase, distinguished_name): try: fixture_relpath = _FIXTURE_NAME_BY_USER_DN[distinguished_name] except KeyError: - raise AssertionError('supplied test user is not known as a fixture') + raise AssertionError("supplied test user is not known as a fixture") # note: this is a non-standard direct use of fixture preparation due # to this being bootstrap code and should not be used elsewhere prepared_fixture = _PreparedFixture.from_relpath( - testcase, - fixture_relpath, - fixture_format='json' + testcase, fixture_relpath, fixture_format="json" ) # write out the user database fixture containing the user - prepared_fixture.write_to_dir(conf_user_db_home, output_format='pickle') + prepared_fixture.write_to_dir(conf_user_db_home, output_format="pickle") - test_user_dir = UserAssertMixin._provision_test_user_dirs(testcase, distinguished_name) + test_user_dir = UserAssertMixin._provision_test_user_dirs( + testcase, distinguished_name + ) return test_user_dir @@ -112,12 +108,16 @@ def _provision_test_user_dirs(testcase, distinguished_name): # create the test user settings directory conf_user_settings = os.path.normpath(self.configuration.user_settings) - test_user_settings_dir = os.path.join(conf_user_settings, test_client_dir_name) + test_user_settings_dir = os.path.join( + conf_user_settings, test_client_dir_name + ) os.makedirs(test_user_settings_dir) # create an empty user settings file - test_user_settings_file = os.path.join(test_user_settings_dir, 'settings') - with open(test_user_settings_file, 'wb') as outfile: + test_user_settings_file = os.path.join( + test_user_settings_dir, "settings" + ) + with open(test_user_settings_file, "wb") as outfile: pickle.dump({}, outfile) return test_user_dir @@ -154,9 +154,11 @@ def _provision_test_users(testcase, *distinguished_names): # write out all the users we have assembled by populating an empty # fixture with their data but using a known fixture name and thus one # suitably hinted so a production format pickle file ends up on-disk - prepared_fixture = _PreparedFixture(testcase, 'MiG-users.db--example') + prepared_fixture = _PreparedFixture(testcase, "MiG-users.db--example") prepared_fixture.fixture_data = users_by_dn - prepared_fixture.write_to_dir(conf_user_db_home, output_format='pickle') + prepared_fixture.write_to_dir(conf_user_db_home, output_format="pickle") for distinguished_name in distinguished_names: - UserAssertMixin._provision_test_user_dirs(testcase, distinguished_name) + UserAssertMixin._provision_test_user_dirs( + testcase, distinguished_name + ) diff --git a/tests/support/wsgisupp.py b/tests/support/wsgisupp.py index 1105d0db8..2a6a209d2 100644 --- a/tests/support/wsgisupp.py +++ b/tests/support/wsgisupp.py @@ -27,22 +27,21 @@ """Test support library for WSGI.""" -from collections import namedtuple import codecs +from collections import namedtuple from io import BytesIO from urllib.parse import urlencode, urlparse from werkzeug.datastructures import MultiDict - # named type representing the tuple that is passed to WSGI handlers -_PreparedWsgi = namedtuple('_PreparedWsgi', ['environ', 'start_response']) +_PreparedWsgi = namedtuple("_PreparedWsgi", ["environ", "start_response"]) class FakeWsgiStartResponse: """Glue object that conforms to the same interface as the start_response() - in the WSGI specs but records the calls to it such that they can be - inspected and, for our purposes, asserted against.""" + in the WSGI specs but records the calls to it such that they can be + inspected and, for our purposes, asserted against.""" def __init__(self): self.calls = [] @@ -51,7 +50,9 @@ def __call__(self, status, headers, exc=None): self.calls.append((status, headers, exc)) -def create_wsgi_environ(configuration, wsgi_url, method='GET', query=None, headers=None, form=None): +def create_wsgi_environ( + configuration, wsgi_url, method="GET", query=None, headers=None, form=None +): """Populate the necessary variables that will constitute a valid WSGI environment given a URL to which we will make a requests under test and various other options that set up the nature of that request.""" @@ -59,21 +60,21 @@ def create_wsgi_environ(configuration, wsgi_url, method='GET', query=None, heade parsed_url = urlparse(wsgi_url) if query: - method = 'GET' + method = "GET" request_query = urlencode(query) wsgi_input = () elif form: - method = 'POST' - request_query = '' + method = "POST" + request_query = "" - body = urlencode(MultiDict(form)).encode('ascii') + body = urlencode(MultiDict(form)).encode("ascii") headers = headers or {} - if not 'Content-Type' in headers: - headers['Content-Type'] = 'application/x-www-form-urlencoded' + if not "Content-Type" in headers: + headers["Content-Type"] = "application/x-www-form-urlencoded" - headers['Content-Length'] = str(len(body)) + headers["Content-Length"] = str(len(body)) wsgi_input = BytesIO(body) else: request_query = parsed_url.query @@ -83,26 +84,27 @@ class _errors: """Internal helper to ignore wsgi.errors close method calls""" def close(self, *ars, **kwargs): - """"Simply ignore""" + """ "Simply ignore""" pass environ = {} - environ['wsgi.errors'] = _errors() - environ['wsgi.input'] = wsgi_input - environ['wsgi.url_scheme'] = parsed_url.scheme - environ['wsgi.version'] = (1, 0) - environ['MIG_CONF'] = configuration.config_file - environ['HTTP_HOST'] = parsed_url.netloc - environ['PATH_INFO'] = parsed_url.path - environ['QUERY_STRING'] = request_query - environ['REQUEST_METHOD'] = method - environ['SCRIPT_URI'] = ''.join( - ('http://', environ['HTTP_HOST'], environ['PATH_INFO'])) + environ["wsgi.errors"] = _errors() + environ["wsgi.input"] = wsgi_input + environ["wsgi.url_scheme"] = parsed_url.scheme + environ["wsgi.version"] = (1, 0) + environ["MIG_CONF"] = configuration.config_file + environ["HTTP_HOST"] = parsed_url.netloc + environ["PATH_INFO"] = parsed_url.path + environ["QUERY_STRING"] = request_query + environ["REQUEST_METHOD"] = method + environ["SCRIPT_URI"] = "".join( + ("http://", environ["HTTP_HOST"], environ["PATH_INFO"]) + ) if headers: for k, v in headers.items(): - header_key = k.replace('-', '_').upper() - if header_key.startswith('CONTENT'): + header_key = k.replace("-", "_").upper() + if header_key.startswith("CONTENT"): # Content-* headers must not be prefixed in WSGI pass else: @@ -119,15 +121,15 @@ def create_wsgi_start_response(): def prepare_wsgi(configuration, url, **kwargs): return _PreparedWsgi( create_wsgi_environ(configuration, url, **kwargs), - create_wsgi_start_response() + create_wsgi_start_response(), ) def _trigger_and_unpack_result(wsgi_result): chunks = list(wsgi_result) assert len(chunks) > 0, "invocation returned no output" - complete_value = b''.join(chunks) - decoded_value = codecs.decode(complete_value, 'utf8') + complete_value = b"".join(chunks) + decoded_value = codecs.decode(complete_value, "utf8") return decoded_value @@ -140,7 +142,7 @@ def assertWsgiResponse(self, wsgi_result, fake_wsgi, expected_status_code): content = _trigger_and_unpack_result(wsgi_result) def called_once(fake): - assert hasattr(fake, 'calls') + assert hasattr(fake, "calls") return len(fake.calls) == 1 fake_start_response = fake_wsgi.start_response diff --git a/tests/test_booleans.py b/tests/test_booleans.py index 5246197ee..a3564b8ff 100644 --- a/tests/test_booleans.py +++ b/tests/test_booleans.py @@ -2,6 +2,7 @@ from tests.support import MigTestCase, testmain + class TestBooleans(MigTestCase): def test_true(self): self.assertEqual(True, True) @@ -10,5 +11,5 @@ def test_false(self): self.assertEqual(False, False) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_install_generateconfs.py b/tests/test_mig_install_generateconfs.py index e8ed90241..c68318cd9 100644 --- a/tests/test_mig_install_generateconfs.py +++ b/tests/test_mig_install_generateconfs.py @@ -33,13 +33,13 @@ import os import sys -from tests.support import MIG_BASE, MigTestCase, testmain, cleanpath +from tests.support import MIG_BASE, MigTestCase, cleanpath, testmain def _import_generateconfs(): """Internal helper to work around non-package import location""" - sys.path.append(os.path.join(MIG_BASE, 'mig/install')) - mod = importlib.import_module('generateconfs') + sys.path.append(os.path.join(MIG_BASE, "mig/install")) + mod = importlib.import_module("generateconfs") sys.path.pop(-1) return mod @@ -51,12 +51,14 @@ def _import_generateconfs(): def create_fake_generate_confs(return_dict=None): """Fake generate confs helper""" + def _generate_confs(*args, **kwargs): _generate_confs.settings = kwargs if return_dict: return (return_dict, {}) else: return ({}, {}) + _generate_confs.settings = None return _generate_confs @@ -69,52 +71,64 @@ class MigInstallGenerateconfs__main(MigTestCase): """Unit test helper for the migrid code pointed to in class name""" def test_option_permanent_freeze(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--permanent_freeze', 'yes'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--permanent_freeze", "yes"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) self.assertEqual(exit_code, 0) def test_option_storage_protocols(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--storage_protocols', 'proto1 proto2 proto3'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--storage_protocols", "proto1 proto2 proto3"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) self.assertEqual(exit_code, 0) settings = fake_generate_confs.settings - self.assertIn('storage_protocols', settings) - self.assertEqual(settings['storage_protocols'], 'proto1 proto2 proto3') + self.assertIn("storage_protocols", settings) + self.assertEqual(settings["storage_protocols"], "proto1 proto2 proto3") def test_option_wwwserve_max_bytes(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--wwwserve_max_bytes', '43211234'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--wwwserve_max_bytes", "43211234"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) settings = fake_generate_confs.settings - self.assertIn('wwwserve_max_bytes', settings) - self.assertEqual(settings['wwwserve_max_bytes'], 43211234) + self.assertIn("wwwserve_max_bytes", settings) + self.assertEqual(settings["wwwserve_max_bytes"], 43211234) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_lib_accounting.py b/tests/test_mig_lib_accounting.py index fc0003fc4..9eb60a931 100644 --- a/tests/test_mig_lib_accounting.py +++ b/tests/test_mig_lib_accounting.py @@ -30,8 +30,11 @@ import os import pickle -from mig.lib.accounting import get_usage, human_readable_filesize, \ - update_accounting +from mig.lib.accounting import ( + get_usage, + human_readable_filesize, + update_accounting, +) from mig.shared.base import client_id_dir from mig.shared.defaults import peers_filename from tests.support import MigTestCase, ensure_dirs_exist @@ -39,71 +42,90 @@ TEST_MTIME = 1768925307 TEST_SOFTLIMIT_BYTES = 109951162777600 TEST_HARDLIMIT_BYTES = 109951162777600 -TEST_CLIENT_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@user.com' +TEST_CLIENT_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@user.com" +) TEST_CLIENT_BYTES = 206128256 -TEST_EXT_DN = '/C=DK/ST=NA/L=NA/O=PEER Org/OU=NA/CN=Test Peer/emailAddress=peer@example.com' +TEST_EXT_DN = "/C=DK/ST=NA/L=NA/O=PEER Org/OU=NA/CN=Test Peer/emailAddress=peer@example.com" TEST_EXT_BYTES = 16806128256 TEST_FREEZE_BYTES = 128256 -TEST_VGRID_NAME1 = 'TestVgrid1' +TEST_VGRID_NAME1 = "TestVgrid1" TEST_VGRID_BYTES1 = 406128256 -TEST_VGRID_NAME2 = 'TestVgrid2' +TEST_VGRID_NAME2 = "TestVgrid2" TEST_VGRID_BYTES2 = 606128256 -TEST_VGRID_NAME3 = 'TestVgrid3' +TEST_VGRID_NAME3 = "TestVgrid3" TEST_VGRID_BYTES3 = 806128256 -TEST_VGRID_TOTAL_BYTES = TEST_VGRID_BYTES1 \ - + TEST_VGRID_BYTES2 \ - + TEST_VGRID_BYTES3 -TEST_TOTAL_BYTES = TEST_CLIENT_BYTES \ - + TEST_EXT_BYTES \ - + TEST_FREEZE_BYTES \ +TEST_VGRID_TOTAL_BYTES = ( + TEST_VGRID_BYTES1 + TEST_VGRID_BYTES2 + TEST_VGRID_BYTES3 +) +TEST_TOTAL_BYTES = ( + TEST_CLIENT_BYTES + + TEST_EXT_BYTES + + TEST_FREEZE_BYTES + TEST_VGRID_TOTAL_BYTES -TEST_LUSTRE_QUOTA_INFO = {'next_pid': 192, 'mtime': TEST_MTIME} -TEST_CLIENT_USAGE = {'lustre_pid': 42, - 'files': 11, - 'bytes': TEST_CLIENT_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE1 = {'lustre_pid': 43, - 'files': 111, - 'bytes': TEST_VGRID_BYTES1, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE2 = {'lustre_pid': 44, - 'files': 222, - 'bytes': TEST_VGRID_BYTES2, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE3 = {'lustre_pid': 45, - 'files': 333, - 'bytes': TEST_VGRID_BYTES3, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_EXT_USAGE = {'lustre_pid': 46, - 'files': 1, - 'bytes': TEST_EXT_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_FREEZE_USAGE = {'lustre_pid': 47, - 'files': 1, - 'bytes': TEST_FREEZE_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_PEERS = {TEST_EXT_DN: {'kind': 'collaboration', - 'distinguished_name': TEST_EXT_DN, - 'country': 'DK', - 'label': 'TEST', - 'state': '', - 'expire': '2222-12-31', - 'full_name': 'Test Peer', - 'organization': 'PEER Org', - 'email': 'peer@example.com' - }} +) +TEST_LUSTRE_QUOTA_INFO = {"next_pid": 192, "mtime": TEST_MTIME} +TEST_CLIENT_USAGE = { + "lustre_pid": 42, + "files": 11, + "bytes": TEST_CLIENT_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE1 = { + "lustre_pid": 43, + "files": 111, + "bytes": TEST_VGRID_BYTES1, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE2 = { + "lustre_pid": 44, + "files": 222, + "bytes": TEST_VGRID_BYTES2, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE3 = { + "lustre_pid": 45, + "files": 333, + "bytes": TEST_VGRID_BYTES3, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_EXT_USAGE = { + "lustre_pid": 46, + "files": 1, + "bytes": TEST_EXT_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_FREEZE_USAGE = { + "lustre_pid": 47, + "files": 1, + "bytes": TEST_FREEZE_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_PEERS = { + TEST_EXT_DN: { + "kind": "collaboration", + "distinguished_name": TEST_EXT_DN, + "country": "DK", + "label": "TEST", + "state": "", + "expire": "2222-12-31", + "full_name": "Test Peer", + "organization": "PEER Org", + "email": "peer@example.com", + } +} class MigLibAccounting(MigTestCase): @@ -111,7 +133,7 @@ class MigLibAccounting(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up test configuration and reset state before each test""" @@ -120,15 +142,17 @@ def before_each(self): self.configuration.site_enable_quota = True self.configuration.site_enable_accounting = True - self.configuration.quota_backend = 'lustre' - - quota_basepath = os.path.join(self.configuration.quota_home, - self.configuration.quota_backend) - quota_user_path = os.path.join(quota_basepath, 'user') - quota_vgrid_path = os.path.join(quota_basepath, 'vgrid') - quota_freeze_path = os.path.join(quota_basepath, 'freeze') - test_client_peers_path = os.path.join(self.configuration.user_settings, - client_id_dir(TEST_CLIENT_DN)) + self.configuration.quota_backend = "lustre" + + quota_basepath = os.path.join( + self.configuration.quota_home, self.configuration.quota_backend + ) + quota_user_path = os.path.join(quota_basepath, "user") + quota_vgrid_path = os.path.join(quota_basepath, "vgrid") + quota_freeze_path = os.path.join(quota_basepath, "freeze") + test_client_peers_path = os.path.join( + self.configuration.user_settings, client_id_dir(TEST_CLIENT_DN) + ) ensure_dirs_exist(self.configuration.vgrid_home) ensure_dirs_exist(self.configuration.user_settings) @@ -141,61 +165,69 @@ def before_each(self): # Ensure fake vgrid and write owner - for vgrid_name in [TEST_VGRID_NAME1, - TEST_VGRID_NAME2, - TEST_VGRID_NAME3]: + for vgrid_name in [ + TEST_VGRID_NAME1, + TEST_VGRID_NAME2, + TEST_VGRID_NAME3, + ]: vgrid_home_path = os.path.join( - self.configuration.vgrid_home, vgrid_name) + self.configuration.vgrid_home, vgrid_name + ) ensure_dirs_exist(vgrid_home_path) - vgrid_owners_filepath = os.path.join(vgrid_home_path, 'owners') - with open(vgrid_owners_filepath, 'wb') as fh: + vgrid_owners_filepath = os.path.join(vgrid_home_path, "owners") + with open(vgrid_owners_filepath, "wb") as fh: fh.write(pickle.dumps([TEST_CLIENT_DN])) # Write fake quota - test_lustre_quota_info_filepath \ - = os.path.join(self.configuration.quota_home, - '%s.pck' % self.configuration.quota_backend) - with open(test_lustre_quota_info_filepath, 'wb') as fh: + test_lustre_quota_info_filepath = os.path.join( + self.configuration.quota_home, + "%s.pck" % self.configuration.quota_backend, + ) + with open(test_lustre_quota_info_filepath, "wb") as fh: fh.write(pickle.dumps(TEST_LUSTRE_QUOTA_INFO)) - quota_test_client_path \ - = os.path.join(quota_user_path, - "%s.pck" % client_id_dir(TEST_CLIENT_DN)) + quota_test_client_path = os.path.join( + quota_user_path, "%s.pck" % client_id_dir(TEST_CLIENT_DN) + ) - with open(quota_test_client_path, 'wb') as fh: + with open(quota_test_client_path, "wb") as fh: fh.write(pickle.dumps(TEST_CLIENT_USAGE)) - quot_test_vgrid_filepath1 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME1) - with open(quot_test_vgrid_filepath1, 'wb') as fh: + quot_test_vgrid_filepath1 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME1 + ) + with open(quot_test_vgrid_filepath1, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE1)) - quot_test_vgrid_filepath2 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME2) - with open(quot_test_vgrid_filepath2, 'wb') as fh: + quot_test_vgrid_filepath2 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME2 + ) + with open(quot_test_vgrid_filepath2, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE2)) - quot_test_vgrid_filepath3 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME3) - with open(quot_test_vgrid_filepath3, 'wb') as fh: + quot_test_vgrid_filepath3 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME3 + ) + with open(quot_test_vgrid_filepath3, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE3)) test_client_peers_filepath = os.path.join( - test_client_peers_path, peers_filename) - with open(test_client_peers_filepath, 'wb') as fh: + test_client_peers_path, peers_filename + ) + with open(test_client_peers_filepath, "wb") as fh: fh.write(pickle.dumps(TEST_PEERS)) - quota_test_client_ext_path \ - = os.path.join(quota_user_path, - "%s.pck" % client_id_dir(TEST_EXT_DN)) - with open(quota_test_client_ext_path, 'wb') as fh: + quota_test_client_ext_path = os.path.join( + quota_user_path, "%s.pck" % client_id_dir(TEST_EXT_DN) + ) + with open(quota_test_client_ext_path, "wb") as fh: fh.write(pickle.dumps(TEST_EXT_USAGE)) - quota_test_freeze_path = os.path.join(quota_freeze_path, - "%s.pck" - % client_id_dir(TEST_CLIENT_DN)) - with open(quota_test_freeze_path, 'wb') as fh: + quota_test_freeze_path = os.path.join( + quota_freeze_path, "%s.pck" % client_id_dir(TEST_CLIENT_DN) + ) + with open(quota_test_freeze_path, "wb") as fh: fh.write(pickle.dumps(TEST_FREEZE_USAGE)) def test_accounting(self): @@ -211,36 +243,43 @@ def test_accounting(self): usage = get_usage(self.configuration) self.assertNotEqual(usage, {}) - accounting = usage.get('accounting', {}) + accounting = usage.get("accounting", {}) test_user_accounting = accounting.get(TEST_CLIENT_DN, {}) self.assertNotEqual(test_user_accounting, {}) - home_total = test_user_accounting.get('home_total', 0) + home_total = test_user_accounting.get("home_total", 0) self.assertEqual(home_total, TEST_CLIENT_BYTES) - vgrid_total = test_user_accounting.get('vgrid_total', 0) + vgrid_total = test_user_accounting.get("vgrid_total", 0) self.assertEqual(vgrid_total, TEST_VGRID_TOTAL_BYTES) - ext_users_total = test_user_accounting.get('ext_users_total', 0) + ext_users_total = test_user_accounting.get("ext_users_total", 0) self.assertEqual(ext_users_total, TEST_EXT_BYTES) - freeze_total = test_user_accounting.get('freeze_total', 0) + freeze_total = test_user_accounting.get("freeze_total", 0) self.assertEqual(freeze_total, TEST_FREEZE_BYTES) - total_bytes = test_user_accounting.get('total_bytes', 0) + total_bytes = test_user_accounting.get("total_bytes", 0) self.assertEqual(total_bytes, TEST_TOTAL_BYTES) def test_human_readable_filesize_valid(self): """Test human-friendly format helper success on valid byte sizes""" - valid = [(0, "0 B"), (42, "42.000 B"), (2**10, "1.000 KiB"), - (2**30, "1.000 GiB"), (2**50, "1.000 PiB"), - (2**89, "512.000 YiB"), (2**90 - 2**70, "1023.999 YiB")] - for (size, expect) in valid: + valid = [ + (0, "0 B"), + (42, "42.000 B"), + (2**10, "1.000 KiB"), + (2**30, "1.000 GiB"), + (2**50, "1.000 PiB"), + (2**89, "512.000 YiB"), + (2**90 - 2**70, "1023.999 YiB"), + ] + for size, expect in valid: self.assertEqual(human_readable_filesize(size), expect) def test_human_readable_filesize_invalid(self): """Test human-friendly format helper failure on invalid byte sizes""" - invalid = [(i, "NaN") for i in [False, None, "", "one", -1, 1.2, 2**90, - 2**128]] - for (size, expect) in invalid: + invalid = [ + (i, "NaN") for i in [False, None, "", "one", -1, 1.2, 2**90, 2**128] + ] + for size, expect in invalid: self.assertEqual(human_readable_filesize(size), expect) diff --git a/tests/test_mig_lib_daemon.py b/tests/test_mig_lib_daemon.py index bd8cd901c..0cd55ea4e 100644 --- a/tests/test_mig_lib_daemon.py +++ b/tests/test_mig_lib_daemon.py @@ -31,10 +31,22 @@ import signal import time -from mig.lib.daemon import _run_event, _stop_event, check_run, check_stop, \ - do_run, interruptible_sleep, register_run_handler, register_stop_handler, \ - reset_run, reset_stop, run_handler, stop_handler, stop_running, \ - unregister_signal_handlers +from mig.lib.daemon import ( + _run_event, + _stop_event, + check_run, + check_stop, + do_run, + interruptible_sleep, + register_run_handler, + register_stop_handler, + reset_run, + reset_stop, + run_handler, + stop_handler, + stop_running, + unregister_signal_handlers, +) from tests.support import FakeConfiguration, FakeLogger, MigTestCase @@ -42,19 +54,25 @@ class MigLibDaemon(MigTestCase): """Unit tests for daemon related helper functions""" # Signals registered across the tests and explicitly unregistered on init - _used_signals = [signal.SIGCONT, signal.SIGINT, signal.SIGALRM, - signal.SIGABRT, signal.SIGUSR1, signal.SIGUSR2] + _used_signals = [ + signal.SIGCONT, + signal.SIGINT, + signal.SIGALRM, + signal.SIGABRT, + signal.SIGUSR1, + signal.SIGUSR2, + ] def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up any test configuration and reset state before each test""" # Create dummy sig and frame values for isolated test use - self.sig = 'SIGNAL' - self.frame = 'FRAME' + self.sig = "SIGNAL" + self.frame = "FRAME" # Reset event states reset_run() @@ -94,7 +112,7 @@ def test_interruptible_sleep(self): max_secs = 4.2 start = time.time() signal.alarm(1) - interruptible_sleep(self.configuration, max_secs, (check_run, )) + interruptible_sleep(self.configuration, max_secs, (check_run,)) self.assertTrue(check_run()) end = time.time() self.assertTrue(end - start < max_secs) @@ -303,14 +321,17 @@ def test_concurrent_event_handling(self): def test_interruptible_sleep_immediate_break(self): """Test interruptible_sleep with immediate break condition""" + def immediate_true(): return True start = time.time() interruptible_sleep(self.configuration, 5.0, [immediate_true]) duration = time.time() - start - self.assertTrue(duration < 0.1, - "Sleep should exit immediately but took %s" % duration) + self.assertTrue( + duration < 0.1, + "Sleep should exit immediately but took %s" % duration, + ) def test_reset_event_helpers(self): """Test simple event reset helpers""" @@ -337,36 +358,41 @@ def test_unregister_signal_handlers_explicit(self): register_stop_handler(self.configuration, signal.SIGABRT) # Verify handlers were set - self.assertEqual(signal.getsignal(signal.SIGALRM).__name__, - 'run_handler') - self.assertEqual(signal.getsignal(signal.SIGABRT).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGALRM).__name__, "run_handler" + ) + self.assertEqual( + signal.getsignal(signal.SIGABRT).__name__, "stop_handler" + ) # Unregister specific signals - unregister_signal_handlers(self.configuration, [signal.SIGALRM, - signal.SIGABRT]) + unregister_signal_handlers( + self.configuration, [signal.SIGALRM, signal.SIGABRT] + ) self.assertEqual(signal.getsignal(signal.SIGALRM), signal.SIG_IGN) self.assertEqual(signal.getsignal(signal.SIGABRT), signal.SIG_IGN) def test_interruptible_sleep_condition_after_interval(self): """Test interruptible_sleep break condition after one interval""" - state = {'count': 0} + state = {"count": 0} def counter_condition(): - state['count'] += 1 - return state['count'] >= 2 + state["count"] += 1 + return state["count"] >= 2 start = time.time() - interruptible_sleep(self.configuration, 5.0, [counter_condition], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 5.0, [counter_condition], nap_secs=0.1 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.2, delta=0.15) def test_interruptible_sleep_maxsecs_equals_napsecs(self): """Test interruptible_sleep with max_secs exactly matching nap_secs""" start = time.time() - interruptible_sleep(self.configuration, 0.1, [lambda: False], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 0.1, [lambda: False], nap_secs=0.1 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.1, delta=0.05) @@ -379,8 +405,9 @@ def faulty_condition(): self.logger.error(SLEEP_ERR) start = time.time() - interruptible_sleep(self.configuration, 0.1, [faulty_condition], - nap_secs=0.01) + interruptible_sleep( + self.configuration, 0.1, [faulty_condition], nap_secs=0.01 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.1, delta=0.05) try: @@ -402,22 +429,26 @@ def test_unregister_default_signals(self): self.assertEqual(signal.getsignal(signal.SIGCONT), signal.SIG_IGN) self.assertEqual(signal.getsignal(signal.SIGUSR2), signal.SIG_IGN) # Verify custom signals remain after default unregister - self.assertEqual(signal.getsignal(signal.SIGINT).__name__, - 'run_handler') - self.assertEqual(signal.getsignal(signal.SIGALRM).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGINT).__name__, "run_handler" + ) + self.assertEqual( + signal.getsignal(signal.SIGALRM).__name__, "stop_handler" + ) def test_register_default_signal(self): """Test handler registration with default signal values""" # Run handler should default to SIGCONT register_run_handler(self.configuration) - self.assertEqual(signal.getsignal(signal.SIGCONT).__name__, - 'run_handler') + self.assertEqual( + signal.getsignal(signal.SIGCONT).__name__, "run_handler" + ) # Stop handler should default to SIGINT register_stop_handler(self.configuration) - self.assertEqual(signal.getsignal(signal.SIGINT).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGINT).__name__, "stop_handler" + ) def test_reset_unregistered_signals(self): """Test unregister responds gracefully to previously unregistered signals""" @@ -440,21 +471,22 @@ def test_interruptible_sleep_break_not_callable(self): def test_interruptible_sleep_all_conditions_checked(self): """Verify all break conditions are checked each sleep interval""" - counter = {'count': 0} + counter = {"count": 0} max_checks = 3 def counter_condition(): - if counter['count'] < max_checks: - counter['count'] += 1 - return counter['count'] >= max_checks + if counter["count"] < max_checks: + counter["count"] += 1 + return counter["count"] >= max_checks start = time.time() - interruptible_sleep(self.configuration, 5.0, [counter_condition], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 5.0, [counter_condition], nap_secs=0.1 + ) duration = time.time() - start # Should run for ~0.3 sec (3 naps of 0.1 sec) self.assertAlmostEqual(duration, 0.3, delta=0.15) - self.assertEqual(counter['count'], max_checks) + self.assertEqual(counter["count"], max_checks) def test_interruptible_sleep_naps_remaining(self): """Test interruptible_sleep counts down remaining naps correctly""" @@ -517,21 +549,20 @@ def test_event_state_persistence(self): def test_signal_handler_dispatch(self): """Verify signal handlers dispatch correct signals""" - test_signals = { - 'run': [signal.SIGUSR1], - 'stop': [signal.SIGUSR2] - } + test_signals = {"run": [signal.SIGUSR1], "stop": [signal.SIGUSR2]} - for func, sigs in [(register_run_handler, test_signals['run']), - (register_stop_handler, test_signals['stop'])]: + for func, sigs in [ + (register_run_handler, test_signals["run"]), + (register_stop_handler, test_signals["stop"]), + ]: for sig in sigs: func(self.configuration, sig) # Verify handler registration dispatch = signal.getsignal(sig) if func == register_run_handler: - self.assertEqual(dispatch.__name__, 'run_handler') + self.assertEqual(dispatch.__name__, "run_handler") else: - self.assertEqual(dispatch.__name__, 'stop_handler') + self.assertEqual(dispatch.__name__, "stop_handler") def test_event_set_unset_lifecycle(self): """Verify full event lifecycle""" diff --git a/tests/test_mig_lib_events.py b/tests/test_mig_lib_events.py index 96ba8c91d..49bf6d5bb 100644 --- a/tests/test_mig_lib_events.py +++ b/tests/test_mig_lib_events.py @@ -32,12 +32,28 @@ import unittest # Imports of the code under test -from mig.lib.events import _restore_env, _save_env, at_remain, cron_match, \ - get_path_expand_map, get_time_expand_map, load_atjobs, load_crontab +from mig.lib.events import ( + _restore_env, + _save_env, + at_remain, + cron_match, + get_path_expand_map, + get_time_expand_map, + load_atjobs, + load_crontab, +) from mig.lib.events import main as events_main -from mig.lib.events import parse_and_save_atjobs, parse_and_save_crontab, \ - parse_atjobs, parse_atjobs_contents, parse_crontab, \ - parse_crontab_contents, run_cron_command, run_events_command +from mig.lib.events import ( + parse_and_save_atjobs, + parse_and_save_crontab, + parse_atjobs, + parse_atjobs_contents, + parse_crontab, + parse_crontab_contents, + run_cron_command, + run_events_command, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist @@ -356,8 +372,7 @@ def test_cron_match_with_wildcards(self): ), ] for job, expected in test_cases: - self.assertEqual(cron_match( - self.configuration, now, job), expected) + self.assertEqual(cron_match(self.configuration, now, job), expected) def test_cron_match_specific_time(self): """Test cron_match rejects non-matching time""" @@ -507,7 +522,8 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -519,7 +535,10 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -531,7 +550,8 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -543,7 +563,10 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ] for job, now in test_cases: @@ -643,7 +666,10 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -655,7 +681,8 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -667,7 +694,10 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -679,7 +709,8 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ] for job, now in test_cases: @@ -979,7 +1010,13 @@ def test_at_remain_with_past_leap_second(self): microsecond=0, ) t_plus_sixtyone = now.replace( - year=2017, month=1, day=1, hour=0, minute=1, second=0, microsecond=0 + year=2017, + month=1, + day=1, + hour=0, + minute=1, + second=0, + microsecond=0, ) self.assertEqual( (t_plus_sixty - t_minus_sixty).total_seconds(), @@ -1409,8 +1446,7 @@ def test_get_path_expand_map_with_relative_path(self): trigger_path = "../relative/path/file.txt" rule = {"vgrid_name": "test", "run_as": DUMMY_USER_DN} expanded = get_path_expand_map(trigger_path, rule, "modified") - self.assertEqual(expanded["+TRIGGERPATH+"], - "../relative/path/file.txt") + self.assertEqual(expanded["+TRIGGERPATH+"], "../relative/path/file.txt") self.assertEqual(expanded["+TRIGGERFILENAME+"], "file.txt") self.assertEqual(expanded["+TRIGGERPREFIX+"], "file") self.assertEqual(expanded["+TRIGGEREXTENSION+"], ".txt") @@ -1807,8 +1843,9 @@ def test_parse_and_save_crontab(self): def test_parse_atjobs(self): """Test parsing atjobs content lines""" parsed = parse_atjobs_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_ATJOBS_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_ATJOBS_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 1) self.assertEqual(parsed[0]["command"], ["/bin/future_command"]) @@ -1816,8 +1853,9 @@ def test_parse_atjobs(self): def test_parse_atjobs_contents(self): """Test parsing atjobs content lines""" parsed = parse_atjobs_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_ATJOBS_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_ATJOBS_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 1) self.assertEqual(parsed[0]["command"], ["/bin/future_command"]) @@ -1825,8 +1863,9 @@ def test_parse_atjobs_contents(self): def test_parse_crontab(self): """Test parsing crontab content lines""" parsed = parse_crontab_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_CRONTAB_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_CRONTAB_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0]["command"], ["/bin/test_command"]) @@ -1834,8 +1873,9 @@ def test_parse_crontab(self): def test_parse_crontab_contents(self): """Test parsing crontab content lines""" parsed = parse_crontab_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_CRONTAB_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_CRONTAB_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0]["command"], ["/bin/test_command"]) @@ -2655,7 +2695,10 @@ def test_run_cron_command_with_invalid_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -2980,7 +3023,10 @@ def test_run_cron_command_with_here_document(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3014,7 +3060,10 @@ def test_run_cron_command_with_subshell(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3074,7 +3123,10 @@ def test_run_cron_command_with_true_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3108,7 +3160,10 @@ def test_run_cron_command_with_false_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3142,7 +3197,10 @@ def test_run_cron_command_with_null_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3176,7 +3234,10 @@ def test_run_cron_command_with_builtin_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3210,7 +3271,10 @@ def test_run_cron_command_with_alias_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3244,7 +3308,10 @@ def test_run_cron_command_with_function_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3278,7 +3345,10 @@ def test_run_cron_command_with_reserved_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3295,7 +3365,10 @@ def test_run_cron_command_with_dot_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3329,7 +3402,10 @@ def test_run_cron_command_with_colon_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3346,7 +3422,10 @@ def test_run_cron_command_with_bracket_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3387,6 +3466,7 @@ def before_each(self): def test_existing_main(self): """Wrap existing self-tests""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: diff --git a/tests/test_mig_lib_janitor.py b/tests/test_mig_lib_janitor.py index 59a522cde..2daaadbc5 100644 --- a/tests/test_mig_lib_janitor.py +++ b/tests/test_mig_lib_janitor.py @@ -32,18 +32,36 @@ import time import unittest -from mig.lib.janitor import EXPIRE_DUMMY_JOBS_DAYS, EXPIRE_REQ_DAYS, \ - EXPIRE_STATE_DAYS, EXPIRE_TWOFACTOR_DAYS, MANAGE_TRIVIAL_REQ_MINUTES, \ - REMIND_REQ_DAYS, SECS_PER_DAY, SECS_PER_HOUR, SECS_PER_MINUTE, \ - _clean_stale_state_files, _lookup_last_run, _update_last_run, \ - clean_mig_system_files, clean_no_job_helpers, \ - clean_sessid_to_mrls_link_home, clean_twofactor_sessions, \ - clean_webserver_home, handle_cache_updates, handle_janitor_tasks, \ - handle_pending_requests, handle_session_cleanup, handle_state_cleanup, \ - manage_single_req, manage_trivial_user_requests, \ - remind_and_expire_user_pending, task_triggers +from mig.lib.janitor import ( + EXPIRE_DUMMY_JOBS_DAYS, + EXPIRE_REQ_DAYS, + EXPIRE_STATE_DAYS, + EXPIRE_TWOFACTOR_DAYS, + MANAGE_TRIVIAL_REQ_MINUTES, + REMIND_REQ_DAYS, + SECS_PER_DAY, + SECS_PER_HOUR, + SECS_PER_MINUTE, + _clean_stale_state_files, + _lookup_last_run, + _update_last_run, + clean_mig_system_files, + clean_no_job_helpers, + clean_sessid_to_mrls_link_home, + clean_twofactor_sessions, + clean_webserver_home, + handle_cache_updates, + handle_janitor_tasks, + handle_pending_requests, + handle_session_cleanup, + handle_state_cleanup, + manage_single_req, + manage_trivial_user_requests, + remind_and_expire_user_pending, + task_triggers, +) from mig.shared.accountreq import save_account_request -from mig.shared.base import distinguished_name_to_user, client_id_dir +from mig.shared.base import client_id_dir, distinguished_name_to_user from mig.shared.pwcrypto import generate_reset_token from tests.support import MigTestCase, ensure_dirs_exist @@ -52,25 +70,31 @@ TEST_USER_ORG = "Test Org" TEST_USER_EMAIL = "test@example.com" # TODO: move next to support.usersupp? -TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=%s/OU=NA/CN=%s/emailAddress=%s' % \ - (TEST_USER_ORG, TEST_USER_FULLNAME, TEST_USER_EMAIL) -TEST_SKIP_EMAIL = '' +TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=%s/OU=NA/CN=%s/emailAddress=%s" % ( + TEST_USER_ORG, + TEST_USER_FULLNAME, + TEST_USER_EMAIL, +) +TEST_SKIP_EMAIL = "" # TODO: adjust password reset token helpers to handle configured services # it currently silently fails if not in migoid(c) or migcert # TEST_SERVICE = 'dummy-svc' -TEST_AUTH = TEST_SERVICE = 'migoid' -TEST_USERDB = 'MiG-users.db' -TEST_PEER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com' +TEST_AUTH = TEST_SERVICE = "migoid" +TEST_USERDB = "MiG-users.db" +TEST_PEER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com" # NOTE: these passwords are not and should not ever be used outside unit tests -TEST_MODERN_PW = 'NoSuchPassword_42' -TEST_MODERN_PW_PBKDF2 = \ +TEST_MODERN_PW = "NoSuchPassword_42" +TEST_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix" -TEST_NEW_MODERN_PW_PBKDF2 = \ +) +TEST_NEW_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf10n581pjXFHrn" -TEST_INVALID_PW_PBKDF2 = \ +) +TEST_INVALID_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf1rn1pjX0n58FH" +) # NOTE: tokens always should contain a multiple of 4 chars -INVALID_TEST_TOKEN = 'THIS_RESET_TOKEN_WAS_NEVER_VALID' +INVALID_TEST_TOKEN = "THIS_RESET_TOKEN_WAS_NEVER_VALID" class MigLibJanitor(MigTestCase): @@ -78,11 +102,11 @@ class MigLibJanitor(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" - def _prepare_test_file(self, path, times=None, content='test'): + def _prepare_test_file(self, path, times=None, content="test"): """Prepare file in path with optional times for timestamp""" - with open(path, 'w') as fp: + with open(path, "w") as fp: fp.write(content) os.utime(path, times) @@ -94,8 +118,9 @@ def before_each(self): self.configuration.site_login_methods.append(TEST_AUTH) # Prevent admin email during reject, etc. self.configuration.admin_email = TEST_SKIP_EMAIL - self.user_db_path = os.path.join(self.configuration.user_db_home, - TEST_USERDB) + self.user_db_path = os.path.join( + self.configuration.user_db_home, TEST_USERDB + ) # Create fake fs layout matching real systems ensure_dirs_exist(self.configuration.user_pending) ensure_dirs_exist(self.configuration.user_db_home) @@ -110,8 +135,9 @@ def before_each(self): ensure_dirs_exist(self.configuration.sessid_to_mrsl_link_home) ensure_dirs_exist(self.configuration.mrsl_files_dir) ensure_dirs_exist(self.configuration.resource_pending) - dummy_job = os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler") + dummy_job = os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ) ensure_dirs_exist(dummy_job) # Prepare user DB with a single dummy user for all tests @@ -124,20 +150,20 @@ def before_each(self): def test_last_run_bookkeeping(self): """Register a last run timestamp and check it""" expect = -1 - stamp = _lookup_last_run(self.configuration, 'janitor_task') + stamp = _lookup_last_run(self.configuration, "janitor_task") self.assertEqual(stamp, expect) expect = 42 - stamp = _update_last_run(self.configuration, 'janitor_task', expect) + stamp = _update_last_run(self.configuration, "janitor_task", expect) self.assertEqual(stamp, expect) expect = time.time() - stamp = _update_last_run(self.configuration, 'janitor_task', expect) + stamp = _update_last_run(self.configuration, "janitor_task", expect) self.assertEqual(stamp, expect) def test_clean_mig_system_files(self): """Test clean_mig system files helper""" test_time = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 - valid_filenames = ['fresh.log', 'current.tmp'] - stale_filenames = ['tmp_expired.txt', 'no_grid_jobs.123'] + valid_filenames = ["fresh.log", "current.tmp"] + stale_filenames = ["tmp_expired.txt", "no_grid_jobs.123"] for name in valid_filenames + stale_filenames: path = os.path.join(self.configuration.mig_system_files, name) self._prepare_test_file(path, (test_time, test_time)) @@ -145,8 +171,10 @@ def test_clean_mig_system_files(self): handled = clean_mig_system_files(self.configuration) self.assertEqual(handled, len(stale_filenames)) - self.assertEqual(len(os.listdir(self.configuration.mig_system_files)), - len(valid_filenames)) + self.assertEqual( + len(os.listdir(self.configuration.mig_system_files)), + len(valid_filenames), + ) for name in valid_filenames: path = os.path.join(self.configuration.mig_system_files, name) self.assertTrue(os.path.exists(path)) @@ -158,8 +186,8 @@ def test_clean_webserver_home(self): """Test clean webserver files helper""" stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.webserver_home - valid_filename = 'fresh.log' - stale_filename = 'stale.log' + valid_filename = "fresh.log" + stale_filename = "stale.log" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -175,10 +203,11 @@ def test_clean_webserver_home(self): def test_clean_no_job_helpers(self): """Test clean dummy job helper files""" stale_stamp = time.time() - EXPIRE_DUMMY_JOBS_DAYS * SECS_PER_DAY - 1 - test_dir = os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler") - valid_filename = 'alive.txt' - stale_filename = 'expired.txt' + test_dir = os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ) + valid_filename = "alive.txt" + stale_filename = "expired.txt" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -195,8 +224,8 @@ def test_clean_twofactor_sessions(self): """Test clean twofactor sessions""" stale_stamp = time.time() - EXPIRE_TWOFACTOR_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.twofactor_home - valid_filename = 'current' - stale_filename = 'expired' + valid_filename = "current" + stale_filename = "expired" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -213,8 +242,8 @@ def test_clean_sessid_to_mrls_link_home(self): """Test clean session MRSL link files""" stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.sessid_to_mrsl_link_home - valid_filename = 'active_session_link' - stale_filename = 'expired_session_link' + valid_filename = "active_session_link" + stale_filename = "expired_session_link" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -232,12 +261,14 @@ def test_handle_state_cleanup(self): # Create a stale file in each location to clean up stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 mig_path = os.path.join( - self.configuration.mig_system_files, 'tmpAbCd1234') - web_path = os.path.join(self.configuration.webserver_home, 'stale.txt') + self.configuration.mig_system_files, "tmpAbCd1234" + ) + web_path = os.path.join(self.configuration.webserver_home, "stale.txt") empty_job_path = os.path.join( - os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler"), - 'sleep.job' + os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ), + "sleep.job", ) stale_paths = [mig_path, web_path, empty_job_path] for path in stale_paths: @@ -252,12 +283,17 @@ def test_handle_state_cleanup(self): def test_handle_session_cleanup(self): """Test combined session cleanup""" - stale_stamp = time.time() - max(EXPIRE_STATE_DAYS, - EXPIRE_TWOFACTOR_DAYS) * SECS_PER_DAY - 1 + stale_stamp = ( + time.time() + - max(EXPIRE_STATE_DAYS, EXPIRE_TWOFACTOR_DAYS) * SECS_PER_DAY + - 1 + ) session_path = os.path.join( - self.configuration.sessid_to_mrsl_link_home, 'expired.txt') + self.configuration.sessid_to_mrsl_link_home, "expired.txt" + ) twofactor_path = os.path.join( - self.configuration.twofactor_home, 'expired.txt') + self.configuration.twofactor_home, "expired.txt" + ) test_paths = [session_path, twofactor_path] for path in test_paths: os.makedirs(os.path.dirname(path), exist_ok=True) @@ -271,17 +307,17 @@ def test_handle_session_cleanup(self): def test_manage_pending_user_request(self): """Test pending user request management""" - req_id = 'req_id' + req_id = "req_id" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) @@ -293,24 +329,25 @@ def test_manage_pending_user_request(self): os.utime(req_path, (req_age, req_age)) # Need user DB and path to simulate existing user - user_dir = os.path.join(self.configuration.user_home, - client_id_dir(TEST_USER_DN)) + user_dir = os.path.join( + self.configuration.user_home, client_id_dir(TEST_USER_DN) + ) os.makedirs(user_dir, exist_ok=True) handled = manage_trivial_user_requests(self.configuration) self.assertEqual(handled, 1) def test_expire_user_pending(self): """Test pending user request expiration reminders""" - req_id = 'expired_req' + req_id = "expired_req" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) saved, req_path = save_account_request(self.configuration, req_dict) @@ -333,42 +370,46 @@ def test_handle_pending_requests(self): """Test combined request handling""" # Create requests (valid, expired) valid_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) - saved, valid_req_path = save_account_request(self.configuration, - valid_dict) + saved, valid_req_path = save_account_request( + self.configuration, valid_dict + ) self.assertTrue(saved, "failed to save valid req") self.assertDirNotEmpty(self.configuration.user_pending) valid_id = os.path.basename(valid_req_path) - expired_id = 'expired_req' + expired_id = "expired_req" expired_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } saved, expired_req_path = save_account_request( - self.configuration, expired_dict) + self.configuration, expired_dict + ) self.assertTrue(saved, "failed to save expired req") expired_id = os.path.basename(expired_req_path) # Make just one old enough to expire expire_time = time.time() - EXPIRE_REQ_DAYS * SECS_PER_DAY - 1 - os.utime(os.path.join(self.configuration.user_pending, expired_id), - (expire_time, expire_time)) + os.utime( + os.path.join(self.configuration.user_pending, expired_id), + (expire_time, expire_time), + ) # NOTE: when using real user mail we currently hit send email errors. # We forgive those errors here and only check any known warnings. @@ -383,25 +424,29 @@ def test_handle_janitor_tasks_full(self): """Test full janitor task scheduler""" # Prepare environment with pending tasks of each kind mig_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 - mig_path = os.path.join(self.configuration.mig_system_files, - 'tmp-stale.txt') - two_path = os.path.join(self.configuration.twofactor_home, 'stale.txt') + mig_path = os.path.join( + self.configuration.mig_system_files, "tmp-stale.txt" + ) + two_path = os.path.join(self.configuration.twofactor_home, "stale.txt") two_stamp = time.time() - EXPIRE_TWOFACTOR_DAYS * SECS_PER_DAY - 1 - stale_tests = ((mig_path, mig_stamp), (two_path, two_stamp), ) - for (stale_path, stale_stamp) in stale_tests: + stale_tests = ( + (mig_path, mig_stamp), + (two_path, two_stamp), + ) + for stale_path, stale_stamp in stale_tests: self._prepare_test_file(stale_path, (stale_stamp, stale_stamp)) self.assertTrue(os.path.exists(stale_path)) - req_id = 'expired_request' + req_id = "expired_request" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) saved, req_path = save_account_request(self.configuration, req_dict) @@ -410,8 +455,10 @@ def test_handle_janitor_tasks_full(self): req_id = os.path.basename(req_path) # Make request very old req_age = time.time() - EXPIRE_REQ_DAYS * SECS_PER_DAY - 1 - os.utime(os.path.join(self.configuration.user_pending, req_id), - (req_age, req_age)) + os.utime( + os.path.join(self.configuration.user_pending, req_id), + (req_age, req_age), + ) # Set no last run timestamps to trigger all tasks now = time.time() @@ -426,26 +473,26 @@ def test_handle_janitor_tasks_full(self): handled = handle_janitor_tasks(self.configuration, now=now) # self.assertEqual(handled, 3) # state+session+requests self.assertEqual(handled, 5) # state+session+3*request - for (stale_path, _) in stale_tests: + for stale_path, _ in stale_tests: self.assertFalse(os.path.exists(stale_path), stale_path) def test__clean_stale_state_files(self): """Test core stale state file cleaner helper""" - test_dir = self.temppath('stale_state_test', ensure_dir=True) - patterns = ['tmp_*', 'session_*'] + test_dir = self.temppath("stale_state_test", ensure_dir=True) + patterns = ["tmp_*", "session_*"] # Create test files (fresh, expired, unexpired, non-matching) test_remove = [ - ('tmp_expired.txt', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), - ('session_old.dat', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("tmp_expired.txt", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("session_old.dat", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), ] test_keep = [ - ('tmp_fresh.txt', -1), - ('session_valid.dat', 0), - ('other_file.log', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("tmp_fresh.txt", -1), + ("session_valid.dat", 0), + ("other_file.log", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), ] - for (name, age_diff) in test_keep + test_remove: + for name, age_diff in test_keep + test_remove: path = os.path.join(test_dir, name) stamp = time.time() - age_diff self._prepare_test_file(path, (stamp, stamp)) @@ -457,27 +504,27 @@ def test__clean_stale_state_files(self): patterns, EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=False + include_dotfiles=False, ) self.assertEqual(handled, 2) # tmp_expired.txt + session_old.dat - for (name, _) in test_keep: + for name, _ in test_keep: path = os.path.join(test_dir, name) self.assertTrue(os.path.exists(path)) - for (name, _) in test_remove: + for name, _ in test_remove: path = os.path.join(test_dir, name) self.assertFalse(os.path.exists(path)) def test_manage_single_req_invalid(self): """Test request handling for invalid request""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'invalid': ['Missing required field: organization'], - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "invalid": ["Missing required field: organization"], + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -486,17 +533,18 @@ def test_manage_single_req_invalid(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('invalid account request' in msg - for msg in log_capture.output)) + self.assertTrue( + any("invalid account request" in msg for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed to clean invalid req for %s" % req_path) @@ -504,24 +552,24 @@ def test_manage_single_req_invalid(self): def test_manage_single_req_expired_token(self): """Test request handling with expired reset token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, } # Mimic proper but old expired token timestamp = 42 # IMPORTANT: we can't use a fixed token here due to dynamic crypto seed - req_dict['reset_token'] = generate_reset_token(self.configuration, - req_dict, TEST_SERVICE, - timestamp) + req_dict["reset_token"] = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -529,17 +577,21 @@ def test_manage_single_req_expired_token(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('reject expired reset token' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "reject expired reset token" in msg + for msg in log_capture.output + ) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed to clean token req for %s" % req_path) @@ -548,52 +600,57 @@ def test_manage_single_req_expired_token(self): def test_manage_single_req_invalid_token(self): """Test request handling with invalid reset token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() - SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() - SECS_PER_DAY, } # Inject known invalid reset token - req_dict['reset_token'] = INVALID_TEST_TOKEN + req_dict["reset_token"] = INVALID_TEST_TOKEN # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('reset with bad token' in msg - for msg in log_capture.output)) - self.assertFalse(os.path.exists(req_path), - "Failed to clean token req for %s" % req_path) + self.assertTrue( + any("reset with bad token" in msg for msg in log_capture.output) + ) + self.assertFalse( + os.path.exists(req_path), + "Failed to clean token req for %s" % req_path, + ) def test_manage_single_req_collision(self): """Test request handling with existing user collision""" # Create collision with the already provisioned user with TEST_USER_DN changed_full_name = "Changed Test Name" req_dict = { - 'client_id': TEST_USER_DN.replace(TEST_USER_FULLNAME, - changed_full_name), - 'distinguished_name': TEST_USER_DN.replace(TEST_USER_FULLNAME, - changed_full_name), - 'auth': [TEST_AUTH], - 'full_name': changed_full_name, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN.replace( + TEST_USER_FULLNAME, changed_full_name + ), + "distinguished_name": TEST_USER_DN.replace( + TEST_USER_FULLNAME, changed_full_name + ), + "auth": [TEST_AUTH], + "full_name": changed_full_name, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to cause collision - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -602,16 +659,17 @@ def test_manage_single_req_collision(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), + ) + self.assertTrue( + any("ID collision" in msg for msg in log_capture.output) ) - self.assertTrue(any('ID collision' in msg - for msg in log_capture.output)) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup collision for %s" % req_path) @@ -619,20 +677,20 @@ def test_manage_single_req_collision(self): def test_manage_single_req_auth_change(self): """Test request handling with auth password change""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password': '', - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password": "", + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, } # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 - req_dict['authorized'] = True + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["authorized"] = True saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -640,19 +698,20 @@ def test_manage_single_req_auth_change(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue( - any('accepted' in msg for msg in log_capture.output)) - self.assertFalse(os.path.exists(req_path), - "Failed to clean token req for %s" % req_path) + self.assertTrue(any("accepted" in msg for msg in log_capture.output)) + self.assertFalse( + os.path.exists(req_path), + "Failed to clean token req for %s" % req_path, + ) def test_handle_cache_updates_stub(self): """Test handle_cache_updates placeholder returns zero""" @@ -662,7 +721,7 @@ def test_handle_cache_updates_stub(self): def test_janitor_update_timestamps(self): """Test task trigger timestamp updates in janitor""" now = time.time() - task = 'test-task' + task = "test-task" # Initial state stamp = _lookup_last_run(self.configuration, task) @@ -678,24 +737,24 @@ def test_janitor_update_timestamps(self): def test__clean_stale_state_files_edge(self): """Test state file cleaner with special cases""" - test_dir = self.temppath('edge_case_test', ensure_dir=True) + test_dir = self.temppath("edge_case_test", ensure_dir=True) # Dot file - dot_path = os.path.join(test_dir, '.hidden.tmp') + dot_path = os.path.join(test_dir, ".hidden.tmp") stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 self._prepare_test_file(dot_path, (stamp, stamp)) # Directory - dir_path = os.path.join(test_dir, 'subdir') + dir_path = os.path.join(test_dir, "subdir") os.makedirs(dir_path) handled = _clean_stale_state_files( self.configuration, test_dir, - ['*'], + ["*"], EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=False + include_dotfiles=False, ) self.assertEqual(handled, 0) @@ -703,46 +762,50 @@ def test__clean_stale_state_files_edge(self): handled = _clean_stale_state_files( self.configuration, test_dir, - ['*'], + ["*"], EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=True + include_dotfiles=True, ) self.assertEqual(handled, 1) @unittest.skip("TODO: enable once unpickling error handling is improved") def test_manage_single_req_corrupted_file(self): """Test manage_single_req with corrupted request file""" - req_id = 'corrupted_req' + req_id = "corrupted_req" req_path = os.path.join(self.configuration.user_pending, req_id) - with open(req_path, 'w') as fp: - fp.write('invalid pickle content') + with open(req_path, "w") as fp: + fp.write("invalid pickle content") - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('Failed to load request from' in msg - or 'Could not load saved request' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "Failed to load request from" in msg + or "Could not load saved request" in msg + for msg in log_capture.output + ) + ) self.assertFalse(os.path.exists(req_path)) def test_manage_single_req_nonexistent_userdb(self): """Test manage_single_req with missing user database""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -750,38 +813,39 @@ def test_manage_single_req_nonexistent_userdb(self): # Remove user database os.remove(self.user_db_path) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('Failed to load user DB' in msg - for msg in log_capture.output)) + self.assertTrue( + any("Failed to load user DB" in msg for msg in log_capture.output) + ) def test_verify_reset_token_failure_logging(self): """Test token verification failure creates proper log entries""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, # Future expiration + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, # Future expiration } timestamp = time.time() # Now change to another pw hash and generate invalid token from it - req_dict['password_hash'] = TEST_INVALID_PW_PBKDF2 - req_dict['reset_token'] = generate_reset_token(self.configuration, - req_dict, TEST_SERVICE, - timestamp) + req_dict["password_hash"] = TEST_INVALID_PW_PBKDF2 + req_dict["reset_token"] = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -790,17 +854,18 @@ def test_verify_reset_token_failure_logging(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('wrong hash' in msg.lower() - for msg in log_capture.output)) + self.assertTrue( + any("wrong hash" in msg.lower() for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup invalid token for %s" % req_path) @@ -808,23 +873,24 @@ def test_verify_reset_token_failure_logging(self): def test_verify_reset_token_success(self): """Test token verification success with valid token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'email': TEST_USER_EMAIL, - 'password': '', - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, # Future expiration + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "email": TEST_USER_EMAIL, + "password": "", + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, # Future expiration } timestamp = time.time() - reset_token = generate_reset_token(self.configuration, req_dict, - TEST_SERVICE, timestamp) - req_dict['reset_token'] = reset_token + reset_token = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) + req_dict["reset_token"] = reset_token # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -832,17 +898,18 @@ def test_verify_reset_token_success(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('accepted' in msg.lower() - for msg in log_capture.output)) + self.assertTrue( + any("accepted" in msg.lower() for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup invalid token for %s" % req_path) @@ -851,23 +918,22 @@ def test_remind_and_expire_edge_cases(self): """Test request expiration with exact boundary timestamps""" now = time.time() test_cases = [ - ('exact_remind', now - REMIND_REQ_DAYS * SECS_PER_DAY), - ('exact_expire', now - EXPIRE_REQ_DAYS * SECS_PER_DAY), + ("exact_remind", now - REMIND_REQ_DAYS * SECS_PER_DAY), + ("exact_expire", now - EXPIRE_REQ_DAYS * SECS_PER_DAY), ] - for (req_id, mtime) in test_cases: + for req_id, mtime in test_cases: req_path = os.path.join(self.configuration.user_pending, req_id) req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "email": TEST_USER_EMAIL, } - saved, req_path = save_account_request( - self.configuration, req_dict) + saved, req_path = save_account_request(self.configuration, req_dict) os.utime(req_path, (mtime, mtime)) # NOTE: when using real user mail we currently hit send email errors. @@ -884,29 +950,43 @@ def test_handle_janitor_tasks_time_thresholds(self): """Test janitor task frequency thresholds""" now = time.time() - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "pending-reqs"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), -1) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "pending-reqs"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), -1 + ) # Test all tasks EXCEPT cache-updates are past threshold last_state_cleanup = now - SECS_PER_DAY - 3 last_session_cleanup = now - SECS_PER_HOUR - 3 last_pending_reqs = now - SECS_PER_MINUTE - 3 last_cache_update = now - SECS_PER_MINUTE + 10 # Not expired - task_triggers.update({'state-cleanup': last_state_cleanup, - 'session-cleanup': last_session_cleanup, - 'pending-reqs': last_pending_reqs, - 'cache-updates': last_cache_update}) - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), last_state_cleanup) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), last_session_cleanup) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), last_cache_update) + task_triggers.update( + { + "state-cleanup": last_state_cleanup, + "session-cleanup": last_session_cleanup, + "pending-reqs": last_pending_reqs, + "cache-updates": last_cache_update, + } + ) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), + last_state_cleanup, + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), + last_session_cleanup, + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), + last_cache_update, + ) # TODO: handled does NOT count no action runs - add dummies to handle? handled = handle_janitor_tasks(self.configuration, now=now) @@ -914,26 +994,32 @@ def test_handle_janitor_tasks_time_thresholds(self): self.assertEqual(handled, 0) # ran with nothing to do # Verify last run timestamps updated - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "pending-reqs"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), last_cache_update) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "pending-reqs"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), + last_cache_update, + ) @unittest.skip("TODO: enable once cleaner has improved error handling") def test_clean_stale_files_nonexistent_dir(self): """Test state cleaner with invalid directory path""" - target_dir = os.path.join(self.configuration.mig_system_files, - "non_existing_dir") + target_dir = os.path.join( + self.configuration.mig_system_files, "non_existing_dir" + ) handled = _clean_stale_state_files( self.configuration, target_dir, ["*"], EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 0) @@ -947,13 +1033,13 @@ def test_clean_stale_files_permission_error(self): stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 self._prepare_test_file(test_path, (stamp, stamp)) - with self.assertLogs(level='ERROR'): + with self.assertLogs(level="ERROR"): handled = _clean_stale_state_files( self.configuration, test_dir, ["*"], EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 0) @@ -976,14 +1062,14 @@ def test_handle_empty_pending_dir(self): def test_janitor_task_cleanup_after_reject(self): """Verify proper cleanup after request rejection""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'invalid': ['Test intentional invalid'], - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "invalid": ["Test intentional invalid"], + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -1000,7 +1086,7 @@ def test_janitor_task_cleanup_after_reject(self): req_id, req_path, self.user_db_path, - time.time() + time.time(), ) # TODO: enable check for removed req once skip email allows it @@ -1009,15 +1095,15 @@ def test_janitor_task_cleanup_after_reject(self): def test_cleaner_with_multiple_patterns(self): """Test state cleaner with multiple filename patterns""" - test_dir = self.temppath('multi_pattern_test', ensure_dir=True) - clean_patterns = ['*.tmp', '*.log', 'temp*'] + test_dir = self.temppath("multi_pattern_test", ensure_dir=True) + clean_patterns = ["*.tmp", "*.log", "temp*"] clean_pairs = [ - ('should_keep_recent.log', EXPIRE_STATE_DAYS - 1), - ('should_remove_stale.tmp', EXPIRE_STATE_DAYS + 1), - ('should_keep_other.pck', EXPIRE_STATE_DAYS + 1) + ("should_keep_recent.log", EXPIRE_STATE_DAYS - 1), + ("should_remove_stale.tmp", EXPIRE_STATE_DAYS + 1), + ("should_keep_other.pck", EXPIRE_STATE_DAYS + 1), ] - for (name, age_days) in clean_pairs: + for name, age_days in clean_pairs: path = os.path.join(test_dir, name) stamp = time.time() - age_days * SECS_PER_DAY self._prepare_test_file(path, (stamp, stamp)) @@ -1028,15 +1114,18 @@ def test_cleaner_with_multiple_patterns(self): test_dir, clean_patterns, EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 1) - self.assertTrue(os.path.exists( - os.path.join(test_dir, 'should_keep_recent.log'))) - self.assertFalse(os.path.exists( - os.path.join(test_dir, 'should_remove_stale.tmp'))) - self.assertTrue(os.path.exists( - os.path.join(test_dir, 'should_keep_other.pck'))) + self.assertTrue( + os.path.exists(os.path.join(test_dir, "should_keep_recent.log")) + ) + self.assertFalse( + os.path.exists(os.path.join(test_dir, "should_remove_stale.tmp")) + ) + self.assertTrue( + os.path.exists(os.path.join(test_dir, "should_keep_other.pck")) + ) def test_absent_jobs_flag(self): """Test clean_no_job_helpers with site_enable_jobs disabled""" diff --git a/tests/test_mig_lib_quota.py b/tests/test_mig_lib_quota.py index e4057b5d1..c99b0a920 100644 --- a/tests/test_mig_lib_quota.py +++ b/tests/test_mig_lib_quota.py @@ -29,6 +29,7 @@ # Imports of the code under test from mig.lib.quota import update_quota + # Imports required for the unit tests themselves from tests.support import MigTestCase @@ -38,7 +39,7 @@ class MigLibQouta(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up test configuration and reset state before each test""" @@ -47,7 +48,9 @@ def before_each(self): def test_invalid_quota_backend(self): """Test invalid quota_backend in configuration""" self.configuration.quota_backend = "NEVERNEVER" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: update_quota(self.configuration) - self.assertTrue("'NEVERNEVER' not in supported_quota_backends:" in msg - for msg in log_capture.output) + self.assertTrue( + "'NEVERNEVER' not in supported_quota_backends:" in msg + for msg in log_capture.output + ) diff --git a/tests/test_mig_lib_xgicore.py b/tests/test_mig_lib_xgicore.py index e63e22b7d..106b08eea 100644 --- a/tests/test_mig_lib_xgicore.py +++ b/tests/test_mig_lib_xgicore.py @@ -30,9 +30,8 @@ import os import sys -from tests.support import MigTestCase, FakeConfiguration, testmain - from mig.lib.xgicore import * +from tests.support import FakeConfiguration, MigTestCase, testmain class MigLibXgicore__get_output_format(MigTestCase): @@ -42,28 +41,32 @@ def test_default_when_missing(self): """Test that default output_format is returned when not set.""" expected = "html" user_args = {} - actual = get_output_format(FakeConfiguration(), user_args, - default_format=expected) - self.assertEqual(actual, expected, - "mismatch in default output_format") + actual = get_output_format( + FakeConfiguration(), user_args, default_format=expected + ) + self.assertEqual(actual, expected, "mismatch in default output_format") def test_get_single_requested_format(self): """Test that the requested output_format is returned.""" expected = "file" - user_args = {'output_format': [expected]} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) def test_get_first_requested_format(self): """Test that first requested output_format is returned.""" expected = "file" - user_args = {'output_format': [expected, 'BOGUS']} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected, "BOGUS"]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) class MigLibXgicore__override_output_format(MigTestCase): @@ -74,29 +77,35 @@ def test_unchanged_without_override(self): expected = "html" user_args = {} out_objs = [] - actual = override_output_format(FakeConfiguration(), user_args, - out_objs, expected) - self.assertEqual(actual, expected, - "mismatch in unchanged output_format") + actual = override_output_format( + FakeConfiguration(), user_args, out_objs, expected + ) + self.assertEqual( + actual, expected, "mismatch in unchanged output_format" + ) def test_get_single_requested_format(self): """Test that the requested output_format is returned if overriden.""" expected = "file" - user_args = {'output_format': [expected]} - out_objs = [{'object_type': 'start', 'override_format': True}] - actual = override_output_format(FakeConfiguration(), user_args, - out_objs, 'OVERRIDE') - self.assertEqual(actual, expected, - "mismatch in overriden output_format") + user_args = {"output_format": [expected]} + out_objs = [{"object_type": "start", "override_format": True}] + actual = override_output_format( + FakeConfiguration(), user_args, out_objs, "OVERRIDE" + ) + self.assertEqual( + actual, expected, "mismatch in overriden output_format" + ) def test_get_first_requested_format(self): """Test that first requested output_format is returned if overriden.""" expected = "file" - user_args = {'output_format': [expected, 'BOGUS']} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected, "BOGUS"]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) class MigLibXgicore__fill_start_headers(MigTestCase): @@ -105,35 +114,41 @@ class MigLibXgicore__fill_start_headers(MigTestCase): def test_unchanged_when_set(self): """Test that existing valid start entry is returned as-is.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream'), - ('Content-Size', 42)] - expected = {'object_type': 'start', 'headers': headers} - out_objs = [expected, {'object_type': 'binary', 'data': 42*b'0'}] + headers = [ + ("Content-Type", "application/octet-stream"), + ("Content-Size", 42), + ] + expected = {"object_type": "start", "headers": headers} + out_objs = [expected, {"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in unchanged start entry") + self.assertEqual(actual, expected, "mismatch in unchanged start entry") def test_headers_added_when_missing(self): """Test that start entry headers are added if missing.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream')] - minimal_start = {'object_type': 'start'} - expected = {'object_type': 'start', 'headers': headers} - out_objs = [minimal_start, {'object_type': 'binary', 'data': 42*b'0'}] + headers = [("Content-Type", "application/octet-stream")] + minimal_start = {"object_type": "start"} + expected = {"object_type": "start", "headers": headers} + out_objs = [ + minimal_start, + {"object_type": "binary", "data": 42 * b"0"}, + ] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in auto initialized start entry") + self.assertEqual( + actual, expected, "mismatch in auto initialized start entry" + ) def test_start_added_when_missing(self): """Test that start entry is added if missing.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream')] - expected = {'object_type': 'start', 'headers': headers} - out_objs = [{'object_type': 'binary', 'data': 42*b'0'}] + headers = [("Content-Type", "application/octet-stream")] + expected = {"object_type": "start", "headers": headers} + out_objs = [{"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in auto initialized start entry") + self.assertEqual( + actual, expected, "mismatch in auto initialized start entry" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_accountreq.py b/tests/test_mig_shared_accountreq.py index ecba42ff3..0b7a6919d 100644 --- a/tests/test_mig_shared_accountreq.py +++ b/tests/test_mig_shared_accountreq.py @@ -35,9 +35,11 @@ # Imports of the code under test import mig.shared.accountreq as accountreq + # Imports required for the unit test wrapping from mig.shared.base import distinguished_name_to_user, fill_distinguished_name from mig.shared.defaults import keyword_auto + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain from tests.support.fixturesupp import FixtureAssertMixin @@ -47,8 +49,8 @@ class MigSharedAccountreq__peers(MigTestCase, FixtureAssertMixin): """Unit tests for peers related functions within the accountreq module""" - TEST_PEER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com' - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_PEER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com" + TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" @property def user_settings_dir(self): @@ -60,40 +62,46 @@ def user_pending_dir(self): def _load_saved_peer(self, absolute_path): self.assertPathWithin(absolute_path, start=self.user_pending_dir) - with open(absolute_path, 'rb') as pickle_file: + with open(absolute_path, "rb") as pickle_file: value = pickle.load(pickle_file) def _string_if_bytes(value): if isinstance(value, bytes): - return str(value, 'utf8') + return str(value, "utf8") else: return value - return {_string_if_bytes(x): _string_if_bytes(y) - for x, y in value.items()} + + return { + _string_if_bytes(x): _string_if_bytes(y) for x, y in value.items() + } def _peer_dict_from_fixture(self): - prepared_fixture = self.prepareFixtureAssert("peer_user_dict", - fixture_format="json") + prepared_fixture = self.prepareFixtureAssert( + "peer_user_dict", fixture_format="json" + ) fixture_data = prepared_fixture.fixture_data assert fixture_data["distinguished_name"] == self.TEST_PEER_DN return fixture_data - def _record_peer_acceptance(self, test_client_dir_name, - peer_distinguished_name): - """Fabricate a peer acceptance record in a particular user settings dir. - """ + def _record_peer_acceptance( + self, test_client_dir_name, peer_distinguished_name + ): + """Fabricate a peer acceptance record in a particular user settings dir.""" test_user_accepted_peers_file = os.path.join( - self.user_settings_dir, test_client_dir_name, "peers") + self.user_settings_dir, test_client_dir_name, "peers" + ) expire_tomorrow = datetime.date.today() + datetime.timedelta(days=1) - with open(test_user_accepted_peers_file, "wb") as \ - test_user_accepted_peers: - pickle.dump({peer_distinguished_name: - {'expire': str(expire_tomorrow)}}, - test_user_accepted_peers) + with open( + test_user_accepted_peers_file, "wb" + ) as test_user_accepted_peers: + pickle.dump( + {peer_distinguished_name: {"expire": str(expire_tomorrow)}}, + test_user_accepted_peers, + ) def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): ensure_dirs_exist(self.configuration.user_cache) @@ -108,8 +116,9 @@ def test_a_new_peer(self): self.assertDirEmpty(self.configuration.user_pending) request_dict = self._peer_dict_from_fixture() - success, _ = accountreq.save_account_request(self.configuration, - request_dict) + success, _ = accountreq.save_account_request( + self.configuration, request_dict + ) # check that we have an output directory now absolute_files = self.assertDirNotEmpty(self.user_pending_dir) @@ -131,10 +140,11 @@ def test_listing_peers(self): # check the fabricated peer was listed # sadly listing returns _relative_ dirs peer_temp_file_name = listing[0] - peer_pickle_file = os.path.join(self.user_pending_dir, - peer_temp_file_name) + peer_pickle_file = os.path.join( + self.user_pending_dir, peer_temp_file_name + ) peer_pickle = self._load_saved_peer(peer_pickle_file) - self.assertEqual(peer_pickle['distinguished_name'], self.TEST_PEER_DN) + self.assertEqual(peer_pickle["distinguished_name"], self.TEST_PEER_DN) def test_peer_acceptance(self): test_client_dir = self._provision_test_user(self, self.TEST_USER_DN) @@ -142,17 +152,18 @@ def test_peer_acceptance(self): self._record_peer_acceptance(test_client_dir_name, self.TEST_PEER_DN) self.assertDirEmpty(self.user_pending_dir) request_dict = self._peer_dict_from_fixture() - success, req_path = accountreq.save_account_request(self.configuration, - request_dict) + success, req_path = accountreq.save_account_request( + self.configuration, request_dict + ) arranged_req_id = os.path.basename(req_path) # NOTE: when using real user mail we currently hit send email errors. # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - success, message = accountreq.accept_account_req(arranged_req_id, - self.configuration, - keyword_auto) + success, message = accountreq.accept_account_req( + arranged_req_id, self.configuration, keyword_auto + ) self.assertTrue(success) @@ -160,29 +171,44 @@ def test_peer_acceptance(self): class MigSharedAccountreq__filters(MigTestCase, UserAssertMixin): """Unit tests for filter related functions within the accountreq module""" - TEST_SERVICE = 'migoid' - TEST_INTERNAL_DN = '/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=Test Name/emailAddress=test@local.org' - TEST_EXTERNAL_DN = '/C=DK/ST=NA/L=NA/O=External Org/OU=NA/CN=Test User/emailAddress=test@external.org' - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' - TEST_ADMIN_DN = '/C=DK/ST=NA/L=NA/O=DIKU/OU=NA/CN=Test Admin/emailAddress=siteadm@di.ku.dk' - - TEST_INT_PW = 'PW74deb6609F109f504d' - TEST_EXT_PW = 'PW174db6509F109e1531' - TEST_USER_PW = 'foobar' - TEST_INT_PW_HASH = 'PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$epib2rEg/HYTQZFnCp7hmIGZ6rzHnViy' - TEST_EXT_PW_HASH = 'PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$TQZFnCp7hmIGZ6ep2rEg/HYrzHnVyiib' - TEST_USER_PW_HASH = 'PBKDF2$sha256$10000$/TkhLk4yMGf6XhaY$7HUeQ9iwCkE4YMQAaCd+ZdrN+y8EzkJH' - - TEST_INTERNAL_EMAILS = ['john.doe@science.ku.dk', 'abc123@ku.dk', - 'john.doe@a.b.c.ku.dk'] - TEST_EXTERNAL_EMAILS = ['john@doe.org', 'a@b.c.org', 'a@ku.dk.com', - 'a@sci.ku.dk.org', 'a@diku.dk', 'a@nbi.dk'] - TEST_EXTERNAL_EMAIL_PATTERN = r'^.+(?@') + "connection_string": "user:password@db.example.com", + "other_field": "some_value", } + subst_map = {"connection_string": (r":.*@", r":@")} masked = mask_creds(user_dict, subst_map=subst_map) - self.assertEqual(masked['connection_string'], - 'user:@db.example.com') - self.assertEqual(masked['other_field'], 'some_value') + self.assertEqual( + masked["connection_string"], "user:@db.example.com" + ) + self.assertEqual(masked["other_field"], "some_value") def test_mask_creds_no_maskable_fields(self): """Test mask_creds with a dictionary containing no maskable fields.""" - user_dict = {'username': 'test', 'role': 'user'} + user_dict = {"username": "test", "role": "user"} masked = mask_creds(user_dict) self.assertEqual(user_dict, masked) @@ -230,55 +275,58 @@ def test_mask_creds_empty_dict(self): def test_mask_creds_csrf_field(self): """Test that the default csrf_field is masked.""" - user_dict = {csrf_field: 'some_csrf_token', 'other': 'value'} + user_dict = {csrf_field: "some_csrf_token", "other": "value"} masked = mask_creds(user_dict) - self.assertEqual(masked[csrf_field], '**HIDDEN**') - self.assertEqual(masked['other'], 'value') + self.assertEqual(masked[csrf_field], "**HIDDEN**") + self.assertEqual(masked["other"], "value") def test_extract_field_exists(self): """Test extracting an existing field from a distinguished name.""" - self.assertEqual(extract_field(TEST_USER_ID, 'full_name'), - TEST_FULL_NAME) - self.assertEqual(extract_field(TEST_USER_ID, 'organization'), - TEST_ORGANIZATION) - self.assertEqual(extract_field(TEST_USER_ID, 'country'), TEST_COUNTRY) - self.assertEqual(extract_field(TEST_USER_ID, 'email'), TEST_EMAIL) + self.assertEqual( + extract_field(TEST_USER_ID, "full_name"), TEST_FULL_NAME + ) + self.assertEqual( + extract_field(TEST_USER_ID, "organization"), TEST_ORGANIZATION + ) + self.assertEqual(extract_field(TEST_USER_ID, "country"), TEST_COUNTRY) + self.assertEqual(extract_field(TEST_USER_ID, "email"), TEST_EMAIL) def test_extract_field_not_exists(self): """Test extracting a non-existent field returns None.""" - self.assertIsNone(extract_field(TEST_USER_ID, 'missing')) - self.assertIsNone(extract_field(TEST_USER_ID, 'dummy')) + self.assertIsNone(extract_field(TEST_USER_ID, "missing")) + self.assertIsNone(extract_field(TEST_USER_ID, "dummy")) def test_extract_field_with_na_value(self): """Test extracting a field with 'NA' value, which should be an empty string.""" - self.assertEqual(extract_field('/C=DK/DUMMY=NA/CN=TEST', 'DUMMY'), '') + self.assertEqual(extract_field("/C=DK/DUMMY=NA/CN=TEST", "DUMMY"), "") def test_extract_field_custom_field(self): """Test extracting a custom (non-standard) field.""" - self.assertEqual(extract_field('/C=DK/DUMMY=proj1/CN=Test', 'DUMMY'), - 'proj1') + self.assertEqual( + extract_field("/C=DK/DUMMY=proj1/CN=Test", "DUMMY"), "proj1" + ) def test_extract_field_empty_dn(self): """Test extracting from an empty distinguished name.""" - self.assertIsNone(extract_field("", 'full_name')) + self.assertIsNone(extract_field("", "full_name")) def test_extract_field_malformed_dn(self): """Test extracting from a malformed distinguished name.""" dn_empty_val = "/C=US/O=/CN=John Doe" - self.assertEqual(extract_field(dn_empty_val, 'organization'), '') + self.assertEqual(extract_field(dn_empty_val, "organization"), "") dn_no_equals = "/C=US/O/CN=John Doe" - self.assertIsNone(extract_field(dn_no_equals, 'organization')) + self.assertIsNone(extract_field(dn_no_equals, "organization")) def test_distinguished_name_to_user_basic(self): """Test basic conversion from distinguished name to user dictionary.""" user_dict = distinguished_name_to_user(TEST_USER_ID) expected = { - 'distinguished_name': TEST_USER_ID, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, + "distinguished_name": TEST_USER_ID, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, } self.assertEqual(user_dict, expected) @@ -287,12 +335,12 @@ def test_distinguished_name_to_user_with_na(self): dn = "%s/dummy=NA" % TEST_USER_ID user_dict = distinguished_name_to_user(dn) expected = { - 'distinguished_name': dn, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'dummy': '' + "distinguished_name": dn, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "dummy": "", } self.assertEqual(user_dict, expected) @@ -301,28 +349,29 @@ def test_distinguished_name_to_user_with_custom_field(self): dn = "%s/dummy=proj1" % TEST_USER_ID user_dict = distinguished_name_to_user(dn) expected = { - 'distinguished_name': dn, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'dummy': 'proj1' + "distinguished_name": dn, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "dummy": "proj1", } self.assertEqual(user_dict, expected) def test_distinguished_name_to_user_empty_and_malformed(self): """Test behavior with empty and malformed distinguished names.""" # Empty DN - self.assertEqual(distinguished_name_to_user(""), - {'distinguished_name': ''}) + self.assertEqual( + distinguished_name_to_user(""), {"distinguished_name": ""} + ) # Malformed part (no '=') dn_malformed = "/C=US/O/CN=John Doe" user_dict_malformed = distinguished_name_to_user(dn_malformed) expected_malformed = { - 'distinguished_name': dn_malformed, - 'country': 'US', - 'full_name': TEST_FULL_NAME + "distinguished_name": dn_malformed, + "country": "US", + "full_name": TEST_FULL_NAME, } self.assertEqual(user_dict_malformed, expected_malformed) @@ -330,45 +379,49 @@ def test_distinguished_name_to_user_empty_and_malformed(self): dn_empty_val = "/C=DK/O=/CN=John Doe" user_dict_empty_val = distinguished_name_to_user(dn_empty_val) expected_empty_val = { - 'distinguished_name': dn_empty_val, - 'country': TEST_COUNTRY, - 'organization': '', - 'full_name': TEST_FULL_NAME + "distinguished_name": dn_empty_val, + "country": TEST_COUNTRY, + "organization": "", + "full_name": TEST_FULL_NAME, } self.assertEqual(user_dict_empty_val, expected_empty_val) def test_fill_distinguished_name_from_fields(self): """Test filling distinguished_name from other user fields.""" user = { - 'full_name': 'Jane Doe', - 'organization': 'Test Corp', - 'country': TEST_COUNTRY, - 'email': 'jane.doe@example.com' + "full_name": "Jane Doe", + "organization": "Test Corp", + "country": TEST_COUNTRY, + "email": "jane.doe@example.com", } fill_distinguished_name(user) - expected_dn = "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" \ - "/emailAddress=jane.doe@example.com" - self.assertEqual(user['distinguished_name'], expected_dn) + expected_dn = ( + "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" + "/emailAddress=jane.doe@example.com" + ) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_distinguished_name_with_gdp(self): """Test filling distinguished_name with a GDP project field.""" user = { - 'full_name': 'Jane Doe', - 'organization': 'Test Corp', - 'country': TEST_COUNTRY, - gdp_distinguished_field: 'project_x' + "full_name": "Jane Doe", + "organization": "Test Corp", + "country": TEST_COUNTRY, + gdp_distinguished_field: "project_x", } fill_distinguished_name(user) - expected_dn = "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" \ - "/emailAddress=NA/GDP=project_x" - self.assertEqual(user['distinguished_name'], expected_dn) + expected_dn = ( + "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" + "/emailAddress=NA/GDP=project_x" + ) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_distinguished_name_already_exists(self): """Test that an existing distinguished_name is not overwritten.""" user = { - 'distinguished_name': TEST_USER_ID, - 'full_name': 'Jane Doe', - 'country': 'US' + "distinguished_name": TEST_USER_ID, + "full_name": "Jane Doe", + "country": "US", } original_user = user.copy() returned_user = fill_distinguished_name(user) @@ -380,24 +433,21 @@ def test_fill_distinguished_name_empty_user(self): user = {} fill_distinguished_name(user) expected_dn = "/C=NA/ST=NA/L=NA/O=NA/OU=NA/CN=NA/emailAddress=NA" - self.assertEqual(user['distinguished_name'], expected_dn) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_user_completes_dict(self): """Test that fill_user adds missing fields and preserves existing ones.""" - user = { - 'full_name': TEST_FULL_NAME, - 'extra_field': 'extra_value' - } + user = {"full_name": TEST_FULL_NAME, "extra_field": "extra_value"} fill_user(user) # Check that existing values are preserved - self.assertEqual(user['full_name'], TEST_FULL_NAME) - self.assertEqual(user['extra_field'], 'extra_value') + self.assertEqual(user["full_name"], TEST_FULL_NAME) + self.assertEqual(user["extra_field"], "extra_value") # Check that missing standard fields are added with empty strings - self.assertEqual(user['organization'], '') - self.assertEqual(user['country'], '') + self.assertEqual(user["organization"], "") + self.assertEqual(user["country"], "") # Check that all standard keys are present for key, _ in cert_field_order: @@ -409,7 +459,7 @@ def test_fill_user_with_empty_dict(self): fill_user(user) self.assertEqual(len(user), len(cert_field_order)) for key, _ in cert_field_order: - self.assertEqual(user[key], '') + self.assertEqual(user[key], "") def test_fill_user_modifies_in_place_and_returns_self(self): """Test that fill_user modifies the dictionary in-place and returns @@ -421,158 +471,171 @@ def test_fill_user_modifies_in_place_and_returns_self(self): def test_canonical_user_transformations(self): """Test canonical_user applies all transformations correctly.""" user_dict = { - 'full_name': ' john doe ', - 'email': 'John@DoE.ORG', - 'country': 'dk', - 'state': 'vt', - 'organization': ' Test Org ', - 'extra_field': 'should be removed', - 'id': 123 + "full_name": " john doe ", + "email": "John@DoE.ORG", + "country": "dk", + "state": "vt", + "organization": " Test Org ", + "extra_field": "should be removed", + "id": 123, } - limit_fields = ['full_name', 'email', - 'country', 'state', 'organization', 'id'] + limit_fields = [ + "full_name", + "email", + "country", + "state", + "organization", + "id", + ] canonical = canonical_user(self.configuration, user_dict, limit_fields) expected = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'country': TEST_COUNTRY, - 'state': TEST_STATE, - 'organization': TEST_ORGANIZATION, - 'id': 123 + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "country": TEST_COUNTRY, + "state": TEST_STATE, + "organization": TEST_ORGANIZATION, + "id": 123, } self.assertEqual(canonical, expected) - self.assertNotIn('extra_field', canonical) + self.assertNotIn("extra_field", canonical) def test_canonical_user_with_peers_legacy(self): """Test canonical_user_with_peers with legacy peers list""" - self.configuration.site_peers_explicit_fields = ['email', 'full_name'] + self.configuration.site_peers_explicit_fields = ["email", "full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers": [ + "/C=DK/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_email'], - 'alice@example.com, bob@example.com') - self.assertEqual(canonical['peers_full_name'], 'Alice, Bob') + self.assertEqual( + canonical["peers_email"], "alice@example.com, bob@example.com" + ) + self.assertEqual(canonical["peers_full_name"], "Alice, Bob") def test_canonical_user_with_peers_explicit(self): """Test canonical_user_with_peers with explicit peers fields""" - self.configuration.site_peers_explicit_fields = ['email', 'full_name'] + self.configuration.site_peers_explicit_fields = ["email", "full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers_email': 'custom@example.com', - 'peers_full_name': 'Custom Name', - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers_email": "custom@example.com", + "peers_full_name": "Custom Name", + "peers": [ + "/C=DK/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_email'], 'custom@example.com') - self.assertEqual(canonical['peers_full_name'], 'Custom Name') + self.assertEqual(canonical["peers_email"], "custom@example.com") + self.assertEqual(canonical["peers_full_name"], "Custom Name") def test_canonical_user_with_peers_mixed(self): """Test canonical_user_with_peers with mixed explicit and legacy peers""" self.configuration.site_peers_explicit_fields = [ - 'email', 'organization'] + "email", + "organization", + ] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers_organization': TEST_ORGANIZATION, - 'peers': [ - '/C=DK/O=Legacy Org/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers_organization": TEST_ORGANIZATION, + "peers": [ + "/C=DK/O=Legacy Org/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email', 'organization'] + limit_fields = ["full_name", "email", "organization"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) # Explicit field should be preserved - self.assertEqual(canonical['peers_organization'], 'Test Org') + self.assertEqual(canonical["peers_organization"], "Test Org") # Legacy peers should be converted for email - self.assertEqual(canonical['peers_email'], - 'alice@example.com, bob@example.com') + self.assertEqual( + canonical["peers_email"], "alice@example.com, bob@example.com" + ) def test_canonical_user_with_peers_empty(self): """Test canonical_user_with_peers with no peers data""" - self.configuration.site_peers_explicit_fields = ['email'] - user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL - } - limit_fields = ['full_name', 'email'] + self.configuration.site_peers_explicit_fields = ["email"] + user_dict = {"full_name": TEST_FULL_NAME, "email": TEST_EMAIL} + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertNotIn('peers_email', canonical) - self.assertNotIn('peers', canonical) + self.assertNotIn("peers_email", canonical) + self.assertNotIn("peers", canonical) def test_canonical_user_with_peers_no_explicit_fields(self): """Test canonical_user_with_peers with no peer fields configured""" self.configuration.site_peers_explicit_fields = [] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': 'john@example.com', - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": "john@example.com", + "peers": ["/C=DK/CN=Alice/emailAddress=alice@example.com"], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) # Should not create any peer fields - self.assertNotIn('peers_email', canonical) - self.assertNotIn('peers_full_name', canonical) + self.assertNotIn("peers_email", canonical) + self.assertNotIn("peers_full_name", canonical) def test_canonical_user_with_peers_special_chars(self): """Test canonical_user_with_peers handles special characters in DNs""" - self.configuration.site_peers_explicit_fields = ['full_name'] + self.configuration.site_peers_explicit_fields = ["full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'peers': [ - '/C=DK/CN=Jérôme Müller', - '/C=DK/CN=O‘‘Reilly', - '/C=DK/CN=Alice "Ace" Smith' - ] + "full_name": TEST_FULL_NAME, + "peers": [ + "/C=DK/CN=Jérôme Müller", + "/C=DK/CN=O‘‘Reilly", + '/C=DK/CN=Alice "Ace" Smith', + ], } - limit_fields = ['full_name'] + limit_fields = ["full_name"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_full_name'], - 'Jérôme Müller, O‘‘Reilly, Alice "Ace" Smith') + self.assertEqual( + canonical["peers_full_name"], + 'Jérôme Müller, O‘‘Reilly, Alice "Ace" Smith', + ) def test_canonical_user_unicode_name(self): """Test canonical_user with unicode characters in full_name.""" # Using a name that title() might mess up without unicode conversion - user_dict = {'full_name': u'josé de la vega'} - limit_fields = ['full_name'] + user_dict = {"full_name": "josé de la vega"} + limit_fields = ["full_name"] canonical = canonical_user(self.configuration, user_dict, limit_fields) - self.assertEqual(canonical['full_name'], u'José De La Vega') + self.assertEqual(canonical["full_name"], "José De La Vega") def test_canonical_user_empty_input(self): """Test canonical_user with empty inputs.""" self.assertEqual(canonical_user(self.configuration, {}, []), {}) - self.assertEqual(canonical_user(self.configuration, {'a': 1}, []), {}) - self.assertEqual(canonical_user(self.configuration, {}, ['a']), {}) + self.assertEqual(canonical_user(self.configuration, {"a": 1}, []), {}) + self.assertEqual(canonical_user(self.configuration, {}, ["a"]), {}) def test_generate_https_urls_single_method_cgi(self): """Test generate_https_urls with a single method and cgi-bin.""" - self.configuration.site_login_methods = ['migcert'] + self.configuration.site_login_methods = ["migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" expected = "https://mig.cert/cgi-bin/script.py" result = generate_https_urls(self.configuration, template, {}) @@ -581,7 +644,7 @@ def test_generate_https_urls_single_method_cgi(self): def test_generate_https_urls_single_method_wsgi(self): """Test generate_https_urls with a single method and wsgi-bin.""" self.configuration.site_enable_wsgi = True - self.configuration.site_login_methods = ['migcert'] + self.configuration.site_login_methods = ["migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" expected = "https://mig.cert/wsgi-bin/script.py" result = generate_https_urls(self.configuration, template, {}) @@ -590,29 +653,32 @@ def test_generate_https_urls_single_method_wsgi(self): def test_generate_https_urls_multiple_methods(self): """Test generate_https_urls with multiple methods.""" template = "%(auto_base)s/%(auto_bin)s/script.py" - self.configuration.site_login_methods = ['migcert', 'extoidc'] + self.configuration.site_login_methods = ["migcert", "extoidc"] result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://mig.cert/cgi-bin/script.py" expected_url2 = "https://ext.oidc/cgi-bin/script.py" expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_generate_https_urls_with_helper_dict(self): """Test generate_https_urls with a helper_dict.""" - self.configuration.site_login_methods = ['extoid'] + self.configuration.site_login_methods = ["extoid"] template = "%(auto_base)s/%(auto_bin)s/%(script)s" - helper = {'script': 'login.py'} + helper = {"script": "login.py"} result = generate_https_urls(self.configuration, template, helper) self.assertEqual(result, "https://ext.oid/cgi-bin/login.py") def test_generate_https_urls_method_enabled_but_url_missing(self): """Test that methods with no configured URL are skipped.""" self.configuration.migserver_https_ext_cert_url = "" # URL is missing - self.configuration.site_login_methods = ['migcert', 'extcert'] + self.configuration.site_login_methods = ["migcert", "extcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) self.assertEqual(result, "https://mig.cert/cgi-bin/script.py") @@ -626,7 +692,7 @@ def test_generate_https_urls_no_methods_enabled(self): def test_generate_https_urls_respects_order(self): """Test that the order of site_login_methods is respected.""" - self.configuration.site_login_methods = ['extoidc', 'migcert'] + self.configuration.site_login_methods = ["extoidc", "migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://ext.oidc/cgi-bin/script.py" @@ -634,14 +700,20 @@ def test_generate_https_urls_respects_order(self): expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_generate_https_urls_avoids_duplicates(self): """Test that duplicate URLs are not generated.""" self.configuration.site_login_methods = [ - 'migcert', 'extoidc', 'migcert'] + "migcert", + "extoidc", + "migcert", + ] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://mig.cert/cgi-bin/script.py" @@ -649,22 +721,35 @@ def test_generate_https_urls_avoids_duplicates(self): expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_auth_type_description_all(self): """Test auth_type_description returns full dict when requested""" from mig.shared.defaults import keyword_all + result = auth_type_description(self.configuration, keyword_all) - expected_keys = ['migoid', 'migoidc', 'migcert', 'extoid', 'extoidc', - 'extcert'] + expected_keys = [ + "migoid", + "migoidc", + "migcert", + "extoid", + "extoidc", + "extcert", + ] self.assertEqual(sorted(result.keys()), sorted(expected_keys)) def test_auth_type_description_individual(self): """Test auth_type_description returns expected strings for each type""" - from mig.shared.defaults import AUTH_CERTIFICATE, AUTH_OPENID_CONNECT, \ - AUTH_OPENID_V2 + from mig.shared.defaults import ( + AUTH_CERTIFICATE, + AUTH_OPENID_CONNECT, + AUTH_OPENID_V2, + ) # Setup titles in configuration self.configuration.user_mig_oid_title = "MiG OpenID" @@ -673,31 +758,37 @@ def test_auth_type_description_individual(self): self.configuration.user_ext_cert_title = "External Certificate" test_cases = [ - ('migoid', 'MiG OpenID %s' % AUTH_OPENID_V2), - ('migoidc', 'MiG OpenID %s' % AUTH_OPENID_CONNECT), - ('migcert', 'MiG Certificate %s' % AUTH_CERTIFICATE), - ('extoid', 'External OpenID %s' % AUTH_OPENID_V2), - ('extoidc', 'External OpenID %s' % AUTH_OPENID_CONNECT), - ('extcert', 'External Certificate %s' % AUTH_CERTIFICATE), + ("migoid", "MiG OpenID %s" % AUTH_OPENID_V2), + ("migoidc", "MiG OpenID %s" % AUTH_OPENID_CONNECT), + ("migcert", "MiG Certificate %s" % AUTH_CERTIFICATE), + ("extoid", "External OpenID %s" % AUTH_OPENID_V2), + ("extoidc", "External OpenID %s" % AUTH_OPENID_CONNECT), + ("extcert", "External Certificate %s" % AUTH_CERTIFICATE), ] - for (auth_type, expected) in test_cases: + for auth_type, expected in test_cases: result = auth_type_description(self.configuration, auth_type) self.assertEqual(result, expected) def test_auth_type_description_unknown(self): """Test auth_type_description returns 'UNKNOWN' for invalid types""" - self.assertEqual(auth_type_description(self.configuration, 'invalid'), - 'UNKNOWN') - self.assertEqual(auth_type_description( - self.configuration, ''), 'UNKNOWN') - self.assertEqual(auth_type_description( - self.configuration, None), 'UNKNOWN') + self.assertEqual( + auth_type_description(self.configuration, "invalid"), "UNKNOWN" + ) + self.assertEqual( + auth_type_description(self.configuration, ""), "UNKNOWN" + ) + self.assertEqual( + auth_type_description(self.configuration, None), "UNKNOWN" + ) def test_auth_type_description_empty_titles(self): """Test auth_type_description handles empty titles in configuration""" - from mig.shared.defaults import AUTH_CERTIFICATE, AUTH_OPENID_CONNECT, \ - AUTH_OPENID_V2 + from mig.shared.defaults import ( + AUTH_CERTIFICATE, + AUTH_OPENID_CONNECT, + AUTH_OPENID_V2, + ) self.configuration.user_mig_oid_title = "" self.configuration.user_mig_cert_title = "" @@ -705,15 +796,15 @@ def test_auth_type_description_empty_titles(self): self.configuration.user_ext_cert_title = "" test_cases = [ - ('migoid', ' %s' % AUTH_OPENID_V2), - ('migoidc', ' %s' % AUTH_OPENID_CONNECT), - ('migcert', ' %s' % AUTH_CERTIFICATE), - ('extoid', ' %s' % AUTH_OPENID_V2), - ('extoidc', ' %s' % AUTH_OPENID_CONNECT), - ('extcert', ' %s' % AUTH_CERTIFICATE), + ("migoid", " %s" % AUTH_OPENID_V2), + ("migoidc", " %s" % AUTH_OPENID_CONNECT), + ("migcert", " %s" % AUTH_CERTIFICATE), + ("extoid", " %s" % AUTH_OPENID_V2), + ("extoidc", " %s" % AUTH_OPENID_CONNECT), + ("extcert", " %s" % AUTH_CERTIFICATE), ] - for (auth_type, expected) in test_cases: + for auth_type, expected in test_cases: result = auth_type_description(self.configuration, auth_type) self.assertEqual(result, expected) @@ -721,11 +812,16 @@ def test_allow_script_gdp_enabled_anonymous_allowed(self): """Test allow_script with GDP enabled, anonymous user, and script allowed.""" self.configuration.site_enable_gdp = True - script_name = valid_gdp_anon_scripts[0] if valid_gdp_anon_scripts \ - else 'allowed_script.py' # Use a valid script or a default + script_name = ( + valid_gdp_anon_scripts[0] + if valid_gdp_anon_scripts + else "allowed_script.py" + ) # Use a valid script or a default if not valid_gdp_anon_scripts: - print("WARNING: valid_gdp_anon_scripts is empty. Using " - "'allowed_script.py' which may cause a test failure.") + print( + "WARNING: valid_gdp_anon_scripts is empty. Using " + "'allowed_script.py' which may cause a test failure." + ) allow, msg = allow_script(self.configuration, script_name, None) self.assertTrue(allow) self.assertEqual(msg, "") @@ -734,29 +830,41 @@ def test_allow_script_gdp_enabled_anonymous_disallowed(self): """Test allow_script with GDP enabled, anonymous user, and script disallowed.""" self.configuration.site_enable_gdp = True - script_name = 'disallowed_script.py' + script_name = "disallowed_script.py" # Ensure the script is not in valid_gdp_anon_scripts if script_name in valid_gdp_anon_scripts: valid_gdp_anon_scripts.remove(script_name) allow, msg = allow_script(self.configuration, script_name, None) self.assertFalse(allow) - self.assertEqual(msg, "anonoymous access to functionality disabled " - "by site configuration!") + self.assertEqual( + msg, + "anonoymous access to functionality disabled " + "by site configuration!", + ) def test_allow_script_gdp_enabled_authenticated_allowed(self): """Test allow_script with GDP enabled, authenticated user, and script allowed.""" self.configuration.site_enable_gdp = True - script_name = valid_gdp_auth_scripts[0] if valid_gdp_auth_scripts \ - else valid_gdp_anon_scripts[0] if valid_gdp_anon_scripts \ - else 'allowed_script.py' + script_name = ( + valid_gdp_auth_scripts[0] + if valid_gdp_auth_scripts + else ( + valid_gdp_anon_scripts[0] + if valid_gdp_anon_scripts + else "allowed_script.py" + ) + ) if not valid_gdp_auth_scripts and not valid_gdp_anon_scripts: - print("WARNING: valid_gdp_auth_scripts and " - "valid_gdp_anon_scripts are empty. Using " - "'allowed_script.py' which may cause a test failure.") + print( + "WARNING: valid_gdp_auth_scripts and " + "valid_gdp_anon_scripts are empty. Using " + "'allowed_script.py' which may cause a test failure." + ) allow, msg = allow_script( - self.configuration, script_name, 'test_client') + self.configuration, script_name, "test_client" + ) self.assertTrue(allow) self.assertEqual(msg, "") @@ -764,7 +872,7 @@ def test_allow_script_gdp_enabled_authenticated_disallowed(self): """Test allow_script with GDP enabled, authenticated user, and script disallowed.""" self.configuration.site_enable_gdp = True - script_name = 'disallowed_script.py' + script_name = "disallowed_script.py" # Ensure the script is not in valid_gdp_auth_scripts or # valid_gdp_anon_scripts @@ -774,98 +882,108 @@ def test_allow_script_gdp_enabled_authenticated_disallowed(self): valid_gdp_anon_scripts.remove(script_name) allow, msg = allow_script( - self.configuration, script_name, 'test_client') + self.configuration, script_name, "test_client" + ) self.assertFalse(allow) - self.assertEqual(msg, "all access to functionality disabled by site " - "configuration!") + self.assertEqual( + msg, + "all access to functionality disabled by site " "configuration!", + ) def test_allow_script_gdp_disabled(self): """Test allow_script with GDP disabled.""" self.configuration.site_enable_gdp = False - allow, msg = allow_script(self.configuration, 'any_script.py', - 'test_client') + allow, msg = allow_script( + self.configuration, "any_script.py", "test_client" + ) self.assertTrue(allow) self.assertEqual(msg, "") def test_allow_script_gdp_disabled_anonymous(self): """Test allow_script with GDP disabled and anonymous user.""" self.configuration.site_enable_gdp = False - allow, msg = allow_script(self.configuration, 'any_script.py', None) + allow, msg = allow_script(self.configuration, "any_script.py", None) self.assertTrue(allow) self.assertEqual(msg, "") def test_requested_page_normal(self): """Test requested_page with basic environment""" fake_env = { - 'SCRIPT_NAME': '/cgi-bin/home.py', - 'REQUEST_URI': '/cgi-bin/home.py' + "SCRIPT_NAME": "/cgi-bin/home.py", + "REQUEST_URI": "/cgi-bin/home.py", } - self.assertEqual(requested_page(fake_env), '/cgi-bin/home.py') + self.assertEqual(requested_page(fake_env), "/cgi-bin/home.py") def test_requested_page_name_only(self): """Test requested_page with name_only argument""" fake_env = { - 'BACKEND_NAME': 'search.py', - 'PATH_INFO': '/cgi-bin/search.py/path' + "BACKEND_NAME": "search.py", + "PATH_INFO": "/cgi-bin/search.py/path", } - result = requested_page(fake_env, name_only=True, fallback='fallback') - self.assertEqual(result, 'search.py') + result = requested_page(fake_env, name_only=True, fallback="fallback") + self.assertEqual(result, "search.py") def test_requested_page_strip_extension(self): """Test requested_page with strip_ext argument""" - fake_env = {'REQUEST_URI': '/cgi-bin/file.py?query=val'} + fake_env = {"REQUEST_URI": "/cgi-bin/file.py?query=val"} result = requested_page(fake_env, strip_ext=True) - self.assertEqual(result, '/cgi-bin/file') + self.assertEqual(result, "/cgi-bin/file") def test_requested_page_priority(self): """Test environment variable priority order""" _init_env = { - 'BACKEND_NAME': 'backend.py', - 'SCRIPT_URL': '/cgi-bin/script_url.py', - 'SCRIPT_URI': 'https://host/cgi-bin/script_uri.py', - 'PATH_INFO': '/cgi-bin/path_info.py', - 'REQUEST_URI': '/cgi-bin/req_uri.py' + "BACKEND_NAME": "backend.py", + "SCRIPT_URL": "/cgi-bin/script_url.py", + "SCRIPT_URI": "https://host/cgi-bin/script_uri.py", + "PATH_INFO": "/cgi-bin/path_info.py", + "REQUEST_URI": "/cgi-bin/req_uri.py", } - priority_order = ['BACKEND_NAME', 'SCRIPT_URL', 'SCRIPT_URI', - 'PATH_INFO', 'REQUEST_URI'] + priority_order = [ + "BACKEND_NAME", + "SCRIPT_URL", + "SCRIPT_URI", + "PATH_INFO", + "REQUEST_URI", + ] for var in priority_order: # Reset fake_env each time fake_env = dict([pair for pair in _init_env.items()]) current_env = {var: fake_env[var]} - if var != 'SCRIPT_URI': + if var != "SCRIPT_URI": expected = fake_env[var] else: - expected = 'https://host/cgi-bin/script_uri.py' + expected = "https://host/cgi-bin/script_uri.py" result = requested_page(current_env) - self.assertEqual(result, expected, - "failed priority for %s" % var) + self.assertEqual(result, expected, "failed priority for %s" % var) # Remove higher priority variables one by one - for higher_var in priority_order[:priority_order.index(var)]: + for higher_var in priority_order[: priority_order.index(var)]: del fake_env[higher_var] result = requested_page(fake_env) - self.assertEqual(result, fake_env[var], - "failed fallthrough to %s" % var) + self.assertEqual( + result, fake_env[var], "failed fallthrough to %s" % var + ) def test_requested_page_unsafe_filter(self): """Test unsafe character filtering""" test_cases = [ - ('/cgi-bin/unsafe.py' - fake_env = {'REQUEST_URI': dangerous} + dangerous = "/cgi-bin/unsafe.py" + fake_env = {"REQUEST_URI": dangerous} unsafe_result = requested_page(fake_env, include_unsafe=True) self.assertEqual(unsafe_result, dangerous) @@ -874,39 +992,39 @@ def test_requested_page_include_unsafe(self): def test_requested_page_query_stripping(self): """Test removal of query parameters""" - test_input = '/cgi-bin/script.py?query=value¶m=data' - fake_env = {'REQUEST_URI': test_input} + test_input = "/cgi-bin/script.py?query=value¶m=data" + fake_env = {"REQUEST_URI": test_input} result = requested_page(fake_env) - self.assertEqual(result, '/cgi-bin/script.py') + self.assertEqual(result, "/cgi-bin/script.py") def test_requested_page_fallback(self): """Test fallback to default""" fake_env = {} - fallback = 'special.py' + fallback = "special.py" result = requested_page(fake_env, fallback=fallback) self.assertEqual(result, fallback) def test_requested_page_fallback_despite_os_environ_value(self): """Test fallback to default""" fake_env = {} - fallback = 'special.py' - os.environ['BACKEND_NAME'] = 'BOGUS' + fallback = "special.py" + os.environ["BACKEND_NAME"] = "BOGUS" result = requested_page(fake_env, fallback=fallback) - del os.environ['BACKEND_NAME'] + del os.environ["BACKEND_NAME"] self.assertEqual(result, fallback) def test_requested_url_base_normal(self): """Test requested_url_base with basic complete URL""" - fake_env = {'SCRIPT_URI': 'https://example.com/path/to/script.py'} + fake_env = {"SCRIPT_URI": "https://example.com/path/to/script.py"} result = requested_url_base(fake_env) - expected = 'https://example.com' + expected = "https://example.com" self.assertEqual(result, expected) def test_requested_url_base_custom_field(self): """Test requested_url_base with custom uri_field parameter""" - fake_env = {'CUSTOM_FIELD_URI': 'http://server.org:8001/base/'} - result = requested_url_base(fake_env, uri_field='CUSTOM_FIELD_URI') - expected = 'http://server.org:8001' + fake_env = {"CUSTOM_FIELD_URI": "http://server.org:8001/base/"} + result = requested_url_base(fake_env, uri_field="CUSTOM_FIELD_URI") + expected = "http://server.org:8001" self.assertEqual(result, expected) # TODO: adjust tested function to bail out on missing uri_field @@ -915,56 +1033,55 @@ def test_requested_url_base_missing(self): """Test requested_url_base when uri_field not present""" fake_env = {} result = requested_url_base(fake_env) - self.assertEqual(result, '') + self.assertEqual(result, "") def test_requested_url_base_safe_filter(self): """Test unsafe character filtering in url base""" - test_url = 'https://user:pass@evil.com/' - fake_env = {'SCRIPT_URI': test_url} + test_url = "https://user:pass@evil.com/" + fake_env = {"SCRIPT_URI": test_url} safe_result = requested_url_base(fake_env) - expected_safe = 'https://user:passevil.com' + expected_safe = "https://user:passevil.com" self.assertEqual(safe_result, expected_safe) def test_requested_url_base_include_unsafe(self): """Test include_unsafe argument behavior""" - test_url = 'http://[::1]?' - fake_env = {'SCRIPT_URI': test_url} + test_url = "http://[::1]?" + fake_env = {"SCRIPT_URI": test_url} unsafe_result = requested_url_base(fake_env, include_unsafe=True) - self.assertEqual(unsafe_result, 'http://[::1]?') + self.assertEqual(unsafe_result, "http://[::1]?") safe_result = requested_url_base(fake_env, include_unsafe=False) - self.assertEqual(safe_result, 'http://::1markup') + self.assertEqual(safe_result, "http://::1markup") safe_result = requested_url_base(fake_env) - self.assertEqual(safe_result, 'http://::1markup') + self.assertEqual(safe_result, "http://::1markup") def test_requested_url_base_split_valid_edge_cases(self): """Test URL base splitting on valid edge cases""" test_cases = [ - ('https://site.com', 'https://site.com'), - ('http://a/single/slash', 'http://a'), - ('file:///absolute/path', 'file://'), - ('invalid.proto://double/slash', 'invalid.proto://double') + ("https://site.com", "https://site.com"), + ("http://a/single/slash", "http://a"), + ("file:///absolute/path", "file://"), + ("invalid.proto://double/slash", "invalid.proto://double"), ] - for (input_url, expected) in test_cases: - fake_env = {'SCRIPT_URI': input_url} + for input_url, expected in test_cases: + fake_env = {"SCRIPT_URI": input_url} result = requested_url_base(fake_env) - self.assertEqual(result, expected, - "failed for %s" % input_url) + self.assertEqual(result, expected, "failed for %s" % input_url) # TODO: adjust function to bail out on invalid URLs and enable next @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_split_invalid_edge_cases(self): """Test URL base splitting on invalid edge cases""" test_cases = [ - ('', ''), - ('/', '/'), - ('/single', '/single'), - ('/double/slash', '/double/slash'), - ('invalid.proto:/1st/2nd/slash', 'invalid.proto:/1st/2nd/slash'), - ('invalid.proto://double/slash', 'invalid.proto://double') + ("", ""), + ("/", "/"), + ("/single", "/single"), + ("/double/slash", "/double/slash"), + ("invalid.proto:/1st/2nd/slash", "invalid.proto:/1st/2nd/slash"), + ("invalid.proto://double/slash", "invalid.proto://double"), ] - for (input_url, expected) in test_cases: - fake_env = {'SCRIPT_URI': input_url} + for input_url, expected in test_cases: + fake_env = {"SCRIPT_URI": input_url} try: result = requested_url_base(fake_env) except ValueError: @@ -975,7 +1092,7 @@ def test_requested_url_base_split_invalid_edge_cases(self): @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_relative_path(self): """Test relative paths in URL""" - fake_env = {'SCRIPT_URI': '/cgi-bin/script.py'} + fake_env = {"SCRIPT_URI": "/cgi-bin/script.py"} try: result = requested_url_base(fake_env) except ValueError: @@ -986,10 +1103,10 @@ def test_requested_url_base_relative_path(self): @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_special_chars(self): """Test handling of special characters in URL""" - test_url = 'http://üñîçøðê.net/path' - fake_env = {'SCRIPT_URI': test_url} + test_url = "http://üñîçøðê.net/path" + fake_env = {"SCRIPT_URI": test_url} result = requested_url_base(fake_env) - self.assertEqual(result, 'http://üñîçøðê.net') + self.assertEqual(result, "http://üñîçøðê.net") def test_verify_local_url_direct_match(self): """Test verify_local_url with direct match to known site URL""" @@ -1005,7 +1122,9 @@ def test_verify_local_url_subpath_match(self): def test_verify_local_url_public_alias(self): """Test verify_local_url with public alias domain""" - self.configuration.migserver_public_alias_url = "https://grid.example.org" + self.configuration.migserver_public_alias_url = ( + "https://grid.example.org" + ) test_url = "https://grid.example.org/cgi-bin/file.py" self.assertTrue(verify_local_url(self.configuration, test_url)) @@ -1017,103 +1136,122 @@ def test_verify_local_url_absolute_path(self): def test_verify_local_url_relative_path(self): """Test verify_local_url with relative path""" test_url = "subdir/script.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_external_domain(self): """Test verify_local_url rejects external domains""" test_url = "https://evil.com/malicious.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_invalid_url(self): """Test verify_local_url rejects invalid/malformed URLs""" test_url = "javascript:alert('xss')" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_missing_https(self): """Test verify_local_url with HTTP when only HTTPS supported""" test_url = "http://mig.cert/cgi-bin/home.py" self.configuration.migserver_https_mig_cert_url = "https://mig.cert" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_different_port(self): """Test verify_local_url rejects same domain with different port""" self.configuration.migserver_https_ext_cert_url = "https://ext.cert:443" test_url = "https://ext.cert:444/cgi-bin/file.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_invisible_path_file(self): """Test invisible_path detects names in invisible files""" - invisible_filename = '.htaccess' - visible_filename = 'visible.txt' + invisible_filename = ".htaccess" + visible_filename = "visible.txt" # Test with invisible filename - self.assertTrue(invisible_path('/some/path/%s' % invisible_filename)) + self.assertTrue(invisible_path("/some/path/%s" % invisible_filename)) self.assertTrue(invisible_path(invisible_filename)) self.assertTrue(invisible_path(invisible_filename, True)) # Test with visible filename self.assertFalse(invisible_path(visible_filename)) - self.assertFalse(invisible_path('/some/path/%s' % visible_filename)) + self.assertFalse(invisible_path("/some/path/%s" % visible_filename)) def test_invisible_path_dir(self): """Test invisible_path detects paths in invisible dir""" - invisible_dirname = '.vgridscm' - visible_dirname = 'somedir' + invisible_dirname = ".vgridscm" + visible_dirname = "somedir" # Test different forms of invisible directory path self.assertTrue(invisible_path(invisible_dirname)) - self.assertTrue(invisible_path('/%s' % invisible_dirname)) - self.assertTrue(invisible_path('/parent/%s' % invisible_dirname)) - self.assertTrue(invisible_path('%s/sub' % invisible_dirname)) - self.assertTrue(invisible_path('/%s/file' % invisible_dirname)) + self.assertTrue(invisible_path("/%s" % invisible_dirname)) + self.assertTrue(invisible_path("/parent/%s" % invisible_dirname)) + self.assertTrue(invisible_path("%s/sub" % invisible_dirname)) + self.assertTrue(invisible_path("/%s/file" % invisible_dirname)) # Test visible directory self.assertFalse(invisible_path(visible_dirname)) - self.assertFalse(invisible_path('/%s' % visible_dirname)) - self.assertFalse(invisible_path('/parent/%s' % visible_dirname)) + self.assertFalse(invisible_path("/%s" % visible_dirname)) + self.assertFalse(invisible_path("/parent/%s" % visible_dirname)) def test_invisible_path_vgrid_exception(self): """Test allow_vgrid_scripts excludes valid vgrid xgi scripts""" - invisible_dirname = '.vgridscm' - vgrid_script = '.vgridscm/cgi-bin/hgweb.cgi' + invisible_dirname = ".vgridscm" + vgrid_script = ".vgridscm/cgi-bin/hgweb.cgi" test_patterns = [ - '/%s/%s' % (invisible_dirname, vgrid_script), - '/root/%s/sub/%s' % (invisible_dirname, vgrid_script), - '/%s/prefix%ssuffix' % (invisible_dirname, vgrid_script), - '/%s/similar_script.py' % invisible_dirname, - '/path/to/%s' % vgrid_script - + "/%s/%s" % (invisible_dirname, vgrid_script), + "/root/%s/sub/%s" % (invisible_dirname, vgrid_script), + "/%s/prefix%ssuffix" % (invisible_dirname, vgrid_script), + "/%s/similar_script.py" % invisible_dirname, + "/path/to/%s" % vgrid_script, ] test_expects = [False, False, False, True, False] test_iter = zip(test_patterns, test_expects) - for (i, (path, expected)) in enumerate(test_iter): + for i, (path, expected) in enumerate(test_iter): self.assertEqual( invisible_path(path, allow_vgrid_scripts=True), expected, "test case %d: path %r should %sbe invisible with scripts" - % (i, path, "" if expected else "not ") + % (i, path, "" if expected else "not "), ) # Should still be invisible without exception flag @@ -1122,23 +1260,24 @@ def test_invisible_path_vgrid_exception(self): invisible_path(path, allow_vgrid_scripts=False), expect_no_exception, "test case %d: path %r should %sbe invisible without scripts" - % (i, path, "" if expect_no_exception else "not ") + % (i, path, "" if expect_no_exception else "not "), ) def test_invisible_path_edge_cases(self): """Test invisible_path handles edge cases""" from mig.shared.defaults import _user_invisible_dirs + invisible_dirname = _user_invisible_dirs[0] # Empty path - self.assertFalse(invisible_path('')) - self.assertFalse(invisible_path('', allow_vgrid_scripts=True)) + self.assertFalse(invisible_path("")) + self.assertFalse(invisible_path("", allow_vgrid_scripts=True)) # Root path - self.assertFalse(invisible_path('/')) + self.assertFalse(invisible_path("/")) # Path that only contains invisible directory substring - substring_path = '/prefix%ssuffix/file' % invisible_dirname + substring_path = "/prefix%ssuffix/file" % invisible_dirname self.assertFalse(invisible_path(substring_path)) def test_client_alias(self): @@ -1176,21 +1315,21 @@ def test_get_short_id_with_gdp(self): def test_get_user_id_x509_format(self): """Test get_user_id returns DN for X509 format""" self.configuration.site_user_id_format = "X509" - user = {'distinguished_name': TEST_USER_ID} + user = {"distinguished_name": TEST_USER_ID} result = get_user_id(self.configuration, user) self.assertEqual(result, TEST_USER_ID) def test_get_user_id_uuid_format(self): """Test get_user_id returns UUID when configured""" self.configuration.site_user_id_format = "UUID" - user = {'unique_id': "123e4567-e89b-12d3-a456-426614174000"} + user = {"unique_id": "123e4567-e89b-12d3-a456-426614174000"} result = get_user_id(self.configuration, user) self.assertEqual(result, "123e4567-e89b-12d3-a456-426614174000") def test_get_client_id(self): """Test get_client_id extracts DN from user dict""" test_dn = "/C=US/CN=Alice" - user = {'distinguished_name': test_dn, 'other': 'field'} + user = {"distinguished_name": test_dn, "other": "field"} result = get_client_id(user) self.assertEqual(result, test_dn) @@ -1204,14 +1343,8 @@ def test_hexlify_unhexlify_roundtrip(self): def test_is_gdp_user_detection(self): """Test is_gdp_user detects GDP project presence""" - self.assertTrue(is_gdp_user( - self.configuration, - "/GDP_PROJ=12345" - )) - self.assertFalse(is_gdp_user( - self.configuration, - "/CN=Regular User" - )) + self.assertTrue(is_gdp_user(self.configuration, "/GDP_PROJ=12345")) + self.assertFalse(is_gdp_user(self.configuration, "/CN=Regular User")) def test_sandbox_resource_identification(self): """Test sandbox_resource identifies sandboxes""" @@ -1235,8 +1368,8 @@ def test_invisible_dir_detection(self): def test_requested_backend_extraction(self): """Test requested_backend extracts backend name from environ""" test_env = { - 'BACKEND_NAME': '/cgi-bin/fileman.py', - 'PATH_TRANSLATED': '/wsgi-bin/fileman.py' + "BACKEND_NAME": "/cgi-bin/fileman.py", + "PATH_TRANSLATED": "/wsgi-bin/fileman.py", } result = requested_backend(test_env) self.assertEqual(result, "fileman") @@ -1258,15 +1391,16 @@ def test_get_xgi_bin_wsgi_vs_cgi(self): """Test get_xgi_bin returns correct script bin based on config""" # Test WSGI enabled self.configuration.site_enable_wsgi = True - self.assertEqual(get_xgi_bin(self.configuration), 'wsgi-bin') + self.assertEqual(get_xgi_bin(self.configuration), "wsgi-bin") # Test WSGI disabled self.configuration.site_enable_wsgi = False - self.assertEqual(get_xgi_bin(self.configuration), 'cgi-bin') + self.assertEqual(get_xgi_bin(self.configuration), "cgi-bin") # Test legacy force - self.assertEqual(get_xgi_bin(self.configuration, force_legacy=True), - 'cgi-bin') + self.assertEqual( + get_xgi_bin(self.configuration, force_legacy=True), "cgi-bin" + ) def test_valid_dir_input(self): """Test valid_dir_input prevents path traversal attempts""" @@ -1277,11 +1411,11 @@ def test_valid_dir_input(self): ("../illegal", False), ("/absolute", False), ] - for (relative_path, expected) in test_cases: + for relative_path, expected in test_cases: self.assertEqual( valid_dir_input(base, relative_path), expected, - "failed for %s" % relative_path + "failed for %s" % relative_path, ) def test_user_base_dir(self): @@ -1306,32 +1440,56 @@ def test_brief_list(self): # List longer than max_entries gets shortened long_list = list(range(15)) - expected_long = [0, 1, 2, 3, 4, ' ... shortened ... ', 10, 11, 12, - 13, 14] + expected_long = [ + 0, + 1, + 2, + 3, + 4, + " ... shortened ... ", + 10, + 11, + 12, + 13, + 14, + ] self.assertEqual(brief_list(long_list), expected_long) # Custom max_entries with odd number custom_odd_list = list(range(10)) - expected_odd = [0, 1, 2, 3, ' ... shortened ... ', 6, 7, 8, 9] + expected_odd = [0, 1, 2, 3, " ... shortened ... ", 6, 7, 8, 9] self.assertEqual(brief_list(custom_odd_list, 9), expected_odd) # Range objects should be handled properly input_range = range(20) - expected_range = [0, 1, 2, 3, 4, ' ... shortened ... ', 15, 16, 17, - 18, 19] + expected_range = [ + 0, + 1, + 2, + 3, + 4, + " ... shortened ... ", + 15, + 16, + 17, + 18, + 19, + ] self.assertEqual(brief_list(input_range), expected_range) # Edge case - max_entries=2 - self.assertEqual(brief_list([1, 2, 3, 4], 2), - [1, ' ... shortened ... ', 4]) + self.assertEqual( + brief_list([1, 2, 3, 4], 2), [1, " ... shortened ... ", 4] + ) # Very small max_entries - self.assertEqual(brief_list([1, 2, 3, 4], 3), - [1, ' ... shortened ... ', 4]) + self.assertEqual( + brief_list([1, 2, 3, 4], 3), [1, " ... shortened ... ", 4] + ) # Non-integer input - str_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - expected_str = ['a', 'b', 'c', ' ... shortened ... ', 'e', 'f', 'g'] + str_list = ["a", "b", "c", "d", "e", "f", "g"] + expected_str = ["a", "b", "c", " ... shortened ... ", "e", "f", "g"] self.assertEqual(brief_list(str_list, 7), str_list) # At max_entries # TODO: fix tested function to handle these and enable test @@ -1339,12 +1497,14 @@ def test_brief_list(self): def test_brief_list_edge_cases(self): """Test brief_list helper function for compact list on edge cases""" # Edge case - max_entries=1 - self.assertEqual(brief_list([1, 2, 3], 1), [' ... shortened ... ']) + self.assertEqual(brief_list([1, 2, 3], 1), [" ... shortened ... "]) # Edge case - even short number of max_entries - str_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - self.assertEqual(brief_list(str_list, 6), - ['a', 'b', ' ... shortened ... ', 'f', 'g']) + str_list = ["a", "b", "c", "d", "e", "f", "g"] + self.assertEqual( + brief_list(str_list, 6), + ["a", "b", " ... shortened ... ", "f", "g"], + ) class TestMigSharedBase__legacy(MigTestCase): @@ -1353,15 +1513,17 @@ class TestMigSharedBase__legacy(MigTestCase): # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): """Run built-in self-tests and check output""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % - (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): diff --git a/tests/test_mig_shared_cloud.py b/tests/test_mig_shared_cloud.py index 241a79970..00e473b99 100644 --- a/tests/test_mig_shared_cloud.py +++ b/tests/test_mig_shared_cloud.py @@ -34,63 +34,74 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.cloud import cloud_load_instance, cloud_save_instance, \ - allowed_cloud_images - -DUMMY_USER = 'dummy-user' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +from mig.shared.cloud import ( + allowed_cloud_images, + cloud_load_instance, + cloud_save_instance, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) + +DUMMY_USER = "dummy-user" +DUMMY_SETTINGS_DIR = "dummy_user_settings" DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) DUMMY_CLOUD = "CLOUD" -DUMMY_FLAVOR = 'openstack' -DUMMY_LABEL = 'dummy-label' -DUMMY_IMAGE = 'dummy-image' -DUMMY_HEX_ID = 'deadbeef-dead-beef-dead-beefdeadbeef' - -DUMMY_CLOUD_SPEC = {'service_title': 'CLOUDTITLE', 'service_name': 'CLOUDNAME', - 'service_desc': 'A Cloud for migrid site', - 'service_provider_flavor': 'openstack', - 'service_hosts': 'https://myopenstack-cloud.org:5000/v3', - 'service_rules_of_conduct': 'rules-of-conduct.pdf', - 'service_max_user_instances': '0', - 'service_max_user_instances_map': {DUMMY_USER: '1'}, - 'service_allowed_images': DUMMY_IMAGE, - 'service_allowed_images_map': {DUMMY_USER: 'ALL'}, - 'service_user_map': {DUMMY_IMAGE, 'user'}, - 'service_image_alias_map': {DUMMY_IMAGE.lower(): - DUMMY_IMAGE}, - 'service_flavor_id': DUMMY_HEX_ID, - 'service_flavor_id_map': {DUMMY_USER: DUMMY_HEX_ID}, - 'service_network_id': DUMMY_HEX_ID, - 'service_key_id_map': {}, - 'service_sec_group_id': DUMMY_HEX_ID, - 'service_floating_network_id': DUMMY_HEX_ID, - 'service_availability_zone': 'myopenstack', - 'service_jumphost_address': 'jumphost.somewhere.org', - 'service_jumphost_user': 'cloud', - 'service_jumphost_manage_keys_script': - 'cloud_manage_keys.py', - 'service_jumphost_manage_keys_coding': 'base16', - 'service_network_id_map': {}, - 'service_sec_group_id_map': {}, - 'service_floating_network_id_map': {}, - 'service_availability_zone_map': {}, - 'service_jumphost_address_map': {}, - 'service_jumphost_user_map': {}} -DUMMY_CONF = FakeConfiguration(user_settings=DUMMY_SETTINGS_PATH, - site_cloud_access=[('distinguished_name', '.*')], - cloud_services=[DUMMY_CLOUD_SPEC]) - -DUMMY_INSTANCE_ID = '%s:%s:%s' % (DUMMY_USER, DUMMY_LABEL, DUMMY_HEX_ID) +DUMMY_FLAVOR = "openstack" +DUMMY_LABEL = "dummy-label" +DUMMY_IMAGE = "dummy-image" +DUMMY_HEX_ID = "deadbeef-dead-beef-dead-beefdeadbeef" + +DUMMY_CLOUD_SPEC = { + "service_title": "CLOUDTITLE", + "service_name": "CLOUDNAME", + "service_desc": "A Cloud for migrid site", + "service_provider_flavor": "openstack", + "service_hosts": "https://myopenstack-cloud.org:5000/v3", + "service_rules_of_conduct": "rules-of-conduct.pdf", + "service_max_user_instances": "0", + "service_max_user_instances_map": {DUMMY_USER: "1"}, + "service_allowed_images": DUMMY_IMAGE, + "service_allowed_images_map": {DUMMY_USER: "ALL"}, + "service_user_map": {DUMMY_IMAGE, "user"}, + "service_image_alias_map": {DUMMY_IMAGE.lower(): DUMMY_IMAGE}, + "service_flavor_id": DUMMY_HEX_ID, + "service_flavor_id_map": {DUMMY_USER: DUMMY_HEX_ID}, + "service_network_id": DUMMY_HEX_ID, + "service_key_id_map": {}, + "service_sec_group_id": DUMMY_HEX_ID, + "service_floating_network_id": DUMMY_HEX_ID, + "service_availability_zone": "myopenstack", + "service_jumphost_address": "jumphost.somewhere.org", + "service_jumphost_user": "cloud", + "service_jumphost_manage_keys_script": "cloud_manage_keys.py", + "service_jumphost_manage_keys_coding": "base16", + "service_network_id_map": {}, + "service_sec_group_id_map": {}, + "service_floating_network_id_map": {}, + "service_availability_zone_map": {}, + "service_jumphost_address_map": {}, + "service_jumphost_user_map": {}, +} +DUMMY_CONF = FakeConfiguration( + user_settings=DUMMY_SETTINGS_PATH, + site_cloud_access=[("distinguished_name", ".*")], + cloud_services=[DUMMY_CLOUD_SPEC], +) + +DUMMY_INSTANCE_ID = "%s:%s:%s" % (DUMMY_USER, DUMMY_LABEL, DUMMY_HEX_ID) DUMMY_INSTANCE_DICT = { DUMMY_INSTANCE_ID: { - 'INSTANCE_LABEL': DUMMY_LABEL, - 'INSTANCE_IMAGE': DUMMY_IMAGE, - 'INSTANCE_ID': DUMMY_INSTANCE_ID, - 'IMAGE_ID': DUMMY_IMAGE, - 'CREATED_TIMESTAMP': "%d" % time.time(), - 'USER_CERT': DUMMY_USER + "INSTANCE_LABEL": DUMMY_LABEL, + "INSTANCE_IMAGE": DUMMY_IMAGE, + "INSTANCE_ID": DUMMY_INSTANCE_ID, + "IMAGE_ID": DUMMY_IMAGE, + "CREATED_TIMESTAMP": "%d" % time.time(), + "USER_CERT": DUMMY_USER, } } @@ -102,38 +113,46 @@ def test_cloud_save_load(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) cleanpath(DUMMY_SETTINGS_DIR, self) - save_status = cloud_save_instance(DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, - DUMMY_LABEL, DUMMY_INSTANCE_DICT) + save_status = cloud_save_instance( + DUMMY_CONF, + DUMMY_USER, + DUMMY_CLOUD, + DUMMY_LABEL, + DUMMY_INSTANCE_DICT, + ) self.assertTrue(save_status) - saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, - '%s.state' % DUMMY_CLOUD) + saved_path = os.path.join( + DUMMY_SETTINGS_PATH, DUMMY_USER, "%s.state" % DUMMY_CLOUD + ) self.assertTrue(os.path.exists(saved_path)) - instance = cloud_load_instance(DUMMY_CONF, DUMMY_USER, - DUMMY_CLOUD, DUMMY_LABEL) + instance = cloud_load_instance( + DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, DUMMY_LABEL + ) # NOTE: instance should be a non-empty dict at this point self.assertTrue(isinstance(instance, dict)) # print(instance) self.assertTrue(DUMMY_INSTANCE_ID in instance) instance_dict = instance[DUMMY_INSTANCE_ID] - self.assertEqual(instance_dict['INSTANCE_LABEL'], DUMMY_LABEL) - self.assertEqual(instance_dict['INSTANCE_IMAGE'], DUMMY_IMAGE) - self.assertEqual(instance_dict['INSTANCE_ID'], DUMMY_INSTANCE_ID) - self.assertEqual(instance_dict['IMAGE_ID'], DUMMY_IMAGE) - self.assertEqual(instance_dict['USER_CERT'], DUMMY_USER) + self.assertEqual(instance_dict["INSTANCE_LABEL"], DUMMY_LABEL) + self.assertEqual(instance_dict["INSTANCE_IMAGE"], DUMMY_IMAGE) + self.assertEqual(instance_dict["INSTANCE_ID"], DUMMY_INSTANCE_ID) + self.assertEqual(instance_dict["IMAGE_ID"], DUMMY_IMAGE) + self.assertEqual(instance_dict["USER_CERT"], DUMMY_USER) - @unittest.skip('Work in progress - currently requires remote openstack') + @unittest.skip("Work in progress - currently requires remote openstack") def test_cloud_allowed_images(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) cleanpath(DUMMY_SETTINGS_DIR, self) - allowed_images = allowed_cloud_images(DUMMY_CONF, DUMMY_USER, - DUMMY_CLOUD, DUMMY_FLAVOR) + allowed_images = allowed_cloud_images( + DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, DUMMY_FLAVOR + ) self.assertTrue(isinstance(allowed_images, list)) print(allowed_images) self.assertTrue(DUMMY_IMAGE in allowed_images) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_compat.py b/tests/test_mig_shared_compat.py index 05f2de2bc..faad707f7 100644 --- a/tests/test_mig_shared_compat.py +++ b/tests/test_mig_shared_compat.py @@ -31,13 +31,13 @@ import os import sys +from mig.shared.compat import PY2, ensure_native_string from tests.support import MigTestCase, testmain -from mig.shared.compat import PY2, ensure_native_string +DUMMY_BYTECHARS = b"DEADBEEF" +DUMMY_BYTESRAW = binascii.unhexlify("DEADBEEF") # 4 bytes +DUMMY_UNICODE = "UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®" -DUMMY_BYTECHARS = b'DEADBEEF' -DUMMY_BYTESRAW = binascii.unhexlify('DEADBEEF') # 4 bytes -DUMMY_UNICODE = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' class MigSharedCompat__ensure_native_string(MigTestCase): """Unit test helper for the migrid code pointed to in class name""" @@ -45,7 +45,7 @@ class MigSharedCompat__ensure_native_string(MigTestCase): def test_char_bytes_conversion(self): actual = ensure_native_string(DUMMY_BYTECHARS) self.assertIs(type(actual), str) - self.assertEqual(actual, 'DEADBEEF') + self.assertEqual(actual, "DEADBEEF") def test_raw_bytes_conversion(self): with self.assertRaises(UnicodeDecodeError): @@ -60,5 +60,5 @@ def test_unicode_conversion(self): self.assertEqual(actual, DUMMY_UNICODE) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_configuration.py b/tests/test_mig_shared_configuration.py index 1108ea3e0..c02b33559 100644 --- a/tests/test_mig_shared_configuration.py +++ b/tests/test_mig_shared_configuration.py @@ -31,16 +31,23 @@ import os import unittest -from tests.support import MigTestCase, TEST_DATA_DIR, PY2, testmain +from mig.shared.configuration import ( + _CONFIGURATION_ARGUMENTS, + _CONFIGURATION_PROPERTIES, + Configuration, +) +from tests.support import PY2, TEST_DATA_DIR, MigTestCase, testmain from tests.support.fixturesupp import FixtureAssertMixin -from mig.shared.configuration import Configuration, \ - _CONFIGURATION_ARGUMENTS, _CONFIGURATION_PROPERTIES - def _to_dict(obj): - return {k: v for k, v in inspect.getmembers(obj) - if not (k.startswith('__') or inspect.ismethod(v) or inspect.isfunction(v))} + return { + k: v + for k, v in inspect.getmembers(obj) + if not ( + k.startswith("__") or inspect.ismethod(v) or inspect.isfunction(v) + ) + } class MigSharedConfiguration__static_definitions(MigTestCase): @@ -50,8 +57,9 @@ def test_consistent_parameters(self): configuration_defaults_keys = set(_CONFIGURATION_PROPERTIES.keys()) mismatched = _CONFIGURATION_ARGUMENTS - configuration_defaults_keys - self.assertEqual(len(mismatched), 0, - "configuration defaults do not match arguments") + self.assertEqual( + len(mismatched), 0, "configuration defaults do not match arguments" + ) class MigSharedConfiguration__loaded_configurations(MigTestCase): @@ -59,19 +67,23 @@ class MigSharedConfiguration__loaded_configurations(MigTestCase): def test_argument_new_user_default_ui_is_replaced(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.new_user_default_ui, 'V3') + self.assertEqual(configuration.new_user_default_ui, "V3") def test_argument_storage_protocols(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # TODO: add a test to cover filtering of a mix of valid+invalid protos # self.assertEqual(configuration.storage_protocols, ['xxx', 'yyy', 'zzz']) @@ -81,90 +93,110 @@ def test_argument_storage_protocols(self): def test_argument_wwwserve_max_bytes(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.wwwserve_max_bytes, 43211234) def test_argument_include_sections(self): """Test that include_sections path default is set""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.include_sections, - '/home/mig/mig/server/MiGserver.d') + self.assertEqual( + configuration.include_sections, "/home/mig/mig/server/MiGserver.d" + ) def test_argument_custom_include_sections(self): """Test that include_sections path override is correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") self.assertTrue(os.path.isdir(test_conf_section_dir)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.include_sections, - test_conf_section_dir) + self.assertEqual(configuration.include_sections, test_conf_section_dir) def test_argument_include_sections_quota(self): """Test that QUOTA conf section overrides are correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'quota.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "quota.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertEqual(configuration.quota_backend, 'dummy') + self.assertEqual(configuration.quota_backend, "dummy") self.assertEqual(configuration.quota_user_limit, 4242) self.assertEqual(configuration.quota_vgrid_limit, 4242424242) def test_argument_include_sections_cloud_misty(self): """Test that CLOUD_MISTY conf section is correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'cloud_misty.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "cloud_misty.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.cloud_services, list) self.assertTrue(configuration.cloud_services) self.assertIsInstance(configuration.cloud_services[0], dict) - self.assertTrue(configuration.cloud_services[0].get('service_name', - False)) - self.assertEqual(configuration.cloud_services[0]['service_name'], - 'MISTY') - self.assertEqual(configuration.cloud_services[0]['service_desc'], - 'MISTY service') - self.assertEqual(configuration.cloud_services[0]['service_provider_flavor'], - 'nostack') + self.assertTrue( + configuration.cloud_services[0].get("service_name", False) + ) + self.assertEqual( + configuration.cloud_services[0]["service_name"], "MISTY" + ) + self.assertEqual( + configuration.cloud_services[0]["service_desc"], "MISTY service" + ) + self.assertEqual( + configuration.cloud_services[0]["service_provider_flavor"], + "nostack", + ) def test_argument_include_sections_global_accepted(self): """Test that peripheral GLOBAL conf overrides are accepted (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'global.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "global.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertEqual(configuration.admin_email, "admin@somewhere.org") @@ -176,93 +208,105 @@ def test_argument_include_sections_global_accepted(self): def test_argument_include_sections_global_rejected(self): """Test that core GLOBAL conf overrides are rejected (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'global.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "global.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Run through the snippet values and check that override didn't succeed # and then that default is left set. The former _could_ be left out but # is kept explicit for clarity in case something breaks by changes. - self.assertNotEqual(configuration.include_sections, '/tmp/MiGserver.d') + self.assertNotEqual(configuration.include_sections, "/tmp/MiGserver.d") self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertNotEqual(configuration.mig_path, '/tmp/mig/mig') - self.assertEqual(configuration.mig_path, '/home/mig/mig') - self.assertNotEqual(configuration.logfile, '/tmp/mig.log') - self.assertEqual(configuration.logfile, 'mig.log') - self.assertNotEqual(configuration.loglevel, 'warning') - self.assertEqual(configuration.loglevel, 'info') - self.assertNotEqual(configuration.server_fqdn, 'somewhere.org') - self.assertEqual(configuration.server_fqdn, '') - self.assertNotEqual(configuration.migserver_public_url, - 'https://somewhere.org') - self.assertEqual(configuration.migserver_public_url, '') - self.assertNotEqual(configuration.migserver_https_sid_url, - 'https://somewhere.org') - self.assertEqual(configuration.migserver_https_sid_url, '') - self.assertNotEqual(configuration.user_openid_address, 'somewhere.org') - self.assertNotEqual(configuration.user_openid_address, 'somewhere.org') - self.assertEqual(configuration.user_openid_address, '') + self.assertNotEqual(configuration.mig_path, "/tmp/mig/mig") + self.assertEqual(configuration.mig_path, "/home/mig/mig") + self.assertNotEqual(configuration.logfile, "/tmp/mig.log") + self.assertEqual(configuration.logfile, "mig.log") + self.assertNotEqual(configuration.loglevel, "warning") + self.assertEqual(configuration.loglevel, "info") + self.assertNotEqual(configuration.server_fqdn, "somewhere.org") + self.assertEqual(configuration.server_fqdn, "") + self.assertNotEqual( + configuration.migserver_public_url, "https://somewhere.org" + ) + self.assertEqual(configuration.migserver_public_url, "") + self.assertNotEqual( + configuration.migserver_https_sid_url, "https://somewhere.org" + ) + self.assertEqual(configuration.migserver_https_sid_url, "") + self.assertNotEqual(configuration.user_openid_address, "somewhere.org") + self.assertNotEqual(configuration.user_openid_address, "somewhere.org") + self.assertEqual(configuration.user_openid_address, "") self.assertNotEqual(configuration.user_openid_port, 4242) self.assertEqual(configuration.user_openid_port, 8443) - self.assertNotEqual(configuration.user_openid_key, '/tmp/openid.key') - self.assertEqual(configuration.user_openid_key, '') - self.assertNotEqual(configuration.user_openid_log, '/tmp/openid.log') - self.assertEqual(configuration.user_openid_log, - '/home/mig/state/log/openid.log') + self.assertNotEqual(configuration.user_openid_key, "/tmp/openid.key") + self.assertEqual(configuration.user_openid_key, "") + self.assertNotEqual(configuration.user_openid_log, "/tmp/openid.log") + self.assertEqual( + configuration.user_openid_log, "/home/mig/state/log/openid.log" + ) def test_argument_include_sections_site_accepted(self): """Test that peripheral SITE conf overrides are accepted (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'site.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "site.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertEqual(configuration.short_title, 'ACME Site') - self.assertEqual(configuration.new_user_default_ui, 'V3') - self.assertEqual(configuration.site_password_legacy_policy, 'MEDIUM') - self.assertEqual(configuration.site_support_text, - 'Custom support text') - self.assertEqual(configuration.site_privacy_text, - 'Custom privacy text') - self.assertEqual(configuration.site_peers_notice, - 'Custom peers notice') - self.assertEqual(configuration.site_peers_contact_hint, - 'Custom peers contact hint') + self.assertEqual(configuration.short_title, "ACME Site") + self.assertEqual(configuration.new_user_default_ui, "V3") + self.assertEqual(configuration.site_password_legacy_policy, "MEDIUM") + self.assertEqual(configuration.site_support_text, "Custom support text") + self.assertEqual(configuration.site_privacy_text, "Custom privacy text") + self.assertEqual(configuration.site_peers_notice, "Custom peers notice") + self.assertEqual( + configuration.site_peers_contact_hint, "Custom peers contact hint" + ) self.assertIsInstance(configuration.site_freeze_admins, list) self.assertTrue(len(configuration.site_freeze_admins) == 1) - self.assertTrue('BOFH' in configuration.site_freeze_admins) - self.assertEqual(configuration.site_freeze_to_tape, - 'Custom freeze to tape') - self.assertEqual(configuration.site_freeze_doi_text, - 'Custom freeze doi text') - self.assertEqual(configuration.site_freeze_doi_url, - 'https://somewhere.org/mint-doi') - self.assertEqual(configuration.site_freeze_doi_url_field, - 'archiveurl') + self.assertTrue("BOFH" in configuration.site_freeze_admins) + self.assertEqual( + configuration.site_freeze_to_tape, "Custom freeze to tape" + ) + self.assertEqual( + configuration.site_freeze_doi_text, "Custom freeze doi text" + ) + self.assertEqual( + configuration.site_freeze_doi_url, "https://somewhere.org/mint-doi" + ) + self.assertEqual(configuration.site_freeze_doi_url_field, "archiveurl") def test_argument_include_sections_site_rejected(self): """Test that core SITE conf overrides are rejected (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'site.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "site.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertEqual(configuration.site_enable_openid, False) @@ -279,56 +323,63 @@ def test_argument_include_sections_site_rejected(self): def test_argument_include_sections_with_invalid_conf_filename(self): """Test that conf snippet with missing .conf extension gets ignored""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'dummy') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join(test_conf_section_dir, "dummy") self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf only contains SETTINGS section which is ignored due to mismatch self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Pig Latin' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Pig Latin" in configuration.language) + self.assertEqual(configuration.language, ["English"]) def test_argument_include_sections_with_section_name_mismatch(self): """Test that conf section must match filename""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'section-mismatch.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "section-mismatch.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf only contains SETTINGS section which is ignored due to mismatch self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Pig Latin' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Pig Latin" in configuration.language) + self.assertEqual(configuration.language, ["English"]) def test_argument_include_sections_multi_ignores_other_sections(self): """Test that conf section must match filename and others are ignored""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'multi.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "multi.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf contains MULTI and SETTINGS sections and latter must be ignored self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Spanglish' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Spanglish" in configuration.language) + self.assertEqual(configuration.language, ["English"]) # TODO: rename file to valid section name we can check and enable next? # self.assertEqual(configuration.multi, 'blabla') @@ -339,15 +390,16 @@ class MigSharedConfiguration__new_instance(MigTestCase, FixtureAssertMixin): @unittest.skipIf(PY2, "Python 3 only") def test_default_object(self): prepared_fixture = self.prepareFixtureAssert( - 'mig_shared_configuration--new', fixture_format='json') + "mig_shared_configuration--new", fixture_format="json" + ) configuration = Configuration(None) # TODO: the following work-around default values set for these on the # instance that no longer make total sense but fiddling with them # is better as a follow-up. - configuration.certs_path = '/some/place/certs' - configuration.state_path = '/some/place/state' - configuration.mig_path = '/some/place/mig' + configuration.certs_path = "/some/place/certs" + configuration.state_path = "/some/place/state" + configuration.mig_path = "/some/place/mig" actual_values = _to_dict(configuration) @@ -358,11 +410,11 @@ def test_object_isolation(self): configuration_2 = Configuration(None) # change one of the configuration objects - configuration_1.default_page.append('foobar') + configuration_1.default_page.append("foobar") # check the other was not affected - self.assertEqual(configuration_2.default_page, ['']) + self.assertEqual(configuration_2.default_page, [""]) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_fileio.py b/tests/test_mig_shared_fileio.py index f43f013e8..277da3a7c 100644 --- a/tests/test_mig_shared_fileio.py +++ b/tests/test_mig_shared_fileio.py @@ -35,52 +35,53 @@ # Imports of the code under test import mig.shared.fileio as fileio + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -DUMMY_BYTES = binascii.unhexlify('DEADBEEF') # 4 bytes +DUMMY_BYTES = binascii.unhexlify("DEADBEEF") # 4 bytes DUMMY_BYTES_LENGTH = 4 -DUMMY_UNICODE = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' +DUMMY_UNICODE = "UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®" DUMMY_UNICODE_LENGTH = len(DUMMY_UNICODE) -DUMMY_TEXT = 'dummy' -DUMMY_TWICE = 'dummy - dummy' -DUMMY_TESTDIR = 'fileio' -DUMMY_SUBDIR = 'subdir' -DUMMY_FILE_ONE = 'file1.txt' -DUMMY_FILE_TWO = 'file2.txt' -DUMMY_FILE_MISSING = 'missing.txt' -DUMMY_FILE_RO = 'readonly.txt' -DUMMY_FILE_WO = 'writeonly.txt' -DUMMY_FILE_RW = 'readwrite.txt' -DUMMY_DIRECTORY_NESTED = 'nested/dir/structure' -DUMMY_DIRECTORY_EMPTY = 'empty_dir' -DUMMY_DIRECTORY_MOVE_SRC = 'move_dir_src' -DUMMY_DIRECTORY_MOVE_DST = 'move_dir_dst' -DUMMY_DIRECTORY_REMOVE = 'remove_dir' -DUMMY_DIRECTORY_CHECKACCESS = 'check_access' -DUMMY_DIRECTORY_MAKEDIRSREC = 'makedirs_rec' -DUMMY_DIRECTORY_COPYRECSRC = 'copy_dir_src' -DUMMY_DIRECTORY_COPYRECDST = 'copy_dir_dst' -DUMMY_DIRECTORY_REMOVEREC = 'remove_rec' +DUMMY_TEXT = "dummy" +DUMMY_TWICE = "dummy - dummy" +DUMMY_TESTDIR = "fileio" +DUMMY_SUBDIR = "subdir" +DUMMY_FILE_ONE = "file1.txt" +DUMMY_FILE_TWO = "file2.txt" +DUMMY_FILE_MISSING = "missing.txt" +DUMMY_FILE_RO = "readonly.txt" +DUMMY_FILE_WO = "writeonly.txt" +DUMMY_FILE_RW = "readwrite.txt" +DUMMY_DIRECTORY_NESTED = "nested/dir/structure" +DUMMY_DIRECTORY_EMPTY = "empty_dir" +DUMMY_DIRECTORY_MOVE_SRC = "move_dir_src" +DUMMY_DIRECTORY_MOVE_DST = "move_dir_dst" +DUMMY_DIRECTORY_REMOVE = "remove_dir" +DUMMY_DIRECTORY_CHECKACCESS = "check_access" +DUMMY_DIRECTORY_MAKEDIRSREC = "makedirs_rec" +DUMMY_DIRECTORY_COPYRECSRC = "copy_dir_src" +DUMMY_DIRECTORY_COPYRECDST = "copy_dir_dst" +DUMMY_DIRECTORY_REMOVEREC = "remove_rec" # File/dir paths for move/copy operations -DUMMY_FILE_MOVE_SRC = 'move_src' -DUMMY_FILE_MOVE_DST = 'move_dst' -DUMMY_FILE_COPY_SRC = 'copy_src' -DUMMY_FILE_COPY_DST = 'copy_dst' -DUMMY_FILE_WRITECHUNK = 'write_chunk' -DUMMY_FILE_WRITEFILE = 'write_file' -DUMMY_FILE_WRITEFILELINES = 'write_file_lines' -DUMMY_FILE_READFILE = 'read_file' -DUMMY_FILE_READFILELINES = 'read_file_lines' -DUMMY_FILE_READHEADLINES = 'read_head_lines' -DUMMY_FILE_READTAILLINES = 'read_tail_lines' -DUMMY_FILE_DELETEFILE = 'delete_file' -DUMMY_FILE_GETFILESIZE = 'get_file_size' -DUMMY_FILE_MAKESYMLINKSRC = 'link_src' -DUMMY_FILE_MAKESYMLINKDST = 'link_target' -DUMMY_FILE_DELETESYMLINKSRC = 'link_src' -DUMMY_FILE_DELETESYMLINKDST = 'link_target' -DUMMY_FILE_TOUCH = 'touch_file' +DUMMY_FILE_MOVE_SRC = "move_src" +DUMMY_FILE_MOVE_DST = "move_dst" +DUMMY_FILE_COPY_SRC = "copy_src" +DUMMY_FILE_COPY_DST = "copy_dst" +DUMMY_FILE_WRITECHUNK = "write_chunk" +DUMMY_FILE_WRITEFILE = "write_file" +DUMMY_FILE_WRITEFILELINES = "write_file_lines" +DUMMY_FILE_READFILE = "read_file" +DUMMY_FILE_READFILELINES = "read_file_lines" +DUMMY_FILE_READHEADLINES = "read_head_lines" +DUMMY_FILE_READTAILLINES = "read_tail_lines" +DUMMY_FILE_DELETEFILE = "delete_file" +DUMMY_FILE_GETFILESIZE = "get_file_size" +DUMMY_FILE_MAKESYMLINKSRC = "link_src" +DUMMY_FILE_MAKESYMLINKDST = "link_target" +DUMMY_FILE_DELETESYMLINKSRC = "link_src" +DUMMY_FILE_DELETESYMLINKDST = "link_target" +DUMMY_FILE_TOUCH = "touch_file" assert isinstance(DUMMY_BYTES, bytes) @@ -90,12 +91,13 @@ class MigSharedFileio__temporary_umask(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_file_one = os.path.join(self.tmp_base, DUMMY_FILE_ONE) try: @@ -121,63 +123,63 @@ def before_each(self): def test_creates_new_file_with_temporary_umask_777(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o777): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o000) def test_creates_new_file_with_temporary_umask_277(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o277): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o400) def test_creates_new_file_with_temporary_umask_227(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o227): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o440) def test_creates_new_file_with_temporary_umask_077(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o077): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o600) def test_creates_new_file_with_temporary_umask_027(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o027): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o640) def test_creates_new_file_with_temporary_umask_007(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o007): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o660) def test_creates_new_file_with_temporary_umask_022(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o022): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o644) def test_creates_new_file_with_temporary_umask_002(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o002): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o664) def test_creates_new_file_with_temporary_umask_000(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o000): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o666) @@ -257,8 +259,9 @@ def test_restores_original_umask_after_exit(self): current_umask = os.umask(original_umask) # Retrieve and reset # Cleanup: Restore environment os.umask(current_umask) - self.assertEqual(current_umask, 0o022, - "Failed to restore original umask") + self.assertEqual( + current_umask, 0o022, "Failed to restore original umask" + ) finally: os.umask(original_umask) # Ensure cleanup @@ -267,20 +270,20 @@ def test_nested_temporary_umask(self): original_umask = os.umask(0o022) try: with fileio.temporary_umask(0o027): # Outer context - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode1 = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode1, 0o640) # 666 & ~027 = 640 with fileio.temporary_umask(0o077): # Inner context - open(self.tmp_file_two, 'w').close() + open(self.tmp_file_two, "w").close() mode2 = os.stat(self.tmp_file_two).st_mode & 0o777 self.assertEqual(mode2, 0o600) # 666 & ~077 # Back to outer context umask os.remove(self.tmp_file_one) - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode1_after = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode1_after, 0o640) # 666 & ~027 # Back to original umask - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode_original = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode_original, 0o640) # 666 & ~002 finally: @@ -309,7 +312,7 @@ def test_restores_umask_after_exception(self): def test_umask_does_not_affect_existing_files(self): """Test temporary_umask doesn't modify existing file permissions""" - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() os.chmod(self.tmp_file_one, 0o644) # Explicit permissions with fileio.temporary_umask(0o077): # Shouldn't affect existing file # Change permissions inside context @@ -324,12 +327,13 @@ class MigSharedFileio__write_chunk(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_WRITECHUNK) @@ -338,16 +342,18 @@ def test_return_false_on_invalid_data(self): self.logger.forgive_errors() # NOTE: we make sure to disable any forced stringification here - did_succeed = fileio.write_chunk(self.tmp_path, 1234, 0, self.logger, - force_string=False) + did_succeed = fileio.write_chunk( + self.tmp_path, 1234, 0, self.logger, force_string=False + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_offset(self): """Test write_chunk returns False with negative offset value""" self.logger.forgive_errors() - did_succeed = fileio.write_chunk(self.tmp_path, DUMMY_BYTES, -42, - self.logger) + did_succeed = fileio.write_chunk( + self.tmp_path, DUMMY_BYTES, -42, self.logger + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_dir(self): @@ -368,7 +374,7 @@ def test_store_bytes(self): """Test write_chunk stores byte data correctly at offset 0""" fileio.write_chunk(self.tmp_path, DUMMY_BYTES, 0, self.logger) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) @@ -379,42 +385,52 @@ def test_store_bytes_at_offset(self): fileio.write_chunk(self.tmp_path, DUMMY_BYTES, offset, self.logger) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH + offset) - self.assertEqual(content[0:3], bytearray([0, 0, 0]), - "expected a hole was left") + self.assertEqual( + content[0:3], bytearray([0, 0, 0]), "expected a hole was left" + ) self.assertEqual(content[3:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_bytes_in_text_mode(self): """Test write_chunk stores byte data in text mode""" - fileio.write_chunk(self.tmp_path, DUMMY_BYTES, 0, self.logger, - mode="r+") + fileio.write_chunk( + self.tmp_path, DUMMY_BYTES, 0, self.logger, mode="r+" + ) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode(self): """Test write_chunk stores unicode data in text mode""" - fileio.write_chunk(self.tmp_path, DUMMY_UNICODE, 0, self.logger, - mode='r+') + fileio.write_chunk( + self.tmp_path, DUMMY_UNICODE, 0, self.logger, mode="r+" + ) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode_in_binary_mode(self): """Test write_chunk stores unicode data in binary mode""" - fileio.write_chunk(self.tmp_path, DUMMY_UNICODE, 0, self.logger, - mode='r+b') + fileio.write_chunk( + self.tmp_path, DUMMY_UNICODE, 0, self.logger, mode="r+b" + ) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) @@ -425,12 +441,13 @@ class MigSharedFileio__write_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) # NOTE: we inject sub-directory to test with missing and existing self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) @@ -441,24 +458,25 @@ def test_return_false_on_invalid_data(self): self.logger.forgive_errors() # NOTE: we make sure to disable any forced stringification here - did_succeed = fileio.write_file(1234, self.tmp_path, self.logger, - force_string=False) + did_succeed = fileio.write_file( + 1234, self.tmp_path, self.logger, force_string=False + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_dir(self): """Test write_file returns False when path is a directory""" self.logger.forgive_errors() ensure_dirs_exist(self.tmp_path) - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, - self.logger) + did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger) self.assertFalse(did_succeed) def test_return_false_on_missing_dir(self): """Test write_file returns False on missing parent dir""" self.logger.forgive_errors() - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, - self.logger, make_parent=False) + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, make_parent=False + ) self.assertFalse(did_succeed) def test_creates_directory(self): @@ -466,7 +484,7 @@ def test_creates_directory(self): # TODO: temporarily use empty string to avoid any byte/unicode issues # did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, # self.logger) - did_succeed = fileio.write_file('', self.tmp_path, self.logger) + did_succeed = fileio.write_file("", self.tmp_path, self.logger) self.assertTrue(did_succeed) path_kind = self.assertPathExists(self.tmp_path) @@ -475,49 +493,59 @@ def test_creates_directory(self): # TODO: replace next test once we have auto adjust mode in write helper def test_store_bytes_with_manual_adjust_mode(self): """Test write_file stores byte data in with manual adjust mode call""" - mode = 'w' + mode = "w" mode = fileio._auto_adjust_mode(DUMMY_BYTES, mode) - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger, - mode=mode) + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, mode=mode + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_bytes_in_text_mode(self): """Test write_file stores byte data when opening in text mode""" - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger, - mode="w") + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, mode="w" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode(self): """Test write_file stores unicode string when opening in text mode""" - did_succeed = fileio.write_file(DUMMY_UNICODE, self.tmp_path, - self.logger, mode='w') + did_succeed = fileio.write_file( + DUMMY_UNICODE, self.tmp_path, self.logger, mode="w" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode_in_binary_mode(self): """Test write_file handles unicode strings when opening in binary mode""" - did_succeed = fileio.write_file(DUMMY_UNICODE, self.tmp_path, - self.logger, mode='wb') + did_succeed = fileio.write_file( + DUMMY_UNICODE, self.tmp_path, self.logger, mode="wb" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) @@ -528,12 +556,13 @@ class MigSharedFileio__write_file_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) # NOTE: we inject sub-directory to test with missing and existing self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) @@ -542,8 +571,7 @@ def before_each(self): def test_write_lines(self): """Test write_file_lines writes lines to a file""" test_lines = ["line1\n", "line2\n", "line3"] - result = fileio.write_file_lines( - test_lines, self.tmp_path, self.logger) + result = fileio.write_file_lines(test_lines, self.tmp_path, self.logger) self.assertTrue(result) # Verify with read_file_lines @@ -559,8 +587,7 @@ def test_invalid_data(self): def test_creates_directory(self): """Test write_file_lines creates parent directory when needed""" test_lines = ["test line"] - result = fileio.write_file_lines( - test_lines, self.tmp_path, self.logger) + result = fileio.write_file_lines(test_lines, self.tmp_path, self.logger) self.assertTrue(result) path_kind = self.assertPathExists(self.tmp_path) @@ -571,14 +598,16 @@ def test_return_false_on_invalid_dir(self): self.logger.forgive_errors() ensure_dirs_exist(self.tmp_path) result = fileio.write_file_lines( - [DUMMY_TEXT], self.tmp_path, self.logger) + [DUMMY_TEXT], self.tmp_path, self.logger + ) self.assertFalse(result) def test_return_false_on_missing_dir(self): """Test write_file_lines fails when parent directory missing""" self.logger.forgive_errors() - result = fileio.write_file_lines([DUMMY_TEXT], self.tmp_path, self.logger, - make_parent=False) + result = fileio.write_file_lines( + [DUMMY_TEXT], self.tmp_path, self.logger, make_parent=False + ) self.assertFalse(result) @@ -587,40 +616,43 @@ class MigSharedFileio__read_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READFILE) def test_reads_bytes(self): """Test read_file returns byte content with binary mode""" - with open(self.tmp_path, 'wb') as fh: + with open(self.tmp_path, "wb") as fh: fh.write(DUMMY_BYTES) - content = fileio.read_file(self.tmp_path, self.logger, mode='rb') + content = fileio.read_file(self.tmp_path, self.logger, mode="rb") self.assertEqual(content, DUMMY_BYTES) def test_reads_text(self): """Test read_file returns text with text mode""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write(DUMMY_UNICODE) - content = fileio.read_file(self.tmp_path, self.logger, mode='r') + content = fileio.read_file(self.tmp_path, self.logger, mode="r") self.assertEqual(content, DUMMY_UNICODE) def test_allows_missing_file(self): """Test read_file returns None with allow_missing=True""" content = fileio.read_file( - 'missing.txt', self.logger, allow_missing=True) + "missing.txt", self.logger, allow_missing=True + ) self.assertIsNone(content) def test_reports_missing_file(self): """Test read_file returns None with allow_missing=False""" self.logger.forgive_errors() content = fileio.read_file( - 'missing.txt', self.logger, allow_missing=False) + "missing.txt", self.logger, allow_missing=False + ) self.assertIsNone(content) def test_handles_directory_path(self): @@ -636,31 +668,32 @@ class MigSharedFileio__read_file_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READFILELINES) def test_returns_empty_list_for_empty_file(self): """Test read_file_lines returns empty list for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_file_lines(self.tmp_path, self.logger) self.assertEqual(lines, []) def test_reads_lines_from_file(self): """Test read_file_lines returns lines from text file""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3") lines = fileio.read_file_lines(self.tmp_path, self.logger) self.assertEqual(lines, ["line1\n", "line2\n", "line3"]) def test_none_for_missing_file(self): self.logger.forgive_errors() - lines = fileio.read_file_lines('missing.txt', self.logger) + lines = fileio.read_file_lines("missing.txt", self.logger) self.assertIsNone(lines) @@ -669,18 +702,19 @@ class MigSharedFileio__get_file_size(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_GETFILESIZE) def test_returns_file_size(self): """Test get_file_size returns correct file size""" - with open(self.tmp_path, 'wb') as fh: + with open(self.tmp_path, "wb") as fh: fh.write(DUMMY_BYTES) size = fileio.get_file_size(self.tmp_path, self.logger) self.assertEqual(size, DUMMY_BYTES_LENGTH) @@ -688,7 +722,7 @@ def test_returns_file_size(self): def test_handles_missing_file(self): """Test get_file_size returns -1 for missing file""" self.logger.forgive_errors() - size = fileio.get_file_size('missing.txt', self.logger) + size = fileio.get_file_size("missing.txt", self.logger) self.assertEqual(size, -1) def test_handles_directory(self): @@ -707,18 +741,19 @@ class MigSharedFileio__delete_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_DELETEFILE) def test_deletes_existing_file(self): """Test delete_file removes existing file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() result = fileio.delete_file(self.tmp_path, self.logger) self.assertTrue(result) self.assertFalse(os.path.exists(self.tmp_path)) @@ -726,15 +761,16 @@ def test_deletes_existing_file(self): def test_handles_missing_file_with_allow_missing(self): """Test delete_file succeeds with allow_missing=True""" result = fileio.delete_file( - 'missing.txt', self.logger, allow_missing=True) + "missing.txt", self.logger, allow_missing=True + ) self.assertTrue(result) def test_false_for_missing_file_without_allow_missing(self): """Test delete_file returns False with allow_missing=False""" self.logger.forgive_errors() - result = fileio.delete_file('missing.txt', - self.logger, - allow_missing=False) + result = fileio.delete_file( + "missing.txt", self.logger, allow_missing=False + ) self.assertFalse(result) @@ -743,39 +779,40 @@ class MigSharedFileio__read_head_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READHEADLINES) def test_reads_requested_lines(self): """Test read_head_lines returns requested number of lines""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3\nline4") lines = fileio.read_head_lines(self.tmp_path, 2, self.logger) self.assertEqual(lines, ["line1\n", "line2\n"]) def test_returns_all_lines_when_requested_more(self): """Test read_head_lines returns all lines when file has fewer""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2") lines = fileio.read_head_lines(self.tmp_path, 5, self.logger) self.assertEqual(lines, ["line1\n", "line2"]) def test_returns_empty_list_for_empty_file(self): """Test read_head_lines returns empty for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_head_lines(self.tmp_path, 3, self.logger) self.assertEqual(lines, []) def test_empty_for_missing_file(self): """Test read_head_lines returns [] for missing file""" self.logger.forgive_errors() - lines = fileio.read_head_lines('missing.txt', 3, self.logger) + lines = fileio.read_head_lines("missing.txt", 3, self.logger) self.assertEqual(lines, []) @@ -784,39 +821,40 @@ class MigSharedFileio__read_tail_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READTAILLINES) def test_reads_requested_lines(self): """Test read_tail_lines returns requested number of lines""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3\nline4") lines = fileio.read_tail_lines(self.tmp_path, 2, self.logger) self.assertEqual(lines, ["line3\n", "line4"]) def test_returns_all_lines_when_requested_more(self): """Test read_tail_lines returns all lines when file has fewer""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2") lines = fileio.read_tail_lines(self.tmp_path, 5, self.logger) self.assertEqual(lines, ["line1\n", "line2"]) def test_returns_empty_list_for_empty_file(self): """Test read_tail_lines returns empty for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_tail_lines(self.tmp_path, 3, self.logger) self.assertEqual(lines, []) def test_empty_for_missing_file(self): """Test read_tail_lines returns [] for missing file""" self.logger.forgive_errors() - lines = fileio.read_tail_lines('missing.txt', 3, self.logger) + lines = fileio.read_tail_lines("missing.txt", 3, self.logger) self.assertEqual(lines, []) @@ -825,49 +863,52 @@ class MigSharedFileio__make_symlink(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) ensure_dirs_exist(self.tmp_dir) self.tmp_link = os.path.join(self.tmp_dir, DUMMY_FILE_MAKESYMLINKSRC) - self.tmp_target = os.path.join(self.tmp_dir, - DUMMY_FILE_MAKESYMLINKDST) - with open(self.tmp_target, 'w') as fh: + self.tmp_target = os.path.join(self.tmp_dir, DUMMY_FILE_MAKESYMLINKDST) + with open(self.tmp_target, "w") as fh: fh.write(DUMMY_TEXT) def test_creates_symlink(self): """Test make_symlink creates working symlink""" result = fileio.make_symlink( - self.tmp_target, self.tmp_link, self.logger) + self.tmp_target, self.tmp_link, self.logger + ) self.assertTrue(result) self.assertTrue(os.path.islink(self.tmp_link)) self.assertEqual(os.readlink(self.tmp_link), self.tmp_target) def test_force_overwrites_existing_link(self): """Test make_symlink force replaces existing link""" - os.symlink('/dummy', self.tmp_link) - result = fileio.make_symlink(self.tmp_target, self.tmp_link, - self.logger, force=True) + os.symlink("/dummy", self.tmp_link) + result = fileio.make_symlink( + self.tmp_target, self.tmp_link, self.logger, force=True + ) self.assertTrue(result) self.assertEqual(os.readlink(self.tmp_link), self.tmp_target) def test_fails_on_existing_link_without_force(self): """Test make_symlink fails on existing link without force""" self.logger.forgive_errors() - os.symlink('/dummy', self.tmp_link) - result = fileio.make_symlink(self.tmp_target, self.tmp_link, self.logger, - force=False) + os.symlink("/dummy", self.tmp_link) + result = fileio.make_symlink( + self.tmp_target, self.tmp_link, self.logger, force=False + ) self.assertFalse(result) def test_handles_nonexistent_target(self): """Test make_symlink still creates broken symlink""" self.logger.forgive_errors() - broken_target = self.tmp_target + '-nonexistent' + broken_target = self.tmp_target + "-nonexistent" result = fileio.make_symlink(broken_target, self.tmp_link, self.logger) self.assertTrue(result) self.assertTrue(os.path.islink(self.tmp_link)) @@ -879,20 +920,21 @@ class MigSharedFileio__delete_symlink(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) ensure_dirs_exist(self.tmp_dir) - self.tmp_link = os.path.join(self.tmp_dir, - DUMMY_FILE_DELETESYMLINKSRC) - self.tmp_target = os.path.join(self.tmp_dir, - DUMMY_FILE_DELETESYMLINKDST) - with open(self.tmp_target, 'w') as fh: + self.tmp_link = os.path.join(self.tmp_dir, DUMMY_FILE_DELETESYMLINKSRC) + self.tmp_target = os.path.join( + self.tmp_dir, DUMMY_FILE_DELETESYMLINKDST + ) + with open(self.tmp_target, "w") as fh: fh.write(DUMMY_TEXT) def create_symlink(self, target=None, link=None): @@ -915,33 +957,36 @@ def test_handles_missing_file_with_allow_missing(self): # First make sure file doesn't exist if os.path.exists(self.tmp_link): os.remove(self.tmp_link) - result = fileio.delete_symlink(self.tmp_link, self.logger, - allow_missing=True) + result = fileio.delete_symlink( + self.tmp_link, self.logger, allow_missing=True + ) self.assertTrue(result) def test_handles_missing_symlink_without_allow_missing(self): """Test delete_symlink fails with allow_missing=False""" self.logger.forgive_errors() - result = fileio.delete_symlink('missing_symlink', self.logger, - allow_missing=False) + result = fileio.delete_symlink( + "missing_symlink", self.logger, allow_missing=False + ) self.assertFalse(result) @unittest.skip("TODO: implement check in tested function and enable again") def test_rejects_regular_file(self): """Test delete_symlink returns False when path is a regular file""" - with open(self.tmp_link, 'w') as fh: + with open(self.tmp_link, "w") as fh: fh.write(DUMMY_TEXT) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: result = fileio.delete_symlink(self.tmp_link, self.logger) self.assertFalse(result) - self.assertTrue(any('Could not remove' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Could not remove" in msg for msg in log_capture.output) + ) def test_deletes_broken_symlink(self): """Test delete_symlink removes broken symlink""" # Create broken symlink - broken_target = self.tmp_target + '-nonexistent' + broken_target = self.tmp_target + "-nonexistent" self.create_symlink(broken_target) self.assertTrue(os.path.islink(self.tmp_link)) # Now delete it @@ -954,12 +999,13 @@ class MigSharedFileio__touch(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_TOUCH) @@ -971,11 +1017,13 @@ def test_creates_new_file(self): self.assertTrue(os.path.exists(self.tmp_path)) self.assertTrue(os.path.isfile(self.tmp_path)) - @unittest.skip("TODO: fix invalid open 'r+w' in tested function and enable again") + @unittest.skip( + "TODO: fix invalid open 'r+w' in tested function and enable again" + ) def test_updates_timestamp_on_existing_file(self): """Test touch updates timestamp on existing file""" # Create initial file - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write(DUMMY_TEXT) orig_mtime = os.path.getmtime(self.tmp_path) time.sleep(0.1) @@ -984,7 +1032,9 @@ def test_updates_timestamp_on_existing_file(self): new_mtime = os.path.getmtime(self.tmp_path) self.assertNotEqual(orig_mtime, new_mtime) - @unittest.skip("TODO: fix handling of directory in tested function and enable again") + @unittest.skip( + "TODO: fix handling of directory in tested function and enable again" + ) def test_succeeds_on_directory(self): """Test touch succeeds for existing directory and updates timestamp""" ensure_dirs_exist(self.tmp_path) @@ -999,7 +1049,7 @@ def test_succeeds_on_directory(self): def test_fails_on_missing_parent(self): """Test touch fails when parent directory doesn't exist""" self.logger.forgive_errors() - nested_path = os.path.join(self.tmp_path, 'missing', DUMMY_FILE_ONE) + nested_path = os.path.join(self.tmp_path, "missing", DUMMY_FILE_ONE) result = fileio.touch(nested_path, self.configuration) self.assertFalse(result) self.assertFalse(os.path.exists(nested_path)) @@ -1010,12 +1060,13 @@ class MigSharedFileio__remove_dir(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVE) # NOTE: we prepare tmp_path as directory here @@ -1032,7 +1083,7 @@ def test_fails_on_nonempty_directory(self): """Test remove_dir returns False for non-empty directory""" self.logger.forgive_errors() # Add a file to the directory - with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) result = fileio.remove_dir(self.tmp_path, self.configuration) self.assertFalse(result) @@ -1043,7 +1094,7 @@ def test_fails_on_file(self): self.logger.forgive_errors() # Add a file to the directory file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with open(file_path, 'w') as fh: + with open(file_path, "w") as fh: fh.write(DUMMY_TEXT) result = fileio.remove_dir(file_path, self.configuration) self.assertFalse(result) @@ -1055,12 +1106,13 @@ class MigSharedFileio__remove_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVEREC) # Create a nested directory structure with files @@ -1069,10 +1121,11 @@ def before_each(self): # └── subdir/ # └── file2.txt ensure_dirs_exist(os.path.join(self.tmp_path, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_path, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_path, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_removes_directory_recursively(self): @@ -1085,7 +1138,7 @@ def test_removes_directory_recursively(self): def test_removes_directory_recursively_with_symlink(self): """Test remove_rec removes directory and contents with symlink""" link_src = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - link_dst = os.path.join(self.tmp_path, DUMMY_FILE_ONE + '.lnk') + link_dst = os.path.join(self.tmp_path, DUMMY_FILE_ONE + ".lnk") os.symlink(link_src, link_dst) self.assertTrue(os.path.exists(self.tmp_path)) result = fileio.remove_rec(self.tmp_path, self.configuration) @@ -1095,7 +1148,7 @@ def test_removes_directory_recursively_with_symlink(self): def test_removes_directory_recursively_with_broken_symlink(self): """Test remove_rec removes directory and contents with broken symlink""" link_src = os.path.join(self.tmp_path, DUMMY_FILE_MISSING) - link_dst = os.path.join(self.tmp_path, DUMMY_FILE_MISSING + '.lnk') + link_dst = os.path.join(self.tmp_path, DUMMY_FILE_MISSING + ".lnk") os.symlink(link_src, link_dst) self.assertTrue(os.path.exists(self.tmp_path)) result = fileio.remove_rec(self.tmp_path, self.configuration) @@ -1115,11 +1168,12 @@ def test_removes_directory_recursively_despite_readonly(self): def test_rejects_regular_file(self): """Test remove_rec returns False when path is a regular file""" file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: result = fileio.remove_rec(file_path, self.configuration) self.assertFalse(result) - self.assertTrue(any('Could not remove' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Could not remove" in msg for msg in log_capture.output) + ) self.assertTrue(os.path.exists(file_path)) @@ -1128,22 +1182,24 @@ class MigSharedFileio__move_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_FILE_MOVE_SRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_FILE_MOVE_DST) - with open(self.tmp_src, 'w') as fh: + with open(self.tmp_src, "w") as fh: fh.write(DUMMY_TEXT) def test_moves_file(self): """Test move_file successfully moves a file""" - success, msg = fileio.move_file(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_file( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) self.assertFalse(os.path.exists(self.tmp_src)) @@ -1152,13 +1208,14 @@ def test_moves_file(self): def test_overwrites_existing_destination(self): """Test move_file overwrites existing destination file""" # Create initial destination file - with open(self.tmp_dst, 'w') as fh: + with open(self.tmp_dst, "w") as fh: fh.write(DUMMY_TWICE) - success, msg = fileio.move_file(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_file( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) - with open(self.tmp_dst, 'r') as fh: + with open(self.tmp_dst, "r") as fh: content = fh.read() self.assertEqual(content, DUMMY_TEXT) @@ -1168,12 +1225,13 @@ class MigSharedFileio__move_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVE) self.tmp_src = os.path.join(self.tmp_base, DUMMY_DIRECTORY_MOVE_SRC) @@ -1184,43 +1242,54 @@ def before_each(self): # └── subdir/ # └── file2.txt ensure_dirs_exist(os.path.join(self.tmp_src, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_src, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_src, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_moves_directory_recursively(self): """Test move_rec moves directory and contents""" - result = fileio.move_rec(self.tmp_src, self.tmp_dst, - self.configuration) + result = fileio.move_rec(self.tmp_src, self.tmp_dst, self.configuration) self.assertTrue(result) self.assertFalse(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) # Verify structure - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, - DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, DUMMY_SUBDIR, - DUMMY_FILE_TWO))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO) + ) + ) def test_extends_existing_destination(self): """Test move_rec extends existing destination directory""" # Create initial destination with some content ensure_dirs_exist(os.path.join(self.tmp_dst, DUMMY_TESTDIR)) - success, msg = fileio.move_rec(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_rec( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) # Verify structure with new src subdir and existing dir new_sub = os.path.basename(DUMMY_DIRECTORY_MOVE_SRC) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, new_sub, - DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, new_sub, - DUMMY_SUBDIR, - DUMMY_FILE_TWO))) - self.assertTrue(os.path.exists( - os.path.join(self.tmp_dst, DUMMY_TESTDIR))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, new_sub, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join( + self.tmp_dst, new_sub, DUMMY_SUBDIR, DUMMY_FILE_TWO + ) + ) + ) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_TESTDIR)) + ) class MigSharedFileio__copy_file(MigTestCase): @@ -1228,23 +1297,25 @@ class MigSharedFileio__copy_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_FILE_COPY_SRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_FILE_COPY_DST) - with open(self.tmp_src, 'w') as fh: + with open(self.tmp_src, "w") as fh: fh.write(DUMMY_TEXT) def test_copies_file(self): """Test copy_file successfully copies a file""" result = fileio.copy_file( - self.tmp_src, self.tmp_dst, self.configuration) + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(result) self.assertTrue(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) @@ -1252,12 +1323,13 @@ def test_copies_file(self): def test_overwrites_existing_destination(self): """Test copy_file overwrites existing destination file""" # Create initial destination file - with open(self.tmp_dst, 'w') as fh: + with open(self.tmp_dst, "w") as fh: fh.write(DUMMY_TWICE) result = fileio.copy_file( - self.tmp_src, self.tmp_dst, self.configuration) + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(result) - with open(self.tmp_dst, 'r') as fh: + with open(self.tmp_dst, "r") as fh: content = fh.read() self.assertEqual(content, DUMMY_TEXT) @@ -1267,36 +1339,41 @@ class MigSharedFileio__copy_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_DIRECTORY_COPYRECSRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_DIRECTORY_COPYRECDST) # Create a nested directory structure with files ensure_dirs_exist(self.tmp_src) ensure_dirs_exist(os.path.join(self.tmp_src, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_src, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_src, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_copies_directory_recursively(self): """Test copy_rec copies directory and contents""" - result = fileio.copy_rec( - self.tmp_src, self.tmp_dst, self.configuration) + result = fileio.copy_rec(self.tmp_src, self.tmp_dst, self.configuration) self.assertTrue(result) self.assertTrue(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) # Verify structure - self.assertTrue(os.path.exists(os.path.join( - self.tmp_dst, DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join( - self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO) + ) + ) class MigSharedFileio__check_empty_dir(MigTestCase): @@ -1304,20 +1381,20 @@ class MigSharedFileio__check_empty_dir(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.empty_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_EMPTY) - self.nonempty_path = os.path.join( - self.tmp_base, DUMMY_DIRECTORY_NESTED) + self.nonempty_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_NESTED) ensure_dirs_exist(self.empty_path) # Create non-empty directory structure ensure_dirs_exist(self.nonempty_path) - with open(os.path.join(self.nonempty_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.nonempty_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) def test_returns_true_for_empty(self): @@ -1340,20 +1417,21 @@ class MigSharedFileio__makedirs_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) - self.tmp_path = os.path.join(self.tmp_base, - DUMMY_DIRECTORY_MAKEDIRSREC) + self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_MAKEDIRSREC) def test_creates_directory_path(self): """Test makedirs_rec creates nested directories""" - nested_path = os.path.join(self.tmp_path, DUMMY_TESTDIR, DUMMY_SUBDIR, - DUMMY_TESTDIR) + nested_path = os.path.join( + self.tmp_path, DUMMY_TESTDIR, DUMMY_SUBDIR, DUMMY_TESTDIR + ) result = fileio.makedirs_rec(nested_path, self.configuration) self.assertTrue(result) self.assertTrue(os.path.exists(nested_path)) @@ -1370,7 +1448,7 @@ def test_fails_for_file_path(self): # Create a file at the path ensure_dirs_exist(self.tmp_path) file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with open(file_path, 'w') as fh: + with open(file_path, "w") as fh: fh.write(DUMMY_TEXT) result = fileio.makedirs_rec(file_path, self.configuration) self.assertFalse(result) @@ -1381,26 +1459,26 @@ class MigSharedFileio__check_access(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) - self.tmp_dir = os.path.join(self.tmp_base, - DUMMY_DIRECTORY_CHECKACCESS) + self.tmp_dir = os.path.join(self.tmp_base, DUMMY_DIRECTORY_CHECKACCESS) ensure_dirs_exist(self.tmp_dir) self.writeonly_file = os.path.join(self.tmp_dir, DUMMY_FILE_WO) self.readonly_file = os.path.join(self.tmp_dir, DUMMY_FILE_RO) self.readwrite_file = os.path.join(self.tmp_dir, DUMMY_FILE_RW) # Create test files with different permissions - with open(self.writeonly_file, 'w') as fh: + with open(self.writeonly_file, "w") as fh: fh.write(DUMMY_TEXT) - with open(self.readonly_file, 'w') as fh: + with open(self.readonly_file, "w") as fh: fh.write(DUMMY_TEXT) - with open(self.readwrite_file, 'w') as fh: + with open(self.readwrite_file, "w") as fh: fh.write(DUMMY_TEXT) # Set permissions @@ -1414,14 +1492,13 @@ def test_check_read_access_file(self): """Test check_read_access with readable file""" self.assertTrue(fileio.check_read_access(self.readwrite_file)) self.assertTrue(fileio.check_read_access(self.readonly_file)) - self.assertTrue(fileio.check_read_access(self.tmp_dir, - parent_dir=True)) + self.assertTrue(fileio.check_read_access(self.tmp_dir, parent_dir=True)) # Super-user has access to read and write all files! if os.getuid() == 0: self.assertTrue(fileio.check_read_access(self.writeonly_file)) else: self.assertFalse(fileio.check_read_access(self.writeonly_file)) - self.assertFalse(fileio.check_read_access('/invalid/path')) + self.assertFalse(fileio.check_read_access("/invalid/path")) def test_check_write_access_file(self): """Test check_write_access with writable file""" @@ -1432,7 +1509,7 @@ def test_check_write_access_file(self): self.assertTrue(fileio.check_write_access(self.readonly_file)) else: self.assertFalse(fileio.check_write_access(self.readonly_file)) - self.assertFalse(fileio.check_write_access('/invalid/path')) + self.assertFalse(fileio.check_write_access("/invalid/path")) def test_check_read_access_with_parent(self): """Test check_read_access with parent_dir True""" @@ -1448,78 +1525,108 @@ def test_check_write_access_with_parent(self): def test_check_readable(self): """Test check_readable wrapper function""" - self.assertTrue(fileio.check_readable(self.configuration, - self.readwrite_file)) - self.assertTrue(fileio.check_readable(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_readable(self.configuration, self.readwrite_file) + ) + self.assertTrue( + fileio.check_readable(self.configuration, self.readonly_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_readable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_readable(self.configuration, self.writeonly_file) + ) else: - self.assertFalse(fileio.check_readable(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readable(self.configuration, - '/invalid/path')) + self.assertFalse( + fileio.check_readable(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readable(self.configuration, "/invalid/path") + ) def test_check_writable(self): """Test check_writable wrapper function""" - self.assertTrue(fileio.check_writable(self.configuration, - self.readwrite_file)) - self.assertTrue(fileio.check_writable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_writable(self.configuration, self.readwrite_file) + ) + self.assertTrue( + fileio.check_writable(self.configuration, self.writeonly_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_writable(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_writable(self.configuration, self.readonly_file) + ) else: - self.assertFalse(fileio.check_writable(self.configuration, - self.readonly_file)) - self.assertFalse(fileio.check_writable(self.configuration, - "/no/such/file")) + self.assertFalse( + fileio.check_writable(self.configuration, self.readonly_file) + ) + self.assertFalse( + fileio.check_writable(self.configuration, "/no/such/file") + ) def test_check_readonly(self): """Test check_readonly wrapper function""" # Super-user has access to read and write all files! if os.getuid() == 0: # Test with read-only file path - self.assertFalse(fileio.check_readonly(self.configuration, - self.readonly_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readonly_file) + ) # Test with writable file - self.assertFalse(fileio.check_readonly(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readonly(self.configuration, - self.readwrite_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readwrite_file) + ) else: # Test with read-only file path - self.assertTrue(fileio.check_readonly(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_readonly(self.configuration, self.readonly_file) + ) # Test with writable file - self.assertFalse(fileio.check_readonly(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readonly(self.configuration, - self.readwrite_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readwrite_file) + ) def test_check_readwritable(self): """Test check_readwritable wrapper function""" - self.assertTrue(fileio.check_readwritable(self.configuration, - self.readwrite_file)) + self.assertTrue( + fileio.check_readwritable(self.configuration, self.readwrite_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_readwritable(self.configuration, - self.readonly_file)) - self.assertTrue(fileio.check_readwritable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_readwritable( + self.configuration, self.readonly_file + ) + ) + self.assertTrue( + fileio.check_readwritable( + self.configuration, self.writeonly_file + ) + ) else: - self.assertFalse(fileio.check_readwritable(self.configuration, - self.readonly_file)) - self.assertFalse(fileio.check_readwritable(self.configuration, - self.writeonly_file)) - - self.assertFalse(fileio.check_readwritable(self.configuration, - "/invalid/file")) + self.assertFalse( + fileio.check_readwritable( + self.configuration, self.readonly_file + ) + ) + self.assertFalse( + fileio.check_readwritable( + self.configuration, self.writeonly_file + ) + ) + + self.assertFalse( + fileio.check_readwritable(self.configuration, "/invalid/file") + ) def test_special_cases(self): """Test various special cases for access checks""" @@ -1533,11 +1640,13 @@ def test_special_cases(self): self.assertFalse(fileio.check_write_access(missing_path)) # Check with custom follow_symlink=False - self.assertTrue(fileio.check_read_access(self.readwrite_file, - follow_symlink=False)) - self.assertTrue(fileio.check_read_access(self.tmp_dir, True, - follow_symlink=False)) + self.assertTrue( + fileio.check_read_access(self.readwrite_file, follow_symlink=False) + ) + self.assertTrue( + fileio.check_read_access(self.tmp_dir, True, follow_symlink=False) + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_filemarks.py b/tests/test_mig_shared_filemarks.py index a640deba9..41a3a95c8 100644 --- a/tests/test_mig_shared_filemarks.py +++ b/tests/test_mig_shared_filemarks.py @@ -36,11 +36,12 @@ # Imports of the code under test from mig.shared.filemarks import get_filemark, reset_filemark, update_filemark + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -TEST_MARKS_DIR = 'TestMarks' -TEST_MARKS_FILE = 'file.mark' +TEST_MARKS_DIR = "TestMarks" +TEST_MARKS_FILE = "file.mark" class TestMigSharedFilemarks(MigTestCase): @@ -48,7 +49,7 @@ class TestMigSharedFilemarks(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def _prepare_mark_for_test(self, mark_name=None, timestamp=None): """Prepare test for mark_name with timestamp in default location""" @@ -57,7 +58,7 @@ def _prepare_mark_for_test(self, mark_name=None, timestamp=None): if timestamp is None: timestamp = time.time() self.marks_path = os.path.join(self.marks_base, mark_name) - open(self.marks_path, 'w').close() + open(self.marks_path, "w").close() os.utime(self.marks_path, (timestamp, timestamp)) return timestamp @@ -69,8 +70,9 @@ def _verify_mark_after_test(self, mark_name, timestamp): def before_each(self): """Setup fake configuration and temp dir before each test.""" - self.marks_base = os.path.join(self.configuration.mig_system_run, - TEST_MARKS_DIR) + self.marks_base = os.path.join( + self.configuration.mig_system_run, TEST_MARKS_DIR + ) ensure_dirs_exist(self.marks_base) self.marks_path = os.path.join(self.marks_base, TEST_MARKS_FILE) @@ -78,8 +80,9 @@ def test_update_filemark_create(self): """Test update_filemark creates mark file with timestamp""" timestamp = 4242 self.assertFalse(os.path.isfile(self.marks_path)) - update_result = update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, timestamp) + update_result = update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, timestamp + ) self.assertTrue(update_result) self.assertTrue(os.path.isfile(self.marks_path)) self.assertEqual(os.path.getmtime(self.marks_path), timestamp) @@ -89,8 +92,9 @@ def test_update_filemark_timestamp(self): timestamp = 424242 self._prepare_mark_for_test(TEST_MARKS_FILE, 4242) - update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, timestamp) + update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, timestamp + ) self.assertTrue(os.path.isfile(self.marks_path)) self.assertEqual(os.path.getmtime(self.marks_path), timestamp) @@ -98,8 +102,9 @@ def test_update_filemark_delete(self): """Test update_filemark deletes mark files with negative timestamp""" self._prepare_mark_for_test(TEST_MARKS_FILE) - delete_result = update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, -1) + delete_result = update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, -1 + ) self.assertTrue(delete_result) self.assertFalse(os.path.exists(self.marks_path)) @@ -108,23 +113,26 @@ def test_get_filemark_existing(self): timestamp = 4242 self._prepare_mark_for_test(TEST_MARKS_FILE, timestamp) - retrieved = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + retrieved = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertEqual(retrieved, timestamp) def test_get_filemark_missing(self): """Test get_filemark returns None for missing mark files""" self.assertFalse(os.path.isfile(self.marks_path)) - retrieved = get_filemark(self.configuration, self.marks_base, - 'missing.mark') + retrieved = get_filemark( + self.configuration, self.marks_base, "missing.mark" + ) self.assertIsNone(retrieved) def test_reset_filemark_single(self): """Test reset_filemark updates single mark timestamp to 0""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - [TEST_MARKS_FILE]) + reset_result = reset_filemark( + self.configuration, self.marks_base, [TEST_MARKS_FILE] + ) self.assertTrue(reset_result) self._verify_mark_after_test(TEST_MARKS_FILE, 0) @@ -133,18 +141,20 @@ def test_reset_filemark_delete(self): """Test reset_filemark deletes marks with delete=True""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - [TEST_MARKS_FILE], delete=True) + reset_result = reset_filemark( + self.configuration, self.marks_base, [TEST_MARKS_FILE], delete=True + ) self.assertTrue(reset_result) - retrieved = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + retrieved = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertIsNone(retrieved) self.assertFalse(os.path.exists(self.marks_path)) def test_reset_filemark_all(self): """Test reset_filemark resets all marks when mark_list=None""" - marks = ['mark1', 'mark2', 'mark3'] + marks = ["mark1", "mark2", "mark3"] for mark in marks: self._prepare_mark_for_test(mark) @@ -157,15 +167,22 @@ def test_reset_filemark_all(self): def test_update_filemark_fails_when_file_prevents_directory(self): """Test update_filemark fails when file prevents create directory""" # Create a file in the way to prevent subdir creation - self._prepare_mark_for_test('obstruct') - - with self.assertLogs(level='ERROR') as log_capture: - result = update_filemark(self.configuration, self.marks_base, - os.path.join('obstruct', 'test.mark'), - time.time()) + self._prepare_mark_for_test("obstruct") + + with self.assertLogs(level="ERROR") as log_capture: + result = update_filemark( + self.configuration, + self.marks_base, + os.path.join("obstruct", "test.mark"), + time.time(), + ) self.assertFalse(result) - self.assertTrue(any('in the way' in msg or 'could not create' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "in the way" in msg or "could not create" in msg + for msg in log_capture.output + ) + ) @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_update_filemark_directory_perms_failure(self): @@ -173,13 +190,17 @@ def test_update_filemark_directory_perms_failure(self): # Create a read-only parent directory to prevent subdir creation os.chmod(self.marks_base, stat.S_IRUSR) # Remove write permissions - with self.assertLogs(level='ERROR') as log_capture: - result = update_filemark(self.configuration, self.marks_base, - os.path.join('noaccess', 'test.mark'), - time.time()) + with self.assertLogs(level="ERROR") as log_capture: + result = update_filemark( + self.configuration, + self.marks_base, + os.path.join("noaccess", "test.mark"), + time.time(), + ) self.assertFalse(result) - self.assertTrue(any('Permission denied' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Permission denied" in msg for msg in log_capture.output) + ) @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_get_filemark_permission_denied(self): @@ -188,8 +209,9 @@ def test_get_filemark_permission_denied(self): # Remove read permissions through parent dir os.chmod(self.marks_base, 0) - result = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + result = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertIsNone(result) # Restore permissions so cleanup works os.chmod(self.marks_base, stat.S_IRWXU) @@ -198,20 +220,23 @@ def test_reset_filemark_string_mark_list(self): """Test reset_filemark handles single string mark_list""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + reset_result = reset_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertTrue(reset_result) self._verify_mark_after_test(TEST_MARKS_FILE, 0) def test_reset_filemark_invalid_mark_list(self): """Test reset_filemark fails with invalid mark_list type""" - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - {'invalid': 'type'}) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, {"invalid": "type"} + ) self.assertFalse(reset_result) - self.assertTrue(any('invalid mark list' in msg for msg in - log_capture.output)) + self.assertTrue( + any("invalid mark list" in msg for msg in log_capture.output) + ) def test_reset_filemark_all_missing_dir(self): """Test reset_filemark handles missing directory when mark_list=None""" @@ -222,41 +247,48 @@ def test_reset_filemark_all_missing_dir(self): @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_reset_filemark_partial_perms_failure(self): """Test reset_filemark with partial failure due to permissions""" - valid_mark = 'valid.mark' - invalid_mark = 'invalid.mark' + valid_mark = "valid.mark" + invalid_mark = "invalid.mark" invalid_path = os.path.join(self.marks_base, invalid_mark) # Create both marks but remove access to the latter self._prepare_mark_for_test(valid_mark) self._prepare_mark_for_test(invalid_mark) os.chmod(invalid_path, stat.S_IRUSR) # Remove write permissions - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - [valid_mark, invalid_mark]) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, [valid_mark, invalid_mark] + ) self.assertFalse(reset_result) # Should fail due to partial failure - self.assertTrue(any('Permission denied' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Permission denied" in msg for msg in log_capture.output) + ) self._verify_mark_after_test(valid_mark, 0) def test_reset_filemark_partial_file_prevents_directory_failure(self): """Test reset_filemark with partial failure due to a file in the way""" - valid_mark = 'valid.mark' - invalid_mark = os.path.join('obstruct', 'invalid.mark') + valid_mark = "valid.mark" + invalid_mark = os.path.join("obstruct", "invalid.mark") # Create valid mark and a file to prevent the invalid mark self._prepare_mark_for_test(valid_mark) # Create a file in the way to prevent subdir creation - self._prepare_mark_for_test('obstruct') + self._prepare_mark_for_test("obstruct") - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - [valid_mark, invalid_mark]) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, [valid_mark, invalid_mark] + ) self.assertFalse(reset_result) # Should fail due to partial failure - self.assertTrue(any('in the way' in msg or 'could not create' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "in the way" in msg or "could not create" in msg + for msg in log_capture.output + ) + ) self._verify_mark_after_test(valid_mark, 0) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_functionality_cat.py b/tests/test_mig_shared_functionality_cat.py index 619c3dd4e..a8081c105 100644 --- a/tests/test_mig_shared_functionality_cat.py +++ b/tests/test_mig_shared_functionality_cat.py @@ -28,17 +28,25 @@ """Unit tests of the MiG functionality file implementing the cat backend""" from __future__ import print_function + import importlib import os import shutil import sys import unittest -from tests.support import MIG_BASE, PY2, TEST_DATA_DIR, MigTestCase, testmain, \ - temppath, ensure_dirs_exist - from mig.shared.base import client_id_dir -from mig.shared.functionality.cat import _main as submain, main as realmain +from mig.shared.functionality.cat import _main as submain +from mig.shared.functionality.cat import main as realmain +from tests.support import ( + MIG_BASE, + PY2, + TEST_DATA_DIR, + MigTestCase, + ensure_dirs_exist, + temppath, + testmain, +) def create_http_environ(configuration): @@ -47,139 +55,169 @@ def create_http_environ(configuration): """ environ = {} - environ['MIG_CONF'] = configuration.config_file - environ['HTTP_HOST'] = 'localhost' - environ['PATH_INFO'] = '/' - environ['REMOTE_ADDR'] = '127.0.0.1' - environ['SCRIPT_URI'] = ''.join(('https://', environ['HTTP_HOST'], - environ['PATH_INFO'])) + environ["MIG_CONF"] = configuration.config_file + environ["HTTP_HOST"] = "localhost" + environ["PATH_INFO"] = "/" + environ["REMOTE_ADDR"] = "127.0.0.1" + environ["SCRIPT_URI"] = "".join( + ("https://", environ["HTTP_HOST"], environ["PATH_INFO"]) + ) return environ def _only_output_objects(output_objects, with_object_type=None): - return [o for o in output_objects if o['object_type'] == with_object_type] + return [o for o in output_objects if o["object_type"] == with_object_type] class MigSharedFunctionalityCat(MigTestCase): """Wrap unit tests for the corresponding module""" - TEST_CLIENT_ID = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_CLIENT_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): - self.test_user_dir = self._provision_test_user(self, self.TEST_CLIENT_ID) + self.test_user_dir = self._provision_test_user( + self, self.TEST_CLIENT_ID + ) self.test_environ = create_http_environ(self.configuration) def assertSingleOutputObject(self, output_objects, with_object_type=None): assert with_object_type is not None - found_objects = _only_output_objects(output_objects, - with_object_type=with_object_type) + found_objects = _only_output_objects( + output_objects, with_object_type=with_object_type + ) self.assertEqual(len(found_objects), 1) return found_objects[0] def test_file_serving_a_single_file_match(self): - with open(os.path.join(self.test_user_dir, 'foobar.txt'), 'w'): + with open(os.path.join(self.test_user_dir, "foobar.txt"), "w"): pass payload = { - 'path': ['foobar.txt'], + "path": ["foobar.txt"], } - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual content self.assertEqual(len(output_objects), 2) - self.assertSingleOutputObject(output_objects, - with_object_type='file_output') + self.assertSingleOutputObject( + output_objects, with_object_type="file_output" + ) def test_file_serving_at_limit(self): test_binary_file = os.path.realpath( - os.path.join(TEST_DATA_DIR, 'loading.gif')) + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join( - self.test_user_dir, 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } self.configuration.wwwserve_max_bytes = test_binary_file_size - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) self.assertEqual(len(output_objects), 2) - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='file_output') - self.assertEqual(len(relevant_obj['lines']), 1) - self.assertEqual(relevant_obj['lines'][0], test_binary_file_data) + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="file_output" + ) + self.assertEqual(len(relevant_obj["lines"]), 1) + self.assertEqual(relevant_obj["lines"][0], test_binary_file_data) def test_file_serving_over_limit_without_storage_protocols(self): - test_binary_file = os.path.realpath(os.path.join(TEST_DATA_DIR, - 'loading.gif')) + test_binary_file = os.path.realpath( + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join(self.test_user_dir, - 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } # NOTE: override default storage_protocols to empty in this test self.configuration.storage_protocols = [] self.configuration.wwwserve_max_bytes = test_binary_file_size - 1 - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual error message self.assertEqual(len(output_objects), 2) - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='error_text') - self.assertEqual(relevant_obj['text'], - "Site configuration prevents web serving contents " - "bigger than 3896 bytes") + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="error_text" + ) + self.assertEqual( + relevant_obj["text"], + "Site configuration prevents web serving contents " + "bigger than 3896 bytes", + ) def test_file_serving_over_limit_with_storage_protocols_sftp(self): - test_binary_file = os.path.realpath(os.path.join(TEST_DATA_DIR, - 'loading.gif')) + test_binary_file = os.path.realpath( + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join(self.test_user_dir, - 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } - self.configuration.storage_protocols = ['sftp'] + self.configuration.storage_protocols = ["sftp"] self.configuration.wwwserve_max_bytes = test_binary_file_size - 1 - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual error message - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='error_text') - self.assertEqual(relevant_obj['text'], - "Site configuration prevents web serving contents " - "bigger than 3896 bytes - please use better " - "alternatives (SFTP) to retrieve large data") + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="error_text" + ) + self.assertEqual( + relevant_obj["text"], + "Site configuration prevents web serving contents " + "bigger than 3896 bytes - please use better " + "alternatives (SFTP) to retrieve large data", + ) @unittest.skipIf(PY2, "Python 3 only") def test_main_passes_environ(self): @@ -187,17 +225,21 @@ def test_main_passes_environ(self): result = realmain(self.TEST_CLIENT_ID, {}, self.test_environ) except Exception as unexpectedexc: raise AssertionError( - "saw unexpected exception: %s" % (unexpectedexc,)) + "saw unexpected exception: %s" % (unexpectedexc,) + ) - (output_objects, status) = result - self.assertEqual(status[1], 'Client error') + output_objects, status = result + self.assertEqual(status[1], "Client error") - error_text_objects = _only_output_objects(output_objects, - with_object_type='error_text') + error_text_objects = _only_output_objects( + output_objects, with_object_type="error_text" + ) relevant_obj = error_text_objects[2] self.assertEqual( - relevant_obj['text'], 'Input arguments were rejected - not allowed for this script!') + relevant_obj["text"], + "Input arguments were rejected - not allowed for this script!", + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_functionality_datatransfer.py b/tests/test_mig_shared_functionality_datatransfer.py index 3fe9e5fdc..6887fd5a1 100644 --- a/tests/test_mig_shared_functionality_datatransfer.py +++ b/tests/test_mig_shared_functionality_datatransfer.py @@ -28,18 +28,19 @@ """Unit tests of the MiG functionality file implementing the datatransfer backend""" from __future__ import print_function + import os import mig.shared.returnvalues as returnvalues -from mig.shared.defaults import CSRF_MINIMAL from mig.shared.base import client_id_dir -from mig.shared.functionality.datatransfer import _main as submain, main as realmain - +from mig.shared.defaults import CSRF_MINIMAL +from mig.shared.functionality.datatransfer import _main as submain +from mig.shared.functionality.datatransfer import main as realmain from tests.support import ( MigTestCase, - testmain, - temppath, ensure_dirs_exist, + temppath, + testmain, ) @@ -66,25 +67,27 @@ def _only_output_objects(output_objects, with_object_type=None): class MigSharedFunctionalityDataTransfer(MigTestCase): """Wrap unit tests for the corresponding module""" - TEST_CLIENT_ID = ( - "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" - ) + TEST_CLIENT_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" def _provide_configuration(self): return "testconfig" def before_each(self): - self.test_user_dir = self._provision_test_user(self, self.TEST_CLIENT_ID) + self.test_user_dir = self._provision_test_user( + self, self.TEST_CLIENT_ID + ) self.test_environ = create_http_environ(self.configuration) def test_default_disabled_site_transfer(self): self.assertFalse(self.configuration.site_enable_transfers) result = realmain(self.TEST_CLIENT_ID, {}, self.test_environ) - (output_objects, status) = result + output_objects, status = result self.assertEqual(status, returnvalues.OK) - text_objects = _only_output_objects(output_objects, with_object_type="text") + text_objects = _only_output_objects( + output_objects, with_object_type="text" + ) self.assertEqual(len(text_objects), 1) self.assertIn("text", text_objects[0]) text_object = text_objects[0]["text"] @@ -95,7 +98,7 @@ def test_show_action_enabled_site_transfer(self): payload = {"action": ["show"]} self.configuration.site_enable_transfers = True - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, @@ -105,17 +108,22 @@ def test_show_action_enabled_site_transfer(self): self.assertEqual(status, returnvalues.OK) # We don't expect any text messages here - text_objects = _only_output_objects(output_objects, with_object_type="text") + text_objects = _only_output_objects( + output_objects, with_object_type="text" + ) self.assertEqual(len(text_objects), 0) def test_deltransfer_without_transfer_id(self): non_existing_transfer_id = "non-existing-transfer-id" - payload = {"action": ["deltransfer"], "transfer_id": [non_existing_transfer_id]} + payload = { + "action": ["deltransfer"], + "transfer_id": [non_existing_transfer_id], + } self.configuration.site_enable_transfers = True self.configuration.site_csrf_protection = CSRF_MINIMAL self.test_environ["REQUEST_METHOD"] = "post" - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, @@ -129,7 +137,8 @@ def test_deltransfer_without_transfer_id(self): ) self.assertEqual(len(error_text_objects), 1) self.assertEqual( - error_text_objects[0]["text"], "existing transfer_id is required for delete" + error_text_objects[0]["text"], + "existing transfer_id is required for delete", ) def test_redotransfer_without_transfer_id(self): @@ -142,7 +151,7 @@ def test_redotransfer_without_transfer_id(self): self.configuration.site_csrf_protection = CSRF_MINIMAL self.test_environ["REQUEST_METHOD"] = "post" - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, diff --git a/tests/test_mig_shared_install.py b/tests/test_mig_shared_install.py index b06f3c928..e67793879 100644 --- a/tests/test_mig_shared_install.py +++ b/tests/test_mig_shared_install.py @@ -27,21 +27,28 @@ """Unit tests for the migrid module pointed to in the filename""" -from past.builtins import basestring import binascii -from configparser import ConfigParser, NoSectionError, NoOptionError import difflib import io import os import pwd import sys +from configparser import ConfigParser, NoOptionError, NoSectionError -from tests.support import MIG_BASE, TEST_OUTPUT_DIR, MigTestCase, \ - testmain, temppath, cleanpath, is_path_within -from tests.support.fixturesupp import fixturepath +from past.builtins import basestring from mig.shared.defaults import keyword_auto from mig.shared.install import determine_timezone, generate_confs +from tests.support import ( + MIG_BASE, + TEST_OUTPUT_DIR, + MigTestCase, + cleanpath, + is_path_within, + temppath, + testmain, +) +from tests.support.fixturesupp import fixturepath class DummyPwInfo: @@ -70,40 +77,43 @@ class MigSharedInstall__determine_timezone(MigTestCase): def test_determines_tz_utc_fallback(self): timezone = determine_timezone( - _environ={}, _path_exists=lambda _: False, _print=noop) + _environ={}, _path_exists=lambda _: False, _print=noop + ) - self.assertEqual(timezone, 'UTC') + self.assertEqual(timezone, "UTC") def test_determines_tz_via_environ(self): - example_environ = { - 'TZ': 'Example/Enviromnent' - } + example_environ = {"TZ": "Example/Enviromnent"} timezone = determine_timezone(_environ=example_environ) - self.assertEqual(timezone, 'Example/Enviromnent') + self.assertEqual(timezone, "Example/Enviromnent") def test_determines_tz_via_localtime(self): def exists_localtime(value): - saw_call = value == '/etc/localtime' + saw_call = value == "/etc/localtime" exists_localtime.was_called = saw_call return saw_call + exists_localtime.was_called = False timezone = determine_timezone( - _environ={}, _path_exists=exists_localtime) + _environ={}, _path_exists=exists_localtime + ) self.assertTrue(exists_localtime.was_called) self.assertIsNotNone(timezone) def test_determines_tz_via_timedatectl(self): def exists_timedatectl(value): - saw_call = value == '/usr/bin/timedatectl' + saw_call = value == "/usr/bin/timedatectl" exists_timedatectl.was_called = saw_call return saw_call + exists_timedatectl.was_called = False timezone = determine_timezone( - _environ={}, _path_exists=exists_timedatectl, _print=noop) + _environ={}, _path_exists=exists_timedatectl, _print=noop + ) self.assertTrue(exists_timedatectl.was_called) self.assertIsNotNone(timezone) @@ -131,46 +141,48 @@ def assertConfigKey(self, generated, section, key, expected): self.assertEqual(actual, expected) def test_creates_output_directory_and_adds_active_symlink(self): - symlink_path = temppath('confs', self) - cleanpath('confs-foobar', self) + symlink_path = temppath("confs", self) + cleanpath("confs-foobar", self) - generate_confs(self.output_path, destination_suffix='-foobar') + generate_confs(self.output_path, destination_suffix="-foobar") - path_kind = self.assertPathExists('confs-foobar') + path_kind = self.assertPathExists("confs-foobar") self.assertEqual(path_kind, "dir") - path_kind = self.assertPathExists('confs') + path_kind = self.assertPathExists("confs") self.assertEqual(path_kind, "symlink") def test_creates_output_directory_and_repairs_active_symlink(self): - expected_generated_dir = cleanpath('confs-foobar', self) - symlink_path = temppath('confs', self) - nowhere_path = temppath('confs-nowhere', self) + expected_generated_dir = cleanpath("confs-foobar", self) + symlink_path = temppath("confs", self) + nowhere_path = temppath("confs-nowhere", self) # arrange pre-existing symlink pointing nowhere os.symlink(nowhere_path, symlink_path) - generate_confs(self.output_path, destination_suffix='-foobar') + generate_confs(self.output_path, destination_suffix="-foobar") generated_dir = os.path.realpath(symlink_path) self.assertEqual(generated_dir, expected_generated_dir) - def test_creates_output_directory_containing_a_standard_local_configuration(self): + def test_creates_output_directory_containing_a_standard_local_configuration( + self, + ): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) generate_confs( self.output_path, - destination_suffix='-stdlocal', - user='testuser', - group='testgroup', - mig_code='/home/mig/mig', - mig_certs='/home/mig/certs', - mig_state='/home/mig/state', - timezone='Test/Place', - crypto_salt='_TEST_CRYPTO_SALT'.zfill(32), - digest_salt='_TEST_DIGEST_SALT'.zfill(32), - seafile_secret='_test-seafile-secret='.zfill(44), - seafile_ccnetid='_TEST_SEAFILE_CCNETID'.zfill(40), + destination_suffix="-stdlocal", + user="testuser", + group="testgroup", + mig_code="/home/mig/mig", + mig_certs="/home/mig/certs", + mig_state="/home/mig/state", + timezone="Test/Place", + crypto_salt="_TEST_CRYPTO_SALT".zfill(32), + digest_salt="_TEST_DIGEST_SALT".zfill(32), + seafile_secret="_test-seafile-secret=".zfill(44), + seafile_ccnetid="_TEST_SEAFILE_CCNETID".zfill(40), _getpwnam=create_dummy_gpwnam(4321, 1234), ) @@ -193,10 +205,11 @@ def test_kwargs_for_paths_auto(self): def capture_defaulted(*args, **kwargs): capture_defaulted.kwargs = kwargs return args[0] + capture_defaulted.kwargs = None - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=capture_defaulted, _writefiles=noop, @@ -204,123 +217,139 @@ def capture_defaulted(*args, **kwargs): ) defaulted = capture_defaulted.kwargs - self.assertPathWithin(defaulted['mig_certs'], MIG_BASE) - self.assertPathWithin(defaulted['mig_state'], MIG_BASE) + self.assertPathWithin(defaulted["mig_certs"], MIG_BASE) + self.assertPathWithin(defaulted["mig_state"], MIG_BASE) def test_creates_output_files_with_datasafety(self): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) generate_confs( self.output_path, destination=symlink_path, - destination_suffix='-stdlocal', - datasafety_link='TEST_DATASAFETY_LINK', - datasafety_text='TEST_DATASAFETY_TEXT', + destination_suffix="-stdlocal", + datasafety_link="TEST_DATASAFETY_LINK", + datasafety_text="TEST_DATASAFETY_TEXT", _getpwnam=create_dummy_gpwnam(4321, 1234), ) - actual_file = self.assertFileExists('confs-stdlocal/MiGserver.conf') + actual_file = self.assertFileExists("confs-stdlocal/MiGserver.conf") self.assertConfigKey( - actual_file, 'SITE', 'datasafety_link', expected='TEST_DATASAFETY_LINK') + actual_file, + "SITE", + "datasafety_link", + expected="TEST_DATASAFETY_LINK", + ) self.assertConfigKey( - actual_file, 'SITE', 'datasafety_text', expected='TEST_DATASAFETY_TEXT') + actual_file, + "SITE", + "datasafety_text", + expected="TEST_DATASAFETY_TEXT", + ) def test_creates_output_files_with_permanent_freeze(self): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) - for arg_val in ('yes', 'no', 'foo bar baz'): + for arg_val in ("yes", "no", "foo bar baz"): generate_confs( self.output_path, destination=symlink_path, - destination_suffix='-stdlocal', + destination_suffix="-stdlocal", permanent_freeze=arg_val, _getpwnam=create_dummy_gpwnam(4321, 1234), ) - actual_file = self.assertFileExists('confs-stdlocal/MiGserver.conf') + actual_file = self.assertFileExists("confs-stdlocal/MiGserver.conf") self.assertConfigKey( - actual_file, 'SITE', 'permanent_freeze', expected=arg_val) + actual_file, "SITE", "permanent_freeze", expected=arg_val + ) def test_options_for_source_auto(self): - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", source=keyword_auto, _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - expected_template_dir = os.path.join(MIG_BASE, 'mig/install') + expected_template_dir = os.path.join(MIG_BASE, "mig/install") - self.assertEqual(options['template_dir'], expected_template_dir) + self.assertEqual(options["template_dir"], expected_template_dir) def test_options_for_source_relative(self): - (options, _) = generate_confs( - '/current/working/directory/mig/install', - source='.', + options, _ = generate_confs( + "/current/working/directory/mig/install", + source=".", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['template_dir'], - '/current/working/directory/mig/install') + self.assertEqual( + options["template_dir"], "/current/working/directory/mig/install" + ) def test_options_for_destination_auto(self): - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", destination=keyword_auto, - destination_suffix='_suffix', + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/some/arbitrary/path/confs') - self.assertEqual(options['destination_dir'], - '/some/arbitrary/path/confs_suffix') + self.assertEqual( + options["destination_link"], "/some/arbitrary/path/confs" + ) + self.assertEqual( + options["destination_dir"], "/some/arbitrary/path/confs_suffix" + ) def test_options_for_destination_relative(self): - (options, _) = generate_confs( - '/current/working/directory', - destination='generate-confs', - destination_suffix='_suffix', + options, _ = generate_confs( + "/current/working/directory", + destination="generate-confs", + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/current/working/directory/generate-confs') - self.assertEqual(options['destination_dir'], - '/current/working/directory/generate-confs_suffix') + self.assertEqual( + options["destination_link"], + "/current/working/directory/generate-confs", + ) + self.assertEqual( + options["destination_dir"], + "/current/working/directory/generate-confs_suffix", + ) def test_options_for_destination_absolute(self): - (options, _) = generate_confs( - '/current/working/directory', - destination='/some/other/place/confs', - destination_suffix='_suffix', + options, _ = generate_confs( + "/current/working/directory", + destination="/some/other/place/confs", + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/some/other/place/confs') - self.assertEqual(options['destination_dir'], - '/some/other/place/confs_suffix') + self.assertEqual(options["destination_link"], "/some/other/place/confs") + self.assertEqual( + options["destination_dir"], "/some/other/place/confs_suffix" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_jupyter.py b/tests/test_mig_shared_jupyter.py index 3eb53bf63..461282df8 100644 --- a/tests/test_mig_shared_jupyter.py +++ b/tests/test_mig_shared_jupyter.py @@ -34,9 +34,14 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain from mig.shared.jupyter import gen_openid_template +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) def noop(*args): @@ -48,7 +53,8 @@ class MigSharedJupyter(MigTestCase): def test_jupyter_gen_openid_template_openid_auth(self): filled = gen_openid_template( - "/some-jupyter-url", "MyDefine", "OpenID", _print=noop) + "/some-jupyter-url", "MyDefine", "OpenID", _print=noop + ) expected = """ @@ -63,7 +69,8 @@ def test_jupyter_gen_openid_template_openid_auth(self): def test_jupyter_gen_openid_template_oidc_auth(self): filled = gen_openid_template( - "/some-jupyter-url", "MyDefine", "openid-connect", _print=noop) + "/some-jupyter-url", "MyDefine", "openid-connect", _print=noop + ) expected = """ @@ -79,26 +86,26 @@ def test_jupyter_gen_openid_template_oidc_auth(self): def test_jupyter_gen_openid_template_invalid_url_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template(None, "MyDefine", - "openid-connect") + filled = gen_openid_template(None, "MyDefine", "openid-connect") def test_jupyter_gen_openid_template_invalid_define_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", None, - "no-such-auth-type") + filled = gen_openid_template( + "/some-jupyter-url", None, "no-such-auth-type" + ) def test_jupyter_gen_openid_template_missing_auth_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", "MyDefine", - None) + filled = gen_openid_template("/some-jupyter-url", "MyDefine", None) def test_jupyter_gen_openid_template_invalid_auth_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", "MyDefine", - "no-such-auth-type") + filled = gen_openid_template( + "/some-jupyter-url", "MyDefine", "no-such-auth-type" + ) # TODO: add more coverage of module -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_localfile.py b/tests/test_mig_shared_localfile.py index 62a5112aa..be9718bd7 100644 --- a/tests/test_mig_shared_localfile.py +++ b/tests/test_mig_shared_localfile.py @@ -27,21 +27,19 @@ """Unit tests for the migrid module pointed to in the filename""" -from contextlib import contextmanager import errno import fcntl import os import sys +from contextlib import contextmanager -sys.path.append(os.path.realpath( - os.path.join(os.path.dirname(__file__), ".."))) - -from tests.support import MigTestCase, temppath, testmain +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))) -from mig.shared.serverfile import LOCK_EX from mig.shared.localfile import LocalFile +from mig.shared.serverfile import LOCK_EX +from tests.support import MigTestCase, temppath, testmain -DUMMY_FILE = 'some_file' +DUMMY_FILE = "some_file" @contextmanager @@ -72,7 +70,7 @@ def assertPathLockedExclusive(self, file_path): # we were errantly able to acquire a lock, mark errored reraise = AssertionError("RERAISE_MUST_UNLOCK") except Exception as maybe_err: - if getattr(maybe_err, 'errno', None) == errno.EAGAIN: + if getattr(maybe_err, "errno", None) == errno.EAGAIN: # this is the expected exception - the logic tried to lock # a file that was (as we intended) already locked, meaning # this assertion has succeeded so we do not need to raise @@ -83,17 +81,19 @@ def assertPathLockedExclusive(self, file_path): if reraise is not None: # if marked errored and locked, cleanup the lock we acquired but shouldn't - if str(reraise) == 'RERAISE_MUST_UNLOCK': + if str(reraise) == "RERAISE_MUST_UNLOCK": fcntl.flock(conflicting_f, fcntl.LOCK_NB | fcntl.LOCK_UN) # raise a user-friendly error to avoid nested raise raise AssertionError( - "expected locked file: %s" % self.pretty_display_path(file_path)) + "expected locked file: %s" + % self.pretty_display_path(file_path) + ) def test_localfile_locking(self): some_file = temppath(DUMMY_FILE, self) - with managed_localfile(LocalFile(some_file, 'w')) as lfd: + with managed_localfile(LocalFile(some_file, "w")) as lfd: lfd.lock(LOCK_EX) self.assertEqual(lfd.get_lock_mode(), LOCK_EX) @@ -101,5 +101,5 @@ def test_localfile_locking(self): self.assertPathLockedExclusive(some_file) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_pwcrypto.py b/tests/test_mig_shared_pwcrypto.py index 784c6963e..42a7440a2 100644 --- a/tests/test_mig_shared_pwcrypto.py +++ b/tests/test_mig_shared_pwcrypto.py @@ -32,65 +32,87 @@ import sys import unittest -from tests.support import MigTestCase, FakeConfiguration, \ - cleanpath, temppath, testmain - -from mig.shared.defaults import POLICY_NONE, POLICY_WEAK, POLICY_MEDIUM, \ - POLICY_HIGH, POLICY_MODERN, POLICY_CUSTOM, PASSWORD_POLICIES -from mig.shared.pwcrypto import main as pwcrypto_main +from mig.shared.defaults import ( + PASSWORD_POLICIES, + POLICY_CUSTOM, + POLICY_HIGH, + POLICY_MEDIUM, + POLICY_MODERN, + POLICY_NONE, + POLICY_WEAK, +) from mig.shared.pwcrypto import * +from mig.shared.pwcrypto import main as pwcrypto_main +from tests.support import ( + FakeConfiguration, + MigTestCase, + cleanpath, + temppath, + testmain, +) DUMMY_USER = "dummy-user" DUMMY_ID = "dummy-id" # NOTE: these passwords are not and should not ever be used outside unit tests -DUMMY_WEAK_PW = 'foobar' -DUMMY_MEDIUM_PW = 'QZFnCp7h' -DUMMY_HIGH_PW = 'QZFnp7I-GZ' -DUMMY_MODERN_PW = 'QZFnCp7hmI1G' -DUMMY_GENERATED_PW = '7hmI1GnCpQZF' +DUMMY_WEAK_PW = "foobar" +DUMMY_MEDIUM_PW = "QZFnCp7h" +DUMMY_HIGH_PW = "QZFnp7I-GZ" +DUMMY_MODERN_PW = "QZFnCp7hmI1G" +DUMMY_GENERATED_PW = "7hmI1GnCpQZF" DUMMY_WEAK_PW_MD5 = "3858f62230ac3c915f300c664312c63f" -DUMMY_WEAK_PW_SHA256 = \ +DUMMY_WEAK_PW_SHA256 = ( "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2" -DUMMY_WEAK_PW_PBKDF2 = \ +) +DUMMY_WEAK_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$epib2rEg/HYTQZFnCp7hmIGZ6rzHnViy" -DUMMY_MEDIUM_PW_PBKDF2 = \ +) +DUMMY_MEDIUM_PW_PBKDF2 = ( "PBKDF2$sha256$10000$ebQHnDX1rzY9Rizb$0vUJ9/4ThhsN4cRaKYmOj4N0YKEsozTr" -DUMMY_HIGH_PW_PBKDF2 = \ +) +DUMMY_HIGH_PW_PBKDF2 = ( "PBKDF2$sha256$10000$HR+KcqLyQe3v0WSk$CtxMAomi8JHiI7gWc/PH5Ey00zW1Now3" +) DUMMY_MODERN_PW_MD5 = "a06d169a171ef7d4383b212457162d93" -DUMMY_MODERN_PW_SHA256 = \ +DUMMY_MODERN_PW_SHA256 = ( "d293dcb9762c87641ea1decbfe76d84ec51b13d6a1e688cdf1a838eebc5bb1a9" -DUMMY_MODERN_PW_PBKDF2 = \ +) +DUMMY_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf10n58FHrn1pjX" -DUMMY_MODERN_PW_DIGEST = \ - "DIGEST$custom$CONFSALT$64756D6D792D7265616C6D3A64756D6D792D7520DE71261F96A2FE48A67DD0877F2A2C" +) +DUMMY_MODERN_PW_DIGEST = "DIGEST$custom$CONFSALT$64756D6D792D7265616C6D3A64756D6D792D7520DE71261F96A2FE48A67DD0877F2A2C" DUMMY_MODERN_DIGEST_SCRAMBLE = "53BB031C1F96A2FE48A67DD0877F2A2C" DUMMY_MODERN_PW_SCRAMBLE = "53BB031C1F96A2FE48A67DD0877F2A2C" -DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED = b'xRsT1qHmiM3xqDjuvFuxqQ==.g4-Gt83uRrdvVWwX0SF1iMza3NyKJbp2sEYVkw==.ICAgIG1pZ3JpZCBhdXRoZW50aWNhdGVkMjA1MDAxMDE=' -DUMMY_MODERN_PW_RESET_TOKEN = b'gAAAAABo63hYqeHE7Db93pMxWn1sWzj2Z-6td2UhA5gKYa4KV096ndV-AO0pp6hrR9jXKcwWAouLCMiNC0BRudeCAYHoBii15lTRbP9b7JzvJjeusbidjRxqcJg0om6bbtSK1Rz_RBTq_jhdAk4v-7PccWlZ15dVJ4j-X3X4zSsBWIOR5y6Y-bA=' -DUMMY_METHOD = 'dummy-method' -DUMMY_OPERATION = 'dummy-operation' -DUMMY_ARGS = {'dummy-key': 'dummy-val'} -DUMMY_CSRF_TOKEN = '351cc47e0cd5c155fa4c4d3d0a6f1ee8f20eeb293ba13d59ede9d2a789687d3d' -DUMMY_CSRF_TRUST_TOKEN = '466c0bacd045a060a201c4e08c749c2e19743613422e0381ab0a57706c9fa2b8' -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED = b"xRsT1qHmiM3xqDjuvFuxqQ==.g4-Gt83uRrdvVWwX0SF1iMza3NyKJbp2sEYVkw==.ICAgIG1pZ3JpZCBhdXRoZW50aWNhdGVkMjA1MDAxMDE=" +DUMMY_MODERN_PW_RESET_TOKEN = b"gAAAAABo63hYqeHE7Db93pMxWn1sWzj2Z-6td2UhA5gKYa4KV096ndV-AO0pp6hrR9jXKcwWAouLCMiNC0BRudeCAYHoBii15lTRbP9b7JzvJjeusbidjRxqcJg0om6bbtSK1Rz_RBTq_jhdAk4v-7PccWlZ15dVJ4j-X3X4zSsBWIOR5y6Y-bA=" +DUMMY_METHOD = "dummy-method" +DUMMY_OPERATION = "dummy-operation" +DUMMY_ARGS = {"dummy-key": "dummy-val"} +DUMMY_CSRF_TOKEN = ( + "351cc47e0cd5c155fa4c4d3d0a6f1ee8f20eeb293ba13d59ede9d2a789687d3d" +) +DUMMY_CSRF_TRUST_TOKEN = ( + "466c0bacd045a060a201c4e08c749c2e19743613422e0381ab0a57706c9fa2b8" +) +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" # TODO: adjust password reset token helpers to handle configured services # it currently silently fails if not in migoid(c) or migcert # DUMMY_SERVICE = 'dummy-svc' -DUMMY_SERVICE = 'migoid' -DUMMY_REALM = 'dummy-realm' -DUMMY_PATH = 'dummy-path' -DUMMY_PATH_MD5 = 'd19033877452e8c217d3cddebbc37419' -DUMMY_SALT = b'53BB031C4ECCE4900BD64AB8EA361B6B' -DUMMY_ENTROPY = b'\xd2\x93\xdc\xb9v,\x87d\x1e\xa1\xde\xcb\xfev\xd8N\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9' -DUMMY_FERNET_KEY = 'NDg3OTcyNzE1NTQ2Nzc3ODYxNjc0NjRFRDZGMjNFQzY=' -DUMMY_AESGCM_KEY = b'48797271554677786167464ED6F23EC6' -DUMMY_AESGCM_STATIC_IV = b'\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9' -DUMMY_AESGCM_AAD_PREFIX = b'\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa9' -DUMMY_AESGCM_AAD = b' \xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa920500101' +DUMMY_SERVICE = "migoid" +DUMMY_REALM = "dummy-realm" +DUMMY_PATH = "dummy-path" +DUMMY_PATH_MD5 = "d19033877452e8c217d3cddebbc37419" +DUMMY_SALT = b"53BB031C4ECCE4900BD64AB8EA361B6B" +DUMMY_ENTROPY = b"\xd2\x93\xdc\xb9v,\x87d\x1e\xa1\xde\xcb\xfev\xd8N\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9" +DUMMY_FERNET_KEY = "NDg3OTcyNzE1NTQ2Nzc3ODYxNjc0NjRFRDZGMjNFQzY=" +DUMMY_AESGCM_KEY = b"48797271554677786167464ED6F23EC6" +DUMMY_AESGCM_STATIC_IV = ( + b"\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9" +) +DUMMY_AESGCM_AAD_PREFIX = b"\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa9" +DUMMY_AESGCM_AAD = b" \xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa920500101" # NOTE: we avoid any percent expansion values of actual date here to freeze AAD -DUMMY_FIXED_TIMESTAMP = '20500101' +DUMMY_FIXED_TIMESTAMP = "20500101" class MigSharedPwCrypto(MigTestCase): @@ -99,13 +121,15 @@ class MigSharedPwCrypto(MigTestCase): def before_each(self): test_user_home = temppath(DUMMY_HOME_DIR, self, ensure_dir=True) test_user_settings = cleanpath( - DUMMY_SETTINGS_DIR, self, ensure_dir=True) + DUMMY_SETTINGS_DIR, self, ensure_dir=True + ) # make two requisite root folders for the dummy user os.mkdir(os.path.join(test_user_home, DUMMY_USER)) os.mkdir(os.path.join(test_user_settings, DUMMY_USER)) # now create a configuration self.dummy_conf = FakeConfiguration( - user_home=test_user_home, user_settings=test_user_settings, + user_home=test_user_home, + user_settings=test_user_settings, site_password_policy="%s:12" % POLICY_MODERN, site_password_legacy_policy=POLICY_MEDIUM, site_password_cracklib=False, @@ -120,9 +144,11 @@ def before_each(self): # 'FakeConfiguration' has no 'site_password_legacy_policy' member # (no-member) unless we explicitly (re-)init it here self.dummy_conf.site_password_legacy_policy = getattr( - self.dummy_conf, 'site_password_legacy_policy', POLICY_NONE) - self.assertEqual(self.dummy_conf.site_password_legacy_policy, - POLICY_MEDIUM) + self.dummy_conf, "site_password_legacy_policy", POLICY_NONE + ) + self.assertEqual( + self.dummy_conf.site_password_legacy_policy, POLICY_MEDIUM + ) def test_best_crypt_salt(self): """Test selection of best salt based on salt availability in @@ -131,13 +157,13 @@ def test_best_crypt_salt(self): expected = DUMMY_SALT actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "best crypt salt not found") - self.dummy_conf.site_crypto_salt = '' + self.dummy_conf.site_crypto_salt = "" actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "2nd best crypt salt not found") - self.dummy_conf.site_password_salt = '' + self.dummy_conf.site_password_salt = "" actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "3rd best crypt salt not found") - self.dummy_conf.site_digest_salt = '' + self.dummy_conf.site_digest_salt = "" actual = None try: actual = best_crypt_salt(self.dummy_conf) @@ -154,10 +180,10 @@ def test_password_requirements(self): self.assertEqual(expected[2], result[2], "failed pw req errors") expected = (8, 3, []) result = password_requirements( - self.dummy_conf.site_password_legacy_policy) + self.dummy_conf.site_password_legacy_policy + ) self.assertEqual(expected[0], result[0], "failed legacy pw req chars") - self.assertEqual(expected[1], result[1], - "failed legacy pw req classes") + self.assertEqual(expected[1], result[1], "failed legacy pw req classes") self.assertEqual(expected[2], result[2], "failed legacy pw req errors") def test_parse_password_policy(self): @@ -174,56 +200,60 @@ def test_parse_password_policy(self): def test_assure_password_strength(self): """Test assure password strength""" try: - allow_weak = assure_password_strength(self.dummy_conf, - DUMMY_WEAK_PW) + allow_weak = assure_password_strength( + self.dummy_conf, DUMMY_WEAK_PW + ) except ValueError as vae: allow_weak = False self.assertFalse(allow_weak, "allowed weak pw") try: - allow_weak = assure_password_strength(self.dummy_conf, - DUMMY_WEAK_PW, - allow_legacy=True) + allow_weak = assure_password_strength( + self.dummy_conf, DUMMY_WEAK_PW, allow_legacy=True + ) except ValueError as vae: allow_weak = False self.assertFalse(allow_weak, "allowed weak pw with legacy") # NOTE: only allow medium with legacy try: - allow_medium = assure_password_strength(self.dummy_conf, - DUMMY_MEDIUM_PW) + allow_medium = assure_password_strength( + self.dummy_conf, DUMMY_MEDIUM_PW + ) except ValueError as vae: allow_medium = False self.assertFalse(allow_medium, "allowed medium pw without legacy") try: - allow_medium = assure_password_strength(self.dummy_conf, - DUMMY_MEDIUM_PW, - allow_legacy=True) + allow_medium = assure_password_strength( + self.dummy_conf, DUMMY_MEDIUM_PW, allow_legacy=True + ) except ValueError as vae: allow_medium = False self.assertTrue(allow_medium, "refused medium pw with legacy") # NOTE: only allow high with legacy - not long enough for modern try: - allow_high = assure_password_strength(self.dummy_conf, - DUMMY_HIGH_PW) + allow_high = assure_password_strength( + self.dummy_conf, DUMMY_HIGH_PW + ) except ValueError as vae: allow_high = False self.assertFalse(allow_high, "allowed high pw without legacy") try: - allow_high = assure_password_strength(self.dummy_conf, - DUMMY_HIGH_PW, - allow_legacy=True) + allow_high = assure_password_strength( + self.dummy_conf, DUMMY_HIGH_PW, allow_legacy=True + ) except ValueError as vae: allow_high = False self.assertTrue(allow_high, "refused high pw with legacy") try: - allow_modern = assure_password_strength(self.dummy_conf, - DUMMY_MODERN_PW) + allow_modern = assure_password_strength( + self.dummy_conf, DUMMY_MODERN_PW + ) except ValueError as vae: allow_modern = False self.assertTrue(allow_modern, "refused modern pw") try: - allow_modern = assure_password_strength(self.dummy_conf, - DUMMY_MODERN_PW, - allow_legacy=True) + allow_modern = assure_password_strength( + self.dummy_conf, DUMMY_MODERN_PW, allow_legacy=True + ) except ValueError as vae: allow_modern = False self.assertTrue(allow_modern, "refused modern pw with legacy") @@ -278,7 +308,7 @@ def test_make_hash_fixed_seed(self): random seed. """ expected = DUMMY_MODERN_PW_PBKDF2 - actual = make_hash(DUMMY_MODERN_PW, _urandom=lambda vlen: b'0' * vlen) + actual = make_hash(DUMMY_MODERN_PW, _urandom=lambda vlen: b"0" * vlen) self.assertEqual(actual, expected, "mismatch hashing string") def test_make_hash_constant_string(self): @@ -286,17 +316,25 @@ def test_make_hash_constant_string(self): random seed. I.e. the value may differ across interpreter invocations but remains constant in same interpreter. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) self.assertEqual(first, second, "same seed hashing is not constant") def test_check_hash_reject_weak(self): """Test basic hash checking of a constant weak complexity password""" expected = DUMMY_WEAK_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_WEAK_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_WEAK_PW, + expected, + strict_policy=True, + ) self.assertFalse(result, "check hash should fail on weak pw") def test_check_hash_reject_medium_without_legacy(self): @@ -304,9 +342,15 @@ def test_check_hash_reject_medium_without_legacy(self): without legacy password support. """ expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertFalse(result, "check hash strict should fail on medium pw") def test_check_hash_accept_medium_with_legacy(self): @@ -314,9 +358,15 @@ def test_check_hash_accept_medium_with_legacy(self): with legacy password support. """ expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True, - allow_legacy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + allow_legacy=True, + ) self.assertTrue(result, "check hash with legacy must accept medium pw") def test_check_hash_accept_high(self): @@ -325,9 +375,15 @@ def test_check_hash_accept_high(self): """ expected = DUMMY_HIGH_PW_PBKDF2 self.dummy_conf.site_password_policy = POLICY_HIGH - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_HIGH_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_HIGH_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertTrue(result, "check hash must accept high complexity pw") def test_check_hash_accept_modern(self): @@ -335,32 +391,57 @@ def test_check_hash_accept_modern(self): without legacy password support. """ expected = DUMMY_MODERN_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertTrue(result, "check hash must accept modern complexity pw") def test_check_hash_fixed(self): """Test basic hash checking of a fixed string""" expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + ) self.assertFalse(result, "check hash should reject medium pw") - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=False, - allow_legacy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=False, + allow_legacy=True, + ) self.assertTrue(result, "check hash failed medium pw when not strict") expected = DUMMY_MODERN_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + strict_policy=True, + ) self.assertTrue(result, "check hash failed modern pw") def test_check_hash_random(self): """Test basic hashing and hash checking of a random string""" random_pw = generate_random_password(self.dummy_conf) expected = make_hash(random_pw) - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, expected) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, random_pw, expected + ) self.assertTrue(result, "mismatch in random hash and check") def test_make_hash_variation(self): @@ -368,10 +449,12 @@ def test_make_hash_variation(self): I.e. the value likely remains constant in same interpreter but differs across interpreter invocations. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen]) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen] + ) self.assertNotEqual(first, second, "varying seed hashing is constant") def test_check_hash_despite_variation(self): @@ -379,15 +462,19 @@ def test_check_hash_despite_variation(self): I.e. the hash value differs across interpreter invocations but testing the same password against each succeeds. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen]) - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, first) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen] + ) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, DUMMY_MODERN_PW, first + ) self.assertTrue(result, "mismatch in 1st random password hash check") - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, second) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, DUMMY_MODERN_PW, second + ) self.assertTrue(result, "mismatch in 2nd random password hash check") def test_scramble_digest_fixed(self): @@ -405,16 +492,23 @@ def test_unscramble_digest_fixed(self): def test_make_digest_fixed(self): """Test basic digest of a fixed string""" expected = DUMMY_MODERN_PW_DIGEST - result = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) + result = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) self.assertEqual(expected, result, "mismatch in fixed digest") def test_check_digest_fixed(self): """Test basic digest checking of a fixed string""" expected = DUMMY_MODERN_PW_DIGEST - result = check_digest(self.dummy_conf, DUMMY_SERVICE, DUMMY_REALM, - DUMMY_USER, DUMMY_MODERN_PW, expected, - DUMMY_SALT) + result = check_digest( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_REALM, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + DUMMY_SALT, + ) self.assertTrue(result, "mismatch in fixed digest check") def test_check_digest_random(self): @@ -422,8 +516,15 @@ def test_check_digest_random(self): random_pw = generate_random_password(self.dummy_conf) random_salt = base64.b16encode(os.urandom(16)) expected = make_digest(DUMMY_REALM, DUMMY_USER, random_pw, random_salt) - result = check_digest(self.dummy_conf, DUMMY_SERVICE, DUMMY_REALM, - DUMMY_USER, random_pw, expected, random_salt) + result = check_digest( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_REALM, + DUMMY_USER, + random_pw, + expected, + random_salt, + ) self.assertTrue(result, "mismatch in random digest check") def test_digest_constant_string(self): @@ -431,10 +532,12 @@ def test_digest_constant_string(self): random seed. I.e. the value may differ across interpreter invocations but remains constant in same interpreter. """ - first = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) - second = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) + first = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) + second = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) self.assertEqual(first, second, "basic digest is not constant") def test_scramble_password_fixed(self): @@ -458,8 +561,14 @@ def test_make_scramble_fixed(self): def test_check_scramble_fixed(self): """Test basic scramble checking of a fixed string""" expected = DUMMY_MODERN_PW_SCRAMBLE - result = check_scramble(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, DUMMY_SALT) + result = check_scramble( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + DUMMY_SALT, + ) self.assertTrue(result, "mismatch in fixed scramble check") def test_check_scramble_random(self): @@ -467,8 +576,14 @@ def test_check_scramble_random(self): random_pw = generate_random_password(self.dummy_conf) random_salt = base64.b16encode(os.urandom(16)) expected = make_scramble(DUMMY_MODERN_PW, random_salt) - result = check_scramble(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, random_salt) + result = check_scramble( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + random_salt, + ) self.assertTrue(result, "mismatch in random scramble check") def test_scramble_constant_string(self): @@ -480,8 +595,7 @@ def test_scramble_constant_string(self): self.assertEqual(first, second, "basic scramble is not constant") def test_prepare_fernet_key(self): - """Test basic fernet secret key preparation on a fixed string. - """ + """Test basic fernet secret key preparation on a fixed string.""" expected = DUMMY_FERNET_KEY result = prepare_fernet_key(self.dummy_conf) self.assertEqual(expected, result, "failed prepare fernet key") @@ -494,8 +608,7 @@ def test_fernet_encrypt_decrypt(self): self.assertEqual(random_pw, result, "failed fernet enc+dec") def test_prepare_aesgcm_key(self): - """Test basic aesgcm secret key preparation on a fixed string. - """ + """Test basic aesgcm secret key preparation on a fixed string.""" expected = DUMMY_AESGCM_KEY result = prepare_aesgcm_key(self.dummy_conf) self.assertEqual(expected, result, "failed prepare aesgcm key") @@ -520,8 +633,11 @@ def test_prepare_aesgcm_aad_fixed(self): entropy and date value. """ expected = DUMMY_AESGCM_AAD - result = prepare_aesgcm_aad(self.dummy_conf, DUMMY_AESGCM_AAD_PREFIX, - aad_stamp=DUMMY_FIXED_TIMESTAMP) + result = prepare_aesgcm_aad( + self.dummy_conf, + DUMMY_AESGCM_AAD_PREFIX, + aad_stamp=DUMMY_FIXED_TIMESTAMP, + ) self.assertEqual(expected, result, "failed prepare aesgcm aad") def test_aesgcm_encrypt_static_iv_fixed(self): @@ -529,9 +645,12 @@ def test_aesgcm_encrypt_static_iv_fixed(self): initialization vector and date helper. """ expected = DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED - result = aesgcm_encrypt_password(self.dummy_conf, DUMMY_MODERN_PW, - init_vector=DUMMY_AESGCM_STATIC_IV, - aad_stamp=DUMMY_FIXED_TIMESTAMP) + result = aesgcm_encrypt_password( + self.dummy_conf, + DUMMY_MODERN_PW, + init_vector=DUMMY_AESGCM_STATIC_IV, + aad_stamp=DUMMY_FIXED_TIMESTAMP, + ) self.assertEqual(expected, result, "failed fixed aesgcm static iv enc") def test_aesgcm_decrypt_static_iv_fixed(self): @@ -539,9 +658,11 @@ def test_aesgcm_decrypt_static_iv_fixed(self): initialization vector. """ expected = DUMMY_MODERN_PW - result = aesgcm_decrypt_password(self.dummy_conf, - DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED, - init_vector=DUMMY_AESGCM_STATIC_IV) + result = aesgcm_decrypt_password( + self.dummy_conf, + DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED, + init_vector=DUMMY_AESGCM_STATIC_IV, + ) self.assertEqual(expected, result, "failed fixed aesgcm static iv den") def test_aesgcm_encrypt_decrypt_static_iv(self): @@ -551,10 +672,12 @@ def test_aesgcm_encrypt_decrypt_static_iv(self): random_pw = generate_random_password(self.dummy_conf) entropy = make_safe_hash(random_pw, False) static_iv = prepare_aesgcm_iv(self.dummy_conf, iv_entropy=entropy) - expected = aesgcm_encrypt_password(self.dummy_conf, random_pw, - init_vector=static_iv) - result = aesgcm_decrypt_password(self.dummy_conf, expected, - init_vector=static_iv) + expected = aesgcm_encrypt_password( + self.dummy_conf, random_pw, init_vector=static_iv + ) + result = aesgcm_decrypt_password( + self.dummy_conf, expected, init_vector=static_iv + ) self.assertEqual(random_pw, result, "failed aesgcm static iv enc+dec") def test_make_encrypt_decrypt(self): @@ -582,97 +705,127 @@ def test_check_encrypt(self): """Test basic password simple encrypt and decrypt on a random string""" random_pw = generate_random_password(self.dummy_conf) # IMPORTANT: only aesgcm_static generates constant enc value! - encrypted = make_encrypt(self.dummy_conf, random_pw, - algo="fernet") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='fernet') + encrypted = make_encrypt(self.dummy_conf, random_pw, algo="fernet") + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="fernet", + ) self.assertFalse(result, "invalid match in fernet encrypt check") encrypted = make_encrypt(self.dummy_conf, random_pw, algo="aesgcm") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='aesgcm') + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="aesgcm", + ) self.assertFalse(result, "invalid match in aesgcm encrypt check") - encrypted = make_encrypt(self.dummy_conf, random_pw, - algo="aesgcm_static") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='aesgcm_static') + encrypted = make_encrypt( + self.dummy_conf, random_pw, algo="aesgcm_static" + ) + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="aesgcm_static", + ) self.assertTrue(result, "mismatch in aesgcm_static encrypt check") def test_assure_reset_supported(self): """Test basic password reset token check for a fixed user and auth""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 - result = assure_reset_supported(self.dummy_conf, dummy_user, - DUMMY_SERVICE) + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 + result = assure_reset_supported( + self.dummy_conf, dummy_user, DUMMY_SERVICE + ) self.assertTrue(result, "failed assure reset supported") # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_generate_reset_token_fixed(self): """Test basic password reset token generate for a fixed string""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 timestamp = 42 expected = DUMMY_MODERN_PW_RESET_TOKEN - result = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) - self.assertEqual(expected, result, - "failed generate password reset token") + result = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) + self.assertEqual( + expected, result, "failed generate password reset token" + ) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_parse_reset_token_fixed(self): """Test basic password reset token parse for a fixed string""" timestamp = 42 - result = parse_reset_token(self.dummy_conf, - DUMMY_MODERN_PW_RESET_TOKEN, - DUMMY_SERVICE) + result = parse_reset_token( + self.dummy_conf, DUMMY_MODERN_PW_RESET_TOKEN, DUMMY_SERVICE + ) self.assertEqual(result[0], timestamp, "failed parse token time") - self.assertEqual(result[1], DUMMY_MODERN_PW_PBKDF2, - "failed parse token hash") + self.assertEqual( + result[1], DUMMY_MODERN_PW_PBKDF2, "failed parse token hash" + ) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_verify_reset_token_fixed(self): """Test basic password reset token verify for a fixed string""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 timestamp = 42 - result = verify_reset_token(self.dummy_conf, dummy_user, - DUMMY_MODERN_PW_RESET_TOKEN, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, + dummy_user, + DUMMY_MODERN_PW_RESET_TOKEN, + DUMMY_SERVICE, + timestamp, + ) self.assertTrue(result, "failed password reset token handling") def test_password_reset_token_generate_and_verify(self): """Test basic password reset token generate and verify helper""" random_pw = generate_random_password(self.dummy_conf) hashed_pw = make_hash(random_pw) - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = hashed_pw + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = hashed_pw timestamp = 42 - expected = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) + expected = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) parsed = parse_reset_token(self.dummy_conf, expected, DUMMY_SERVICE) self.assertEqual(parsed[0], timestamp, "failed parse token time") self.assertEqual(parsed[1], hashed_pw, "failed parse token hash") - result = verify_reset_token(self.dummy_conf, dummy_user, expected, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, dummy_user, expected, DUMMY_SERVICE, timestamp + ) self.assertTrue(result, "failed password reset token handling") def test_password_reset_token_verify_expired(self): """Test basic password reset token verify failure after it expired""" random_pw = generate_random_password(self.dummy_conf) hashed_pw = make_hash(random_pw) - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = hashed_pw + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = hashed_pw timestamp = 42 - expected = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) + expected = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) parsed = parse_reset_token(self.dummy_conf, expected, DUMMY_SERVICE) self.assertEqual(parsed[0], timestamp, "failed parse token time") self.assertEqual(parsed[1], hashed_pw, "failed parse token hash") timestamp = 4242 - result = verify_reset_token(self.dummy_conf, dummy_user, expected, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, dummy_user, expected, DUMMY_SERVICE, timestamp + ) self.assertFalse(result, "failed password reset token expiry check") def test_make_csrf_token_fixed(self): @@ -680,8 +833,9 @@ def test_make_csrf_token_fixed(self): client id. """ expected = DUMMY_CSRF_TOKEN - result = make_csrf_token(self.dummy_conf, DUMMY_METHOD, - DUMMY_OPERATION, DUMMY_ID) + result = make_csrf_token( + self.dummy_conf, DUMMY_METHOD, DUMMY_OPERATION, DUMMY_ID + ) self.assertEqual(expected, result, "failed make csrf token") def test_make_csrf_trust_token_fixed(self): @@ -689,8 +843,13 @@ def test_make_csrf_trust_token_fixed(self): client id and args. """ expected = DUMMY_CSRF_TRUST_TOKEN - result = make_csrf_trust_token(self.dummy_conf, DUMMY_METHOD, - DUMMY_OPERATION, DUMMY_ARGS, DUMMY_ID) + result = make_csrf_trust_token( + self.dummy_conf, + DUMMY_METHOD, + DUMMY_OPERATION, + DUMMY_ARGS, + DUMMY_ID, + ) self.assertEqual(expected, result, "failed make csrf trust token") def test_generate_random_password(self): @@ -705,8 +864,9 @@ def test_generate_random_password_fixed_seed(self): """Test basic generate password is constant with fixed random seed""" expected = DUMMY_GENERATED_PW result = generate_random_password(self.dummy_conf) - self.assertEqual(expected, result, - "failed generate password with fixed seed") + self.assertEqual( + expected, result, "failed generate password with fixed seed" + ) # TODO: migrate remaining inline checks from module here instead def test_existing_main(self): @@ -715,9 +875,11 @@ def raise_on_error_exit(exit_code): if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -726,5 +888,5 @@ def record_last_print(value): pwcrypto_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_safeeval.py b/tests/test_mig_shared_safeeval.py index 4bec4aa90..c7e827905 100644 --- a/tests/test_mig_shared_safeeval.py +++ b/tests/test_mig_shared_safeeval.py @@ -30,13 +30,11 @@ import os import sys -from tests.support import MigTestCase, testmain - from mig.shared.safeeval import * - +from tests.support import MigTestCase, testmain PWD_STR = os.getcwd() -PWD_BYTES = PWD_STR.encode('utf8') +PWD_BYTES = PWD_STR.encode("utf8") class MigSharedSafeeval(MigTestCase): @@ -44,49 +42,60 @@ class MigSharedSafeeval(MigTestCase): def test_subprocess_call(self): """Check that pwd call without args succeeds""" - retval = subprocess_call(['pwd'], stdout=subprocess_pipe) + retval = subprocess_call(["pwd"], stdout=subprocess_pipe) self.assertEqual(retval, 0, "unexpected subprocess call pwd retval") def test_subprocess_call_invalid(self): """Check that pwd call with invalid arg fails""" - retval = subprocess_call(['pwd', '-h'], stderr=subprocess_pipe) - self.assertNotEqual(retval, 0, - "unexpected subprocess call nosuchcommand retval") + retval = subprocess_call(["pwd", "-h"], stderr=subprocess_pipe) + self.assertNotEqual( + retval, 0, "unexpected subprocess call nosuchcommand retval" + ) def test_subprocess_check_output(self): """Check that pwd command output matches getcwd as bytes""" - data = subprocess_check_output(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_pipe).strip() - self.assertEqual(data, PWD_BYTES, - "mismatch in subprocess check pwd output") + data = subprocess_check_output( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_pipe + ).strip() + self.assertEqual( + data, PWD_BYTES, "mismatch in subprocess check pwd output" + ) def test_subprocess_check_output_text(self): """Check that pwd command output matches getcwd as string""" - data = subprocess_check_output(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_pipe, - text=True).strip() - self.assertEqual(data, PWD_STR, - "mismatch in subprocess check pwd output") + data = subprocess_check_output( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_pipe, text=True + ).strip() + self.assertEqual( + data, PWD_STR, "mismatch in subprocess check pwd output" + ) def test_subprocess_popen(self): """Check that pwd popen output matches getcwd as bytes""" - proc = subprocess_popen(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_stdout) + proc = subprocess_popen( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_stdout + ) retval = proc.wait() data = proc.stdout.read().strip() - self.assertEqual(data, PWD_BYTES, - "mismatch in subprocess popen pwd output") + self.assertEqual( + data, PWD_BYTES, "mismatch in subprocess popen pwd output" + ) def test_subprocess_popen_text(self): """Check that pwd popen output matches getcwd as string""" orig = os.getcwd() - proc = subprocess_popen(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_stdout, text=True) + proc = subprocess_popen( + ["pwd"], + stdout=subprocess_pipe, + stderr=subprocess_stdout, + text=True, + ) retval = proc.wait() data = proc.stdout.read().strip() - self.assertEqual(data, PWD_STR, - "mismatch in subprocess popen pwd output") + self.assertEqual( + data, PWD_STR, "mismatch in subprocess popen pwd output" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_safeinput.py b/tests/test_mig_shared_safeinput.py index 4a01dddd8..15e8b91d6 100644 --- a/tests/test_mig_shared_safeinput.py +++ b/tests/test_mig_shared_safeinput.py @@ -30,15 +30,26 @@ import base64 import codecs import sys + from past.builtins import basestring, unicode +from mig.shared.safeinput import ( + VALID_NAME_CHARACTERS, + InputException, + filter_commonname, +) +from mig.shared.safeinput import main as safeinput_main +from mig.shared.safeinput import ( + valid_alphanumeric, + valid_base_url, + valid_commonname, + valid_complex_url, + valid_path, + valid_printable, + valid_url, +) from tests.support import MigTestCase, testmain -from mig.shared.safeinput import main as safeinput_main, InputException, \ - filter_commonname, valid_alphanumeric, valid_commonname, valid_path, \ - valid_printable, valid_base_url, valid_url, valid_complex_url, \ - VALID_NAME_CHARACTERS - PY2 = sys.version_info[0] == 2 @@ -46,18 +57,18 @@ def as_string_of_unicode(value): assert isinstance(value, basestring) if not is_string_of_unicode(value): assert PY2, "unreachable unless Python 2" - return unicode(codecs.decode(value, 'utf8')) + return unicode(codecs.decode(value, "utf8")) return value def is_string_of_unicode(value): - return type(value) == type(u'') + return type(value) == type("") def _hex_wrap(val): """Insert a clearly marked hex representation of val""" # Please keep aligned with helper in mig/shared/functionality/autocreate.py - return ".X%s" % base64.b16encode(val.encode('utf8')).decode('utf8') + return ".X%s" % base64.b16encode(val.encode("utf8")).decode("utf8") class TestMigSharedSafeInput(MigTestCase): @@ -69,7 +80,7 @@ class TestMigSharedSafeInput(MigTestCase): PRINTABLE_CHARS = "abc123!@#" ACCENTED_VALID = "Renée Müller" ACCENTED_INVALID_EXOTIC = "Źaćâř" - DECOMPOSED_UNICODE = u"å" # a + combining ring above + DECOMPOSED_UNICODE = "å" # a + combining ring above # Commonname specific test constants APOSTROPHE_FULL_NAME = "John O'Connor" @@ -77,26 +88,28 @@ class TestMigSharedSafeInput(MigTestCase): APOSTROPHE_FULL_NAME_HEX = "John O.X27Connor" COMMONNAME_PERMITTED = ( - 'Firstname Lastname', - 'Test Æøå', - 'Test Überh4x0r', - 'Harry S. Truman', - u'Unicode æøå') + "Firstname Lastname", + "Test Æøå", + "Test Überh4x0r", + "Harry S. Truman", + "Unicode æøå", + ) COMMONNAME_PROHIBITED = ( "Invalid D'Angelo", - 'Test Maybe Invalid Źacãŕ', - 'Test Invalid ?', - 'Test HTML Invalid ') + "Test Maybe Invalid Źacãŕ", + "Test Invalid ?", + "Test HTML Invalid ", + ) - BASE_URL = 'https://www.migrid.org' - REGULAR_URL = 'https://www.migrid.org/wsgi-bin/ls.py?path=README&flags=v' - COMPLEX_URL = 'https://www.migrid.org/abc123@some.org/ls.py?path=R+D#HERE' - INVALID_URL = 'https://www.migrid.org/¾½§' + BASE_URL = "https://www.migrid.org" + REGULAR_URL = "https://www.migrid.org/wsgi-bin/ls.py?path=README&flags=v" + COMPLEX_URL = "https://www.migrid.org/abc123@some.org/ls.py?path=R+D#HERE" + INVALID_URL = "https://www.migrid.org/¾½§" def _provide_configuration(self): """Provide test configuration""" - return 'testconfig' + return "testconfig" def test_commonname_valid(self): """Test valid_commonname with acceptable and prohibited names""" @@ -130,7 +143,7 @@ def test_commonname_filter(self): self.assertTrue(len(filtered_cn) < len(test_cn_unicode)) # With default skip all chars in filtered_cn must be in original overlap = [i for i in filtered_cn if i in test_cn_unicode] - self.assertEqual(''.join(overlap), filtered_cn) + self.assertEqual("".join(overlap), filtered_cn) def test_commonname_filter_hexlify_illegal(self): """Test filter_commonname with hex encoding of illegal chars""" @@ -145,21 +158,23 @@ def test_commonname_filter_hexlify_illegal(self): filtered_cn = filter_commonname(test_cn, illegal_handler=_hex_wrap) # Invalid should be replaced with hexlify illegal_handler self.assertNotEqual(filtered_cn, test_cn_unicode) - self.assertIn('.X', filtered_cn) + self.assertIn(".X", filtered_cn) self.assertTrue(len(filtered_cn) > len(test_cn_unicode)) def test_filter_commonname_apostrophe_name_skip_illegal(self): """Test apostrophe handling with skip illegal_handler""" - result = filter_commonname(self.APOSTROPHE_FULL_NAME, - illegal_handler=None) + result = filter_commonname( + self.APOSTROPHE_FULL_NAME, illegal_handler=None + ) self.assertNotEqual(result, self.APOSTROPHE_FULL_NAME) self.assertNotIn("'", result) self.assertEqual(result, self.APOSTROPHE_FULL_NAME_SKIP) def test_filter_commonname_apostrophe_name_hexlify_illegal(self): """Test apostrophe handling with hex encode illegal_handler""" - result = filter_commonname(self.APOSTROPHE_FULL_NAME, - illegal_handler=_hex_wrap) + result = filter_commonname( + self.APOSTROPHE_FULL_NAME, illegal_handler=_hex_wrap + ) self.assertNotEqual(result, self.APOSTROPHE_FULL_NAME) self.assertNotIn("'", result) self.assertEqual(result, self.APOSTROPHE_FULL_NAME_HEX) @@ -283,14 +298,17 @@ class TestMigSharedSafeInput__legacy(MigTestCase): # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): """Run built-in self-tests and check output""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -300,5 +318,5 @@ def record_last_print(value): safeinput_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_serial.py b/tests/test_mig_shared_serial.py index c577e0b4e..c42dc0d8b 100644 --- a/tests/test_mig_shared_serial.py +++ b/tests/test_mig_shared_serial.py @@ -30,12 +30,17 @@ import os import sys +from mig.shared.serial import * from tests.support import MigTestCase, temppath, testmain -from mig.shared.serial import * class BasicSerial(MigTestCase): - BASIC_OBJECT = {'abc': 123, 'def': 'def', 'ghi': 42.0, 'accented': 'TéstÆøå'} + BASIC_OBJECT = { + "abc": 123, + "def": "def", + "ghi": 42.0, + "accented": "TéstÆøå", + } def test_pickle_string(self): orig = BasicSerial.BASIC_OBJECT @@ -49,5 +54,6 @@ def test_pickle_file(self): data = load(tmp_path) self.assertEqual(data, orig, "mismatch pickling string") -if __name__ == '__main__': + +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_settings.py b/tests/test_mig_shared_settings.py index 87b9e02e1..dabcbbc9e 100644 --- a/tests/test_mig_shared_settings.py +++ b/tests/test_mig_shared_settings.py @@ -32,23 +32,31 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.settings import load_settings, update_settings, \ - parse_and_save_settings +from mig.shared.settings import ( + load_settings, + parse_and_save_settings, + update_settings, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) DUMMY_USER = "dummy-user" -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_SETTINGS_DIR = "dummy_user_settings" DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) -DUMMY_SYSTEM_FILES_DIR = 'dummy_system_files' +DUMMY_SYSTEM_FILES_DIR = "dummy_system_files" DUMMY_SYSTEM_FILES_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SYSTEM_FILES_DIR) -DUMMY_TMP_DIR = 'dummy_tmp' -DUMMY_TMP_FILE = 'settings.mRSL' +DUMMY_TMP_DIR = "dummy_tmp" +DUMMY_TMP_FILE = "settings.mRSL" DUMMY_TMP_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_TMP_DIR) DUMMY_MRSL_PATH = os.path.join(DUMMY_TMP_PATH, DUMMY_TMP_FILE) -DUMMY_USER_INTERFACE = ['V3', 'V42'] -DUMMY_DEFAULT_UI = 'V42' +DUMMY_USER_INTERFACE = ["V3", "V42"] +DUMMY_DEFAULT_UI = "V42" DUMMY_INIT_MRSL = """ ::EMAIL:: john@doe.org @@ -65,10 +73,12 @@ ::SITE_USER_MENU:: people """ -DUMMY_CONF = FakeConfiguration(user_settings=DUMMY_SETTINGS_PATH, - mig_system_files=DUMMY_SYSTEM_FILES_PATH, - user_interface=DUMMY_USER_INTERFACE, - new_user_default_ui=DUMMY_DEFAULT_UI) +DUMMY_CONF = FakeConfiguration( + user_settings=DUMMY_SETTINGS_PATH, + mig_system_files=DUMMY_SYSTEM_FILES_PATH, + user_interface=DUMMY_USER_INTERFACE, + new_user_default_ui=DUMMY_DEFAULT_UI, +) class MigSharedSettings(MigTestCase): @@ -82,28 +92,31 @@ def test_settings_save_load(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, 'settings') + saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, "settings") self.assertTrue(os.path.exists(saved_path)) settings = load_settings(DUMMY_USER, DUMMY_CONF) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(settings, dict)) - self.assertEqual(settings['EMAIL'], ['john@doe.org']) - self.assertEqual(settings['SITE_USER_MENU'], - ['sharelinks', 'people', 'peers']) + self.assertEqual(settings["EMAIL"], ["john@doe.org"]) + self.assertEqual( + settings["SITE_USER_MENU"], ["sharelinks", "people", "peers"] + ) # NOTE: we no longer auto save default values for optional vars for key in settings.keys(): - self.assertTrue(key in ['EMAIL', 'SITE_USER_MENU']) + self.assertTrue(key in ["EMAIL", "SITE_USER_MENU"]) # Any saved USER_INTERFACE value must match configured default if set - self.assertEqual(settings.get('USER_INTERFACE', DUMMY_DEFAULT_UI), - DUMMY_DEFAULT_UI) + self.assertEqual( + settings.get("USER_INTERFACE", DUMMY_DEFAULT_UI), DUMMY_DEFAULT_UI + ) def test_settings_replace(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) @@ -113,25 +126,27 @@ def test_settings_replace(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_UPDATE_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) updated = load_settings(DUMMY_USER, DUMMY_CONF) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(updated, dict)) - self.assertEqual(updated['EMAIL'], ['jane@doe.org']) - self.assertEqual(updated['SITE_USER_MENU'], ['people']) + self.assertEqual(updated["EMAIL"], ["jane@doe.org"]) + self.assertEqual(updated["SITE_USER_MENU"], ["people"]) def test_update_settings(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) @@ -141,20 +156,21 @@ def test_update_settings(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - changes = {'EMAIL': ['john@doe.org', 'jane@doe.org']} + changes = {"EMAIL": ["john@doe.org", "jane@doe.org"]} defaults = {} updated = update_settings(DUMMY_USER, DUMMY_CONF, changes, defaults) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(updated, dict)) - self.assertEqual(updated['EMAIL'], ['john@doe.org', 'jane@doe.org']) + self.assertEqual(updated["EMAIL"], ["john@doe.org", "jane@doe.org"]) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_ssh.py b/tests/test_mig_shared_ssh.py index 7aa4ba3ec..954bd8c1a 100644 --- a/tests/test_mig_shared_ssh.py +++ b/tests/test_mig_shared_ssh.py @@ -32,10 +32,18 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.ssh import supported_pub_key_parsers, parse_pub_key, \ - generate_ssh_rsa_key_pair +from mig.shared.ssh import ( + generate_ssh_rsa_key_pair, + parse_pub_key, + supported_pub_key_parsers, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) class MigSharedSsh(MigTestCase): @@ -45,11 +53,11 @@ def test_ssh_key_generate_and_parse(self): parsers = supported_pub_key_parsers() # NOTE: should return a non-empty dict of algos and parsers self.assertTrue(parsers) - self.assertTrue('ssh-rsa' in parsers) + self.assertTrue("ssh-rsa" in parsers) # Generate common sized keys and parse the result for keysize in (2048, 3072, 4096): - (priv_key, pub_key) = generate_ssh_rsa_key_pair(size=keysize) + priv_key, pub_key = generate_ssh_rsa_key_pair(size=keysize) self.assertTrue(priv_key) self.assertTrue(pub_key) @@ -57,15 +65,16 @@ def test_ssh_key_generate_and_parse(self): try: parsed = parse_pub_key(pub_key) except ValueError as vae: - #print("Error in parsing pub key: %r" % vae) + # print("Error in parsing pub key: %r" % vae) parsed = None self.assertIsNotNone(parsed) - (priv_key, pub_key) = generate_ssh_rsa_key_pair(size=keysize, - encode_utf8=True) + priv_key, pub_key = generate_ssh_rsa_key_pair( + size=keysize, encode_utf8=True + ) self.assertTrue(priv_key) self.assertTrue(pub_key) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_transferfunctions.py b/tests/test_mig_shared_transferfunctions.py index fd8c7d5cc..53665259a 100644 --- a/tests/test_mig_shared_transferfunctions.py +++ b/tests/test_mig_shared_transferfunctions.py @@ -31,16 +31,27 @@ import sys import tempfile -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, temppath, testmain -from mig.shared.transferfunctions import get_transfers_path, \ - load_data_transfers, create_data_transfer, delete_data_transfer, \ - lock_data_transfers, unlock_data_transfers +from mig.shared.transferfunctions import ( + create_data_transfer, + delete_data_transfer, + get_transfers_path, + load_data_transfers, + lock_data_transfers, + unlock_data_transfers, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + temppath, + testmain, +) DUMMY_USER = "dummy-user" DUMMY_ID = "dummy-id" -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" def noop(*args, **kwargs): @@ -55,13 +66,15 @@ class MigSharedTransferfunctions(MigTestCase): def before_each(self): test_user_home = temppath(DUMMY_HOME_DIR, self, ensure_dir=True) test_user_settings = cleanpath( - DUMMY_SETTINGS_DIR, self, ensure_dir=True) + DUMMY_SETTINGS_DIR, self, ensure_dir=True + ) # make two requisite root folders for the dummy user os.mkdir(os.path.join(test_user_home, DUMMY_USER)) os.mkdir(os.path.join(test_user_settings, DUMMY_USER)) # now create a configuration - self.dummy_conf = FakeConfiguration(user_home=test_user_home, - user_settings=test_user_settings) + self.dummy_conf = FakeConfiguration( + user_home=test_user_home, user_settings=test_user_settings + ) def test_transfers_basic_locking_shared(self): dummy_conf = self.dummy_conf @@ -82,9 +95,11 @@ def test_transfers_basic_locking_ro_to_rw_exclusive(self): # Non-blocking exclusive locking of shared lock must fail ro_lock = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) rw_lock = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) self.assertTrue(ro_lock) self.assertFalse(rw_lock) @@ -99,9 +114,11 @@ def test_transfers_basic_locking_exclusive(self): rw_lock = lock_data_transfers(transfers_path, exclusive=True) # Non-blocking repeated shared or exclusive locking must fail ro_lock_again = lock_data_transfers( - transfers_path, exclusive=False, blocking=False) + transfers_path, exclusive=False, blocking=False + ) rw_lock_again = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) self.assertTrue(rw_lock) self.assertFalse(ro_lock_again) @@ -112,18 +129,19 @@ def test_transfers_basic_locking_exclusive(self): def test_create_and_delete_transfer(self): dummy_conf = self.dummy_conf - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}) + success, out = create_data_transfer( + dummy_conf, DUMMY_USER, {"transfer_id": DUMMY_ID} + ) self.assertTrue(success and DUMMY_ID in out) - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and transfers.get(DUMMY_ID, None)) - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID) + success, out = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID) self.assertTrue(success and out == DUMMY_ID) - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and transfers.get(DUMMY_ID, None) is None) @@ -131,38 +149,48 @@ def test_transfers_shared_read_locking(self): dummy_conf = self.dummy_conf transfers_path = get_transfers_path(dummy_conf, DUMMY_USER) # Init a dummy transfer to read and delete later - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) # take a shared ro lock up front ro_lock = lock_data_transfers(transfers_path, exclusive=False) # cases: - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and DUMMY_ID in transfers) # Create with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Delete with repeated locking should fail - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=True, blocking=False) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=True, blocking=False + ) self.assertFalse(success) # Verify unchanged - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and DUMMY_ID in transfers) # Unlock all to leave critical section and allow clean up unlock_data_transfers(ro_lock) # Delete with locking should be fine again - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=True) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=True + ) self.assertTrue(success and out == DUMMY_ID) def test_transfers_exclusive_write_locking(self): @@ -174,40 +202,51 @@ def test_transfers_exclusive_write_locking(self): # cases: # Non-blocking load with repeated locking should fail - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER, - do_lock=True, blocking=False) + success, transfers = load_data_transfers( + dummy_conf, DUMMY_USER, do_lock=True, blocking=False + ) self.assertFalse(success) # Load without repeated locking should be fine - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER, - do_lock=False) + success, transfers = load_data_transfers( + dummy_conf, DUMMY_USER, do_lock=False + ) self.assertTrue(success) # Non-blocking create with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Create without repeated locking should be fine - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=False) + success, out = create_data_transfer( + dummy_conf, DUMMY_USER, {"transfer_id": DUMMY_ID}, do_lock=False + ) self.assertTrue(success) # Non-blocking delete with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Delete without repeated locking should be fine - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=False) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=False + ) self.assertTrue(success) unlock_data_transfers(rw_lock) -if __name__ == '__main__': +if __name__ == "__main__": testmain(failfast=True) diff --git a/tests/test_mig_shared_url.py b/tests/test_mig_shared_url.py index 1f029775d..b81cf9799 100644 --- a/tests/test_mig_shared_url.py +++ b/tests/test_mig_shared_url.py @@ -32,8 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import MigTestCase, FakeConfiguration, testmain from mig.shared.url import _get_site_urls, check_local_site_url +from tests.support import FakeConfiguration, MigTestCase, testmain def _generate_dynamic_site_urls(url_base_list): @@ -42,40 +42,58 @@ def _generate_dynamic_site_urls(url_base_list): """ site_urls = [] for url_base in url_base_list: - site_urls += ['%s' % url_base, '%s/' % url_base, - '%s/wsgi-bin/home.py' % url_base, - '%s/wsgi-bin/logout.py' % url_base, - '%s/wsgi-bin/logout.py?return_url=' % url_base, - '%s/wsgi-bin/logout.py?return_url=%s' % (url_base, - ENC_URL) - ] + site_urls += [ + "%s" % url_base, + "%s/" % url_base, + "%s/wsgi-bin/home.py" % url_base, + "%s/wsgi-bin/logout.py" % url_base, + "%s/wsgi-bin/logout.py?return_url=" % url_base, + "%s/wsgi-bin/logout.py?return_url=%s" % (url_base, ENC_URL), + ] return site_urls -DUMMY_CONF = FakeConfiguration(migserver_http_url='http://myfqdn.org', - migserver_https_url='https://myfqdn.org', - migserver_https_mig_cert_url='', - migserver_https_ext_cert_url='', - migserver_https_mig_oid_url='', - migserver_https_ext_oid_url='', - migserver_https_mig_oidc_url='', - migserver_https_ext_oidc_url='', - migserver_https_sid_url='', - migserver_public_url='', - migserver_public_alias_url='') -ENC_URL = 'https%3A%2F%2Fsomewhere.org%2Fsub%0A' +DUMMY_CONF = FakeConfiguration( + migserver_http_url="http://myfqdn.org", + migserver_https_url="https://myfqdn.org", + migserver_https_mig_cert_url="", + migserver_https_ext_cert_url="", + migserver_https_mig_oid_url="", + migserver_https_ext_oid_url="", + migserver_https_mig_oidc_url="", + migserver_https_ext_oidc_url="", + migserver_https_sid_url="", + migserver_public_url="", + migserver_public_alias_url="", +) +ENC_URL = "https%3A%2F%2Fsomewhere.org%2Fsub%0A" # Static site-anchored and dynamic full URLs for local site URL check success -LOCAL_SITE_URLS = ['', 'abc', 'abc.txt', '/', '/bla', '/bla#anchor', - '/bla/', '/bla/#anchor', '/bla/bla', '/bla/bla/bla', - '//bla//', './bla', './bla/', './bla/bla', - './bla/bla/bla', 'logout.py', 'logout.py?bla=', - '/cgi-sid/logout.py', '/cgi-sid/logout.py?bla=bla', - '/cgi-sid/logout.py?return_url=%s' % ENC_URL, - ] +LOCAL_SITE_URLS = [ + "", + "abc", + "abc.txt", + "/", + "/bla", + "/bla#anchor", + "/bla/", + "/bla/#anchor", + "/bla/bla", + "/bla/bla/bla", + "//bla//", + "./bla", + "./bla/", + "./bla/bla", + "./bla/bla/bla", + "logout.py", + "logout.py?bla=", + "/cgi-sid/logout.py", + "/cgi-sid/logout.py?bla=bla", + "/cgi-sid/logout.py?return_url=%s" % ENC_URL, +] LOCAL_BASE_URLS = _get_site_urls(DUMMY_CONF) LOCAL_SITE_URLS += _generate_dynamic_site_urls(LOCAL_SITE_URLS) # Dynamic full URLs for local site URL check failure -REMOTE_BASE_URLS = ['https://someevilsite.com', 'ftp://someevilsite.com'] +REMOTE_BASE_URLS = ["https://someevilsite.com", "ftp://someevilsite.com"] REMOTE_SITE_URLS = _generate_dynamic_site_urls(REMOTE_BASE_URLS) @@ -85,15 +103,19 @@ class BasicUrl(MigTestCase): def test_valid_local_site_urls(self): """Check known valid static and dynamic URLs""" for url in LOCAL_SITE_URLS: - self.assertTrue(check_local_site_url(DUMMY_CONF, url), - "Local site url should succeed for %s" % url) + self.assertTrue( + check_local_site_url(DUMMY_CONF, url), + "Local site url should succeed for %s" % url, + ) def test_invalid_local_site_urls(self): """Check known invalid URLs""" for url in REMOTE_SITE_URLS: - self.assertFalse(check_local_site_url(DUMMY_CONF, url), - "Local site url should fail for %s" % url) + self.assertFalse( + check_local_site_url(DUMMY_CONF, url), + "Local site url should fail for %s" % url, + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_useradm.py b/tests/test_mig_shared_useradm.py index e5d1a3b16..be04ff21b 100644 --- a/tests/test_mig_shared_useradm.py +++ b/tests/test_mig_shared_useradm.py @@ -38,62 +38,88 @@ from past.builtins import basestring # Imports required for the unit test wrapping -from mig.shared.defaults import DEFAULT_USER_ID_FORMAT, htaccess_filename, \ - keyword_auto +from mig.shared.defaults import ( + DEFAULT_USER_ID_FORMAT, + htaccess_filename, + keyword_auto, +) + # Imports of the code under test -from mig.shared.useradm import _ensure_dirs_needed_for_userdb, \ - assure_current_htaccess, create_user +from mig.shared.useradm import ( + _ensure_dirs_needed_for_userdb, + assure_current_htaccess, + create_user, +) + # Imports required for the unit tests themselves -from tests.support import MIG_BASE, TEST_OUTPUT_DIR, FakeConfiguration, \ - MigTestCase, cleanpath, ensure_dirs_exist, is_path_within, testmain +from tests.support import ( + MIG_BASE, + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + ensure_dirs_exist, + is_path_within, + testmain, +) from tests.support.fixturesupp import FixtureAssertMixin from tests.support.picklesupp import PickleAssertMixin -DUMMY_USER = 'dummy-user' -DUMMY_STALE_USER = 'dummy-stale-user' -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' -DUMMY_MRSL_FILES_DIR = 'dummy_mrsl_files' -DUMMY_RESOURCE_PENDING_DIR = 'dummy_resource_pending' -DUMMY_CACHE_DIR = 'dummy_user_cache' +DUMMY_USER = "dummy-user" +DUMMY_STALE_USER = "dummy-stale-user" +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" +DUMMY_MRSL_FILES_DIR = "dummy_mrsl_files" +DUMMY_RESOURCE_PENDING_DIR = "dummy_resource_pending" +DUMMY_CACHE_DIR = "dummy_user_cache" DUMMY_HOME_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_HOME_DIR) DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) DUMMY_MRSL_FILES_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_MRSL_FILES_DIR) -DUMMY_RESOURCE_PENDING_PATH = os.path.join(TEST_OUTPUT_DIR, - DUMMY_RESOURCE_PENDING_DIR) +DUMMY_RESOURCE_PENDING_PATH = os.path.join( + TEST_OUTPUT_DIR, DUMMY_RESOURCE_PENDING_DIR +) DUMMY_CACHE_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_CACHE_DIR) -DUMMY_USER_DICT = {'distinguished_name': DUMMY_USER, - 'short_id': '%s@my.org' % DUMMY_USER} -DUMMY_REL_HTACCESS_PATH = os.path.join(DUMMY_HOME_DIR, DUMMY_USER, - htaccess_filename) -DUMMY_HTACCESS_PATH = DUMMY_REL_HTACCESS_PATH.replace(DUMMY_HOME_DIR, - DUMMY_HOME_PATH) -DUMMY_REL_HTACCESS_BACKUP_PATH = os.path.join(DUMMY_CACHE_DIR, DUMMY_USER, - "%s.old" % htaccess_filename) +DUMMY_USER_DICT = { + "distinguished_name": DUMMY_USER, + "short_id": "%s@my.org" % DUMMY_USER, +} +DUMMY_REL_HTACCESS_PATH = os.path.join( + DUMMY_HOME_DIR, DUMMY_USER, htaccess_filename +) +DUMMY_HTACCESS_PATH = DUMMY_REL_HTACCESS_PATH.replace( + DUMMY_HOME_DIR, DUMMY_HOME_PATH +) +DUMMY_REL_HTACCESS_BACKUP_PATH = os.path.join( + DUMMY_CACHE_DIR, DUMMY_USER, "%s.old" % htaccess_filename +) DUMMY_HTACCESS_BACKUP_PATH = DUMMY_REL_HTACCESS_BACKUP_PATH.replace( - DUMMY_CACHE_DIR, DUMMY_CACHE_PATH) + DUMMY_CACHE_DIR, DUMMY_CACHE_PATH +) DUMMY_REQUIRE_USER = 'require user "%s"' % DUMMY_USER DUMMY_REQUIRE_STALE_USER = 'require user "%s"' % DUMMY_STALE_USER -DUMMY_CONF = FakeConfiguration(user_home=DUMMY_HOME_PATH, - user_settings=DUMMY_SETTINGS_PATH, - user_cache=DUMMY_CACHE_PATH, - mrsl_files_dir=DUMMY_MRSL_FILES_PATH, - resource_pending=DUMMY_RESOURCE_PENDING_PATH, - site_user_id_format=DEFAULT_USER_ID_FORMAT, - short_title='dummysite', - support_email='support@dummysite.org', - user_openid_providers=['dummyoidprovider.org'], - ) - - -class TestMigSharedUsedadm_create_user(MigTestCase, - FixtureAssertMixin, - PickleAssertMixin): +DUMMY_CONF = FakeConfiguration( + user_home=DUMMY_HOME_PATH, + user_settings=DUMMY_SETTINGS_PATH, + user_cache=DUMMY_CACHE_PATH, + mrsl_files_dir=DUMMY_MRSL_FILES_PATH, + resource_pending=DUMMY_RESOURCE_PENDING_PATH, + site_user_id_format=DEFAULT_USER_ID_FORMAT, + short_title="dummysite", + support_email="support@dummysite.org", + user_openid_providers=["dummyoidprovider.org"], +) + + +class TestMigSharedUsedadm_create_user( + MigTestCase, FixtureAssertMixin, PickleAssertMixin +): """Coverage of useradm create_user function.""" - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" TEST_USER_DN_GDP = "%s/GDP" % (TEST_USER_DN,) - TEST_USER_PASSWORD_HASH = 'PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix' + TEST_USER_PASSWORD_HASH = ( + "PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix" + ) def before_each(self): configuration = self.configuration @@ -101,73 +127,82 @@ def before_each(self): _ensure_dirs_needed_for_userdb(self.configuration) self.expected_user_db_home = os.path.normpath( - configuration.user_db_home) + configuration.user_db_home + ) self.expected_user_db_file = os.path.join( - self.expected_user_db_home, 'MiG-users.db') + self.expected_user_db_home, "MiG-users.db" + ) ensure_dirs_exist(self.configuration.mig_system_files) def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_user_db_is_created(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "user@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['password'] = "password" - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "user@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["password"] = "password" + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) # presence of user home path_kind = MigTestCase._absolute_path_kind(self.expected_user_db_home) - self.assertEqual(path_kind, 'dir') + self.assertEqual(path_kind, "dir") # presence of user db path_kind = MigTestCase._absolute_path_kind(self.expected_user_db_file) - self.assertEqual(path_kind, 'file') + self.assertEqual(path_kind, "file") def test_user_creation_records_a_user(self): def _adjust_user_dict_for_compare(user_obj): obj = dict(user_obj) - obj['created'] = 9999999999.9999999 - obj['expire'] = 9999999999.9999999 - obj['unique_id'] = '__UNIQUE_ID__' + obj["created"] = 9999999999.9999999 + obj["expire"] = 9999999999.9999999 + obj["unique_id"] = "__UNIQUE_ID__" return obj expected_user_id = self.TEST_USER_DN expected_user_password_hash = self.TEST_USER_PASSWORD_HASH user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = expected_user_password_hash - - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) - - pickled = self.assertPickledFile(self.expected_user_db_file, - apply_hints=['convert_dict_bytes_to_strings_kv']) + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = expected_user_password_hash + + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) + + pickled = self.assertPickledFile( + self.expected_user_db_file, + apply_hints=["convert_dict_bytes_to_strings_kv"], + ) self.assertIn(expected_user_id, pickled) - prepared = self.prepareFixtureAssert('MiG-users.db--example', - fixture_format='json') + prepared = self.prepareFixtureAssert( + "MiG-users.db--example", fixture_format="json" + ) # TODO: remove resetting the handful of keys here # this is done to allow the comparision to succeed actual_user_object = _adjust_user_dict_for_compare( - pickled[expected_user_id]) + pickled[expected_user_id] + ) expected_user_object = _adjust_user_dict_for_compare( - prepared.fixture_data[expected_user_id]) + prepared.fixture_data[expected_user_id] + ) self.maxDiff = None self.assertEqual(actual_user_object, expected_user_object) @@ -176,95 +211,117 @@ def test_user_creation_records_a_user_with_gdp(self): self.configuration.site_enable_gdp = True user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = self.TEST_USER_PASSWORD_HASH + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = self.TEST_USER_PASSWORD_HASH # explicitly setting set a DN suffixed user DN to force GDP - user_dict['distinguished_name'] = self.TEST_USER_DN_GDP + user_dict["distinguished_name"] = self.TEST_USER_DN_GDP try: - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) except: self.assertFalse(True, "should not be reached") def test_user_creation_and_renew_records_a_user(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = self.TEST_USER_PASSWORD_HASH + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = self.TEST_USER_PASSWORD_HASH try: - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) except: self.assertFalse(True, "should not be reached") try: - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) except: self.assertFalse(True, "should not be reached") def test_user_creation_fails_in_renew_when_locked(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = self.TEST_USER_PASSWORD_HASH + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = self.TEST_USER_PASSWORD_HASH # explicitly setting set a DN suffixed user DN to force GDP - user_dict['distinguished_name'] = self.TEST_USER_DN_GDP - user_dict['status'] = "locked" + user_dict["distinguished_name"] = self.TEST_USER_DN_GDP + user_dict["status"] = "locked" try: - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) except: self.assertFalse(True, "should not be reached") def test_user_creation_with_id_collission_fails(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "user@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['password'] = "password" - user_dict['distinguished_name'] = self.TEST_USER_DN + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "user@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["password"] = "password" + user_dict["distinguished_name"] = self.TEST_USER_DN try: - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) except: self.assertFalse(True, "should not be reached") # NOTE: reset distinguished_name and introduce an ID conflict to test - del user_dict['distinguished_name'] - user_dict['organization'] = "Another Org" + del user_dict["distinguished_name"] + user_dict["organization"] = "Another Org" with self.assertRaises(Exception): - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) class MigSharedUseradm__assure_current_htaccess(MigTestCase): @@ -289,16 +346,15 @@ def assertHtaccessRequireUserClause(self, generated, expected): with io.open(generated) as htaccess_file: generated = htaccess_file.read() - generated_lines = generated.split('\n') + generated_lines = generated.split("\n") if not expected in generated_lines: raise AssertionError("no such require user line: %s" % expected) def test_skips_accounts_without_short_id(self): user_dict = {} user_dict.update(DUMMY_USER_DICT) - del user_dict['short_id'] - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + del user_dict["short_id"] + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) try: path_kind = self.assertPathExists(DUMMY_REL_HTACCESS_PATH) @@ -310,8 +366,7 @@ def test_skips_accounts_without_short_id(self): def test_creates_missing_htaccess_file(self): user_dict = {} user_dict.update(DUMMY_USER_DICT) - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) path_kind = self.assertPathExists(DUMMY_REL_HTACCESS_PATH) # File should exist here and be valid @@ -320,26 +375,26 @@ def test_creates_missing_htaccess_file(self): # Backup file should exist here and be empty self.assertEqual(path_kind, "file") - self.assertHtaccessRequireUserClause(DUMMY_HTACCESS_PATH, - DUMMY_REQUIRE_USER) + self.assertHtaccessRequireUserClause( + DUMMY_HTACCESS_PATH, DUMMY_REQUIRE_USER + ) def test_repairs_existing_stale_htaccess_file(self): user_dict = {} user_dict.update(DUMMY_USER_DICT) # Fake stale user ID directly through DN - user_dict['distinguished_name'] = DUMMY_STALE_USER - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + user_dict["distinguished_name"] = DUMMY_STALE_USER + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) # Verify stale - self.assertHtaccessRequireUserClause(DUMMY_HTACCESS_PATH, - DUMMY_REQUIRE_STALE_USER) + self.assertHtaccessRequireUserClause( + DUMMY_HTACCESS_PATH, DUMMY_REQUIRE_STALE_USER + ) # Reset stale user ID and retry user_dict = {} user_dict.update(DUMMY_USER_DICT) - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) path_kind = self.assertPathExists(DUMMY_REL_HTACCESS_PATH) # File should exist here and be valid @@ -348,9 +403,10 @@ def test_repairs_existing_stale_htaccess_file(self): # Backup file should exist here and be empty self.assertEqual(path_kind, "file") - self.assertHtaccessRequireUserClause(DUMMY_HTACCESS_PATH, - DUMMY_REQUIRE_USER) + self.assertHtaccessRequireUserClause( + DUMMY_HTACCESS_PATH, DUMMY_REQUIRE_USER + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_userdb.py b/tests/test_mig_shared_userdb.py index 2efc524fc..083f14092 100644 --- a/tests/test_mig_shared_userdb.py +++ b/tests/test_mig_shared_userdb.py @@ -34,17 +34,26 @@ # Imports required for the unit test wrapping from mig.shared.base import distinguished_name_to_user from mig.shared.fileio import delete_file -from mig.shared.serial import loads, dumps +from mig.shared.serial import dumps, loads + # Imports of the code under test -from mig.shared.userdb import default_db_path, load_user_db, load_user_dict, \ - lock_user_db, save_user_db, save_user_dict, unlock_user_db, \ - update_user_dict +from mig.shared.userdb import ( + default_db_path, + load_user_db, + load_user_dict, + lock_user_db, + save_user_db, + save_user_dict, + unlock_user_db, + update_user_dict, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -TEST_USER_ID = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' -THIS_USER_ID = '/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=This User/emailAddress=this.user@here.org' -OTHER_USER_ID = '/C=DK/ST=NA/L=NA/O=Other Org/OU=NA/CN=Other User/emailAddress=other.user@there.org' +TEST_USER_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" +THIS_USER_ID = "/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=This User/emailAddress=this.user@here.org" +OTHER_USER_ID = "/C=DK/ST=NA/L=NA/O=Other Org/OU=NA/CN=Other User/emailAddress=other.user@there.org" class TestMigSharedUserDB(MigTestCase): @@ -52,7 +61,7 @@ class TestMigSharedUserDB(MigTestCase): def _provide_configuration(self): """Get test configuration""" - return 'testconfig' + return "testconfig" # Helper methods def _generate_sample_db(self, content=None): @@ -60,7 +69,7 @@ def _generate_sample_db(self, content=None): if content is None: sample_db = { TEST_USER_ID: distinguished_name_to_user(TEST_USER_ID), - THIS_USER_ID: distinguished_name_to_user(THIS_USER_ID) + THIS_USER_ID: distinguished_name_to_user(THIS_USER_ID), } else: sample_db = content @@ -78,10 +87,12 @@ def before_each(self): """Set up test configuration and reset user DB paths""" ensure_dirs_exist(self.configuration.user_db_home) ensure_dirs_exist(self.configuration.mig_server_home) - self.user_db_path = os.path.join(self.configuration.user_db_home, - "MiG-users.db") - self.legacy_db_path = os.path.join(self.configuration.mig_server_home, - "MiG-users.db") + self.user_db_path = os.path.join( + self.configuration.user_db_home, "MiG-users.db" + ) + self.legacy_db_path = os.path.join( + self.configuration.mig_server_home, "MiG-users.db" + ) # Clear any existing test DBs if os.path.exists(self.user_db_path): @@ -95,15 +106,15 @@ def before_each(self): def test_default_db_path(self): """Test default_db_path returns correct path structure""" - expected = os.path.join(self.configuration.user_db_home, - "MiG-users.db") + expected = os.path.join(self.configuration.user_db_home, "MiG-users.db") result = default_db_path(self.configuration) self.assertEqual(result, expected) # Test legacy path fallback - self.configuration.user_db_home = '/no-such-dir' - expected_legacy = os.path.join(self.configuration.mig_server_home, - "MiG-users.db") + self.configuration.user_db_home = "/no-such-dir" + expected_legacy = os.path.join( + self.configuration.mig_server_home, "MiG-users.db" + ) result = default_db_path(self.configuration) self.assertEqual(result, expected_legacy) @@ -149,8 +160,7 @@ def test_load_user_db_direct(self): def test_load_user_db_missing(self): """Test loading missing user database""" - db_path = os.path.join( - self.configuration.user_db_home, "no-such-db.db") + db_path = os.path.join(self.configuration.user_db_home, "no-such-db.db") try: loaded = load_user_db(db_path) except Exception: @@ -190,8 +200,9 @@ def test_load_user_dict_missing(self): """Test loading non-existent user from DB""" self._create_sample_db() try: - loaded = load_user_dict(self.logger, "no-such-user", - self.user_db_path) + loaded = load_user_dict( + self.logger, "no-such-user", self.user_db_path + ) except Exception: loaded = None self.assertIsNone(loaded) @@ -200,8 +211,9 @@ def test_load_user_dict_existing(self): """Test loading existing user from DB""" sample_db = self._create_sample_db() try: - test_user_data = load_user_dict(self.logger, TEST_USER_ID, - self.user_db_path) + test_user_data = load_user_dict( + self.logger, TEST_USER_ID, self.user_db_path + ) except Exception: test_user_data = None self.assertEqual(test_user_data, sample_db[TEST_USER_ID]) @@ -209,8 +221,9 @@ def test_load_user_dict_existing(self): def test_save_user_dict_new_user(self): """Test saving new user to database""" other_user = distinguished_name_to_user(OTHER_USER_ID) - save_status = save_user_dict(self.logger, OTHER_USER_ID, - other_user, self.user_db_path) + save_status = save_user_dict( + self.logger, OTHER_USER_ID, other_user, self.user_db_path + ) self.assertTrue(save_status) with open(self.user_db_path, "rb") as fh: @@ -223,8 +236,9 @@ def test_save_user_dict_update(self): sample_db = self._create_sample_db() changed = distinguished_name_to_user(THIS_USER_ID) changed.update({"Organization": "UPDATED", "new_field": "ADDED"}) - save_status = save_user_dict(self.logger, THIS_USER_ID, - changed, self.user_db_path) + save_status = save_user_dict( + self.logger, THIS_USER_ID, changed, self.user_db_path + ) self.assertTrue(save_status) with open(self.user_db_path, "rb") as fh: @@ -235,9 +249,12 @@ def test_save_user_dict_update(self): def test_update_user_dict(self): """Test update_user_dict with partial changes""" sample_db = self._create_sample_db() - updated = update_user_dict(self.logger, THIS_USER_ID, - {"Organization": "CHANGED"}, - self.user_db_path) + updated = update_user_dict( + self.logger, + THIS_USER_ID, + {"Organization": "CHANGED"}, + self.user_db_path, + ) self.assertEqual(updated["Organization"], "CHANGED") with open(self.user_db_path, "rb") as fh: @@ -249,8 +266,12 @@ def test_update_user_dict_requirements(self): """Test update_user_dict with invalid user ID""" self.logger.forgive_errors() try: - result = update_user_dict(self.logger, "no-such-user", - {"field": "test"}, self.user_db_path) + result = update_user_dict( + self.logger, + "no-such-user", + {"field": "test"}, + self.user_db_path, + ) except Exception: result = None self.assertIsNone(result) @@ -274,6 +295,7 @@ def delayed_load(): return loaded import threading + delayed_thread = threading.Thread(target=delayed_load) delayed_thread.start() time.sleep(0.2) @@ -308,7 +330,8 @@ def test_load_user_db_pickle_abi(self): def test_lock_user_db_invalid_path(self): """Test locking on non-existent database path""" invalid_path = os.path.join( - self.configuration.user_db_home, "missing", "MiG-users.db") + self.configuration.user_db_home, "missing", "MiG-users.db" + ) flock = lock_user_db(invalid_path) self.assertIsNone(flock) @@ -322,7 +345,7 @@ def test_unlock_user_db_invalid(self): def test_load_user_db_corrupted(self): """Test loading corrupted user database""" - with open(self.user_db_path, 'w') as fh: + with open(self.user_db_path, "w") as fh: fh.write("invalid pickle content") with self.assertRaises(Exception): load_user_db(self.user_db_path) @@ -354,8 +377,9 @@ def test_save_user_dict_invalid_id(self): """Test saving user with invalid characters in ID""" invalid_id = "../../invalid.user" user_dict = distinguished_name_to_user(TEST_USER_ID) - save_status = save_user_dict(self.logger, invalid_id, - user_dict, self.user_db_path) + save_status = save_user_dict( + self.logger, invalid_id, user_dict, self.user_db_path + ) self.assertFalse(save_status) def test_update_user_dict_empty_changes(self): @@ -364,8 +388,9 @@ def test_update_user_dict_empty_changes(self): self.logger.forgive_errors() sample_db = self._create_sample_db() original = sample_db[THIS_USER_ID].copy() - updated = update_user_dict(self.logger, THIS_USER_ID, {}, - self.user_db_path) + updated = update_user_dict( + self.logger, THIS_USER_ID, {}, self.user_db_path + ) self.assertEqual(updated, original) # TODO: adjust API to allow enabling the next test @@ -395,5 +420,5 @@ def test_load_user_db_allows_concurrent_read_access(self): unlock_user_db(flock) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_userio.py b/tests/test_mig_shared_userio.py index daa69444a..c17d49b0b 100644 --- a/tests/test_mig_shared_userio.py +++ b/tests/test_mig_shared_userio.py @@ -29,11 +29,11 @@ import os import sys -from past.builtins import basestring, unicode -from tests.support import MigTestCase, testmain +from past.builtins import basestring, unicode from mig.shared.userio import main as userio_main +from tests.support import MigTestCase, testmain class MigSharedUserIO(MigTestCase): @@ -45,9 +45,11 @@ def raise_on_error_exit(exit_code): if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -56,5 +58,5 @@ def record_last_print(value): userio_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_vgrid.py b/tests/test_mig_shared_vgrid.py index 465b74390..89250898d 100644 --- a/tests/test_mig_shared_vgrid.py +++ b/tests/test_mig_shared_vgrid.py @@ -35,41 +35,69 @@ # Imports required for the unit test wrapping from mig.shared.base import client_id_dir from mig.shared.serial import dump + # Imports of the code under test -from mig.shared.vgrid import get_vgrid_workflow_jobs, legacy_main, \ - vgrid_add_entities, vgrid_add_members, vgrid_add_owners, \ - vgrid_add_resources, vgrid_add_workflow_jobs, vgrid_allow_restrict_write, \ - vgrid_exists, vgrid_flat_name, vgrid_is_default, vgrid_is_member, \ - vgrid_is_owner, vgrid_is_owner_or_member, vgrid_is_trigger, vgrid_list, \ - vgrid_list_parents, vgrid_list_subvgrids, vgrid_list_vgrids, \ - vgrid_match_resources, vgrid_nest_sep, vgrid_remove_entities, \ - vgrid_restrict_write, vgrid_set_entities, vgrid_set_members, \ - vgrid_set_owners, vgrid_set_workflow_jobs, vgrid_settings +from mig.shared.vgrid import ( + get_vgrid_workflow_jobs, + legacy_main, + vgrid_add_entities, + vgrid_add_members, + vgrid_add_owners, + vgrid_add_resources, + vgrid_add_workflow_jobs, + vgrid_allow_restrict_write, + vgrid_exists, + vgrid_flat_name, + vgrid_is_default, + vgrid_is_member, + vgrid_is_owner, + vgrid_is_owner_or_member, + vgrid_is_trigger, + vgrid_list, + vgrid_list_parents, + vgrid_list_subvgrids, + vgrid_list_vgrids, + vgrid_match_resources, + vgrid_nest_sep, + vgrid_remove_entities, + vgrid_restrict_write, + vgrid_set_entities, + vgrid_set_members, + vgrid_set_owners, + vgrid_set_workflow_jobs, + vgrid_settings, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain class TestMigSharedVgrid(MigTestCase): """Unit tests for vgrid helpers""" + # Standard user IDs following X.500 DN format - TEST_OWNER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/'\ - 'emailAddress=owner@example.com' - TEST_MEMBER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/'\ - 'emailAddress=member@example.com' - TEST_OUTSIDER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/'\ - 'emailAddress=outsider@example.com' - TEST_RESOURCE_DN = 'test.example.org' - TEST_OWNER_DIR = \ - '+C=DK+ST=NA+L=NA+O=Test_Org+OU=NA+CN=Test_Owner+'\ - 'emailAddress=owner@example.com' - TEST_JOB_ID = '12345667890' + TEST_OWNER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/" + "emailAddress=owner@example.com" + ) + TEST_MEMBER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/" + "emailAddress=member@example.com" + ) + TEST_OUTSIDER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/" + "emailAddress=outsider@example.com" + ) + TEST_RESOURCE_DN = "test.example.org" + TEST_OWNER_DIR = ( + "+C=DK+ST=NA+L=NA+O=Test_Org+OU=NA+CN=Test_Owner+" + "emailAddress=owner@example.com" + ) + TEST_JOB_ID = "12345667890" def _provide_configuration(self): """Return configuration to use""" - return 'testconfig' + return "testconfig" def before_each(self): """Create test environment for vgrid tests""" @@ -83,30 +111,33 @@ def before_each(self): ensure_dirs_exist(self.configuration.workflows_home) ensure_dirs_exist(self.configuration.workflows_db_home) ensure_dirs_exist(self.configuration.mig_system_files) - self.configuration.site_vgrid_label = 'VGridLabel' - self.configuration.vgrid_owners = 'owners.pck' - self.configuration.vgrid_members = 'members.pck' - self.configuration.vgrid_resources = 'resources.pck' - self.configuration.vgrid_settings = 'settings.pck' - self.configuration.vgrid_workflow_job_queue = 'jobqueue.pck' + self.configuration.site_vgrid_label = "VGridLabel" + self.configuration.vgrid_owners = "owners.pck" + self.configuration.vgrid_members = "members.pck" + self.configuration.vgrid_resources = "resources.pck" + self.configuration.vgrid_settings = "settings.pck" + self.configuration.vgrid_workflow_job_queue = "jobqueue.pck" self.configuration.site_enable_workflows = True # Default vgrid for comparison - self.default_vgrid = 'Generic' + self.default_vgrid = "Generic" # Create test vgrid structure using ensure_dirs_exist - self.test_vgrid = 'testvgrid' + self.test_vgrid = "testvgrid" self.test_vgrid_path = os.path.join( - self.configuration.vgrid_home, self.test_vgrid) + self.configuration.vgrid_home, self.test_vgrid + ) ensure_dirs_exist(self.test_vgrid_path) - vgrid_add_owners(self.configuration, self.test_vgrid, - [self.TEST_OWNER_DN]) + vgrid_add_owners( + self.configuration, self.test_vgrid, [self.TEST_OWNER_DN] + ) vgrid_add_members(self.configuration, self.test_vgrid, []) # Nested sub-VGrid - self.test_subvgrid = os.path.join(self.test_vgrid, 'subvgrid') + self.test_subvgrid = os.path.join(self.test_vgrid, "subvgrid") self.test_subvgrid_path = os.path.join( - self.configuration.vgrid_home, self.test_subvgrid) + self.configuration.vgrid_home, self.test_subvgrid + ) ensure_dirs_exist(self.test_subvgrid_path) vgrid_set_owners(self.configuration, self.test_vgrid, []) vgrid_set_members(self.configuration, self.test_vgrid, []) @@ -114,32 +145,32 @@ def before_each(self): def test_vgrid_is_default_match(self): """Test default vgrid detection with match""" - self.assertTrue(vgrid_is_default('Generic')) + self.assertTrue(vgrid_is_default("Generic")) self.assertTrue(vgrid_is_default(None)) - self.assertTrue(vgrid_is_default('')) + self.assertTrue(vgrid_is_default("")) def test_vgrid_is_default_mismatch(self): """Test default vgrid detection with mismatch""" self.assertFalse(vgrid_is_default(self.test_vgrid)) self.assertFalse(vgrid_is_default(self.test_subvgrid)) - self.assertFalse(vgrid_is_default('MiG')) + self.assertFalse(vgrid_is_default("MiG")) def test_vgrid_exists_for_existing(self): """Test vgrid existence checks for existing vgrids""" - self.assertTrue(vgrid_exists(self.configuration, 'Generic')) + self.assertTrue(vgrid_exists(self.configuration, "Generic")) self.assertTrue(vgrid_exists(self.configuration, None)) - self.assertTrue(vgrid_exists(self.configuration, '')) + self.assertTrue(vgrid_exists(self.configuration, "")) self.assertTrue(vgrid_exists(self.configuration, self.test_vgrid)) self.assertTrue(vgrid_exists(self.configuration, self.test_subvgrid)) def test_vgrid_exists_for_missing(self): """Test vgrid existence checks for missing vgrids""" - self.assertFalse(vgrid_exists(self.configuration, 'no_such_vgrid')) - self.assertFalse(vgrid_exists(self.configuration, 'no_such_vgrid/sub')) + self.assertFalse(vgrid_exists(self.configuration, "no_such_vgrid")) + self.assertFalse(vgrid_exists(self.configuration, "no_such_vgrid/sub")) # Parent exists but not child - yet, vgrid_exists defaults to recursive # and allow_missing so it will return True for ALL subvgrids of vgrids. - missing_child = os.path.join(self.test_subvgrid, 'missing_child') + missing_child = os.path.join(self.test_subvgrid, "missing_child") self.assertTrue(vgrid_exists(self.configuration, missing_child)) def test_vgrid_list_vgrids(self): @@ -152,15 +183,17 @@ def test_vgrid_list_vgrids(self): self.assertIn(self.test_subvgrid, all_vgrids) # Exclude default - status, no_default = vgrid_list_vgrids(self.configuration, - include_default=False) + status, no_default = vgrid_list_vgrids( + self.configuration, include_default=False + ) self.assertTrue(status) self.assertNotIn(self.default_vgrid, no_default) self.assertIn(self.test_vgrid, no_default) # Root filtering - status, root_vgrids = vgrid_list_vgrids(self.configuration, - root_vgrid=self.test_vgrid) + status, root_vgrids = vgrid_list_vgrids( + self.configuration, root_vgrid=self.test_vgrid + ) self.assertTrue(status) self.assertIn(self.test_subvgrid, root_vgrids) self.assertNotIn(self.test_vgrid, root_vgrids) @@ -168,134 +201,185 @@ def test_vgrid_list_vgrids(self): def test_vgrid_add_remove_entities(self): """Test entity management in vgrid""" # Test adding owner - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "owners", [self.TEST_OWNER_DN] + ) self.assertTrue(added, msg) time.sleep(0.1) # Ensure timestamp changes # Verify existence - self.assertTrue(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) + self.assertTrue( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) # Test removal without and with allow empty in turn - removed, msg = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN], False) + removed, msg = vgrid_remove_entities( + self.configuration, + self.test_vgrid, + "owners", + [self.TEST_OWNER_DN], + False, + ) self.assertFalse(removed, msg) - self.assertTrue(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) - removed, msg = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN], True) + self.assertTrue( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) + removed, msg = vgrid_remove_entities( + self.configuration, + self.test_vgrid, + "owners", + [self.TEST_OWNER_DN], + True, + ) self.assertTrue(removed, msg) - self.assertFalse(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) + self.assertFalse( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) def test_vgrid_settings_inheritance(self): """Test vgrid setting inheritance""" # Parent settings (MUST include required vgrid_name field) parent_settings = [ - ('vgrid_name', self.test_vgrid), - ('write_shared_files', 'None') + ("vgrid_name", self.test_vgrid), + ("write_shared_files", "None"), ] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', parent_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "settings", parent_settings + ) self.assertTrue(added, msg) # Verify inheritance - status, settings = vgrid_settings(self.test_subvgrid, self.configuration, - recursive=True, as_dict=True) + status, settings = vgrid_settings( + self.test_subvgrid, + self.configuration, + recursive=True, + as_dict=True, + ) self.assertTrue(status) - self.assertEqual(settings.get('write_shared_files'), 'None') + self.assertEqual(settings.get("write_shared_files"), "None") # Verify vgrid_name is preserved - self.assertEqual(settings['vgrid_name'], self.test_subvgrid) + self.assertEqual(settings["vgrid_name"], self.test_subvgrid) def test_vgrid_permission_checks(self): """Test owner/member permission verification""" # Setup owners and members - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "owners", [self.TEST_OWNER_DN] + ) self.assertTrue(added, msg) time.sleep(0.1) - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'members', [self.TEST_MEMBER_DN]) + added, msg = vgrid_add_entities( + self.configuration, + self.test_vgrid, + "members", + [self.TEST_MEMBER_DN], + ) self.assertTrue(added, msg) time.sleep(0.1) # Verify owner permissions - self.assertTrue(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) - self.assertTrue(vgrid_is_owner_or_member(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) + self.assertTrue( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) + self.assertTrue( + vgrid_is_owner_or_member( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) # Verify member permissions - self.assertTrue(vgrid_is_member(self.test_vgrid, self.TEST_MEMBER_DN, - self.configuration)) - self.assertTrue(vgrid_is_owner_or_member(self.test_vgrid, self.TEST_MEMBER_DN, - self.configuration)) + self.assertTrue( + vgrid_is_member( + self.test_vgrid, self.TEST_MEMBER_DN, self.configuration + ) + ) + self.assertTrue( + vgrid_is_owner_or_member( + self.test_vgrid, self.TEST_MEMBER_DN, self.configuration + ) + ) # Verify non-member - self.assertFalse(vgrid_is_owner_or_member(self.test_vgrid, self.TEST_OUTSIDER_DN, - self.configuration)) + self.assertFalse( + vgrid_is_owner_or_member( + self.test_vgrid, self.TEST_OUTSIDER_DN, self.configuration + ) + ) def test_workflow_job_management(self): """Test workflow job queue handling""" job_entry = { - 'client_id': self.TEST_OWNER_DN, - 'job_id': self.TEST_JOB_ID + "client_id": self.TEST_OWNER_DN, + "job_id": self.TEST_JOB_ID, } - job_dir = os.path.join(self.configuration.mrsl_files_dir, - self.TEST_OWNER_DIR) + job_dir = os.path.join( + self.configuration.mrsl_files_dir, self.TEST_OWNER_DIR + ) ensure_dirs_exist(job_dir) - job_path = os.path.join(job_dir, '%s.mRSL' % self.TEST_JOB_ID) - dump({'job_id': self.TEST_JOB_ID, 'EXECUTE': 'uptime'}, job_path) + job_path = os.path.join(job_dir, "%s.mRSL" % self.TEST_JOB_ID) + dump({"job_id": self.TEST_JOB_ID, "EXECUTE": "uptime"}, job_path) # Add job - status, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'jobqueue', [job_entry]) + status, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "jobqueue", [job_entry] + ) self.assertTrue(status, msg) # TODO: adjust function to consistent return API? tuple vs list now. # List jobs - result = get_vgrid_workflow_jobs(self.configuration, - self.test_vgrid, True) + result = get_vgrid_workflow_jobs( + self.configuration, self.test_vgrid, True + ) if isinstance(result, tuple): status, msg = result jobs = [] else: - status, msg = True, '' + status, msg = True, "" jobs = result self.assertTrue(status) self.assertEqual(len(jobs), 1) - self.assertEqual(jobs[0]['job_id'], self.TEST_JOB_ID) + self.assertEqual(jobs[0]["job_id"], self.TEST_JOB_ID) # TODO: adjust function to consistent return API? tuple vs list now. # Remove job - result = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'jobqueue', [job_entry], True) + result = vgrid_remove_entities( + self.configuration, self.test_vgrid, "jobqueue", [job_entry], True + ) if isinstance(result, tuple): status, msg = result jobs = [] else: - status, msg = True, '' + status, msg = True, "" jobs = result self.assertTrue(status, msg) # TODO: adjust function to consistent return API? tuple vs list now. # Verify removal - result = get_vgrid_workflow_jobs(self.configuration, - self.test_vgrid, True) + result = get_vgrid_workflow_jobs( + self.configuration, self.test_vgrid, True + ) if isinstance(result, tuple): status, msg = result jobs = [] else: - status, msg = True, '' + status, msg = True, "" jobs = result self.assertTrue(status) self.assertEqual(len(jobs), 0) def test_vgrid_list_subvgrids(self): """Test retrieving subvgrids of given vgrid""" - status, subvgrids = vgrid_list_subvgrids(self.test_vgrid, - self.configuration) + status, subvgrids = vgrid_list_subvgrids( + self.test_vgrid, self.configuration + ) self.assertTrue(status) self.assertEqual(subvgrids, [self.test_subvgrid]) @@ -307,13 +391,18 @@ def test_vgrid_list_parents(self): def test_vgrid_match_resources(self): """Test resource filtering for vgrid""" - test_resources = ['res1', 'res2', 'invalid_res', self.TEST_RESOURCE_DN] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'resources', [self.TEST_RESOURCE_DN]) + test_resources = ["res1", "res2", "invalid_res", self.TEST_RESOURCE_DN] + added, msg = vgrid_add_entities( + self.configuration, + self.test_vgrid, + "resources", + [self.TEST_RESOURCE_DN], + ) self.assertTrue(added, msg) - matched = vgrid_match_resources(self.test_vgrid, test_resources, - self.configuration) + matched = vgrid_match_resources( + self.test_vgrid, test_resources, self.configuration + ) self.assertEqual(matched, [self.TEST_RESOURCE_DN]) # TODO: adjust API to allow enabling the next test @@ -321,25 +410,37 @@ def test_vgrid_match_resources(self): def test_vgrid_allow_restrict_write(self): """Test write restriction validation logic""" # Create parent-child structure - parent_write_setting = [('write_shared_files', 'none')] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', parent_write_setting) + parent_write_setting = [("write_shared_files", "none")] + added, msg = vgrid_add_entities( + self.configuration, + self.test_vgrid, + "settings", + parent_write_setting, + ) self.assertTrue(added, msg) # Child tries to set write_shared_files to "members" - result = vgrid_allow_restrict_write(self.test_subvgrid, 'members', - self.configuration, - auto_migrate=True) + result = vgrid_allow_restrict_write( + self.test_subvgrid, + "members", + self.configuration, + auto_migrate=True, + ) self.assertFalse(result) # Valid case: parent allows writes, child sets to "none" - parent_write_setting = [('write_shared_files', 'members')] - vgrid_set_entities(self.configuration, self.test_vgrid, - 'settings', parent_write_setting, True) + parent_write_setting = [("write_shared_files", "members")] + vgrid_set_entities( + self.configuration, + self.test_vgrid, + "settings", + parent_write_setting, + True, + ) - result = vgrid_allow_restrict_write(self.test_subvgrid, 'none', - self.configuration, - auto_migrate=True) + result = vgrid_allow_restrict_write( + self.test_subvgrid, "none", self.configuration, auto_migrate=True + ) self.assertTrue(result) # TODO: adjust API to allow enabling the next test @@ -347,289 +448,348 @@ def test_vgrid_allow_restrict_write(self): def test_vgrid_restrict_write(self): """Test write restriction enforcement""" # Setup test share - test_share = os.path.join(self.configuration.vgrid_files_home, - self.test_vgrid) + test_share = os.path.join( + self.configuration.vgrid_files_home, self.test_vgrid + ) ensure_dirs_exist(test_share) # Migrate to restricted mode - result = vgrid_restrict_write(self.test_vgrid, 'none', - self.configuration, auto_migrate=True) + result = vgrid_restrict_write( + self.test_vgrid, "none", self.configuration, auto_migrate=True + ) self.assertTrue(result) # Verify symlink points to readonly - flat_vgrid = self.test_vgrid.replace('/', vgrid_nest_sep) - read_path = os.path.join(self.configuration.vgrid_files_readonly, - flat_vgrid) - self.assertEqual(os.path.realpath(test_share), - os.path.realpath(read_path)) + flat_vgrid = self.test_vgrid.replace("/", vgrid_nest_sep) + read_path = os.path.join( + self.configuration.vgrid_files_readonly, flat_vgrid + ) + self.assertEqual( + os.path.realpath(test_share), os.path.realpath(read_path) + ) def test_vgrid_settings_scopes(self): """Test different vgrid settings lookup scopes""" # Local settings only - MUST include required vgrid_name field local_settings = [ - ('vgrid_name', self.test_subvgrid), - ('description', 'test subvgrid'), - ('write_shared_files', 'members'), + ("vgrid_name", self.test_subvgrid), + ("description", "test subvgrid"), + ("write_shared_files", "members"), ] - added, msg = vgrid_add_entities(self.configuration, self.test_subvgrid, - 'settings', local_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_subvgrid, "settings", local_settings + ) self.assertTrue(added, msg) # Check recursive vs direct status, recursive_settings = vgrid_settings( - self.test_subvgrid, self.configuration, recursive=True, as_dict=True) + self.test_subvgrid, + self.configuration, + recursive=True, + as_dict=True, + ) self.assertTrue(status) status, direct_settings = vgrid_settings( - self.test_subvgrid, self.configuration, recursive=False, as_dict=True) + self.test_subvgrid, + self.configuration, + recursive=False, + as_dict=True, + ) self.assertTrue(status) - self.assertEqual(direct_settings['description'], 'test subvgrid') - self.assertIn('write_shared_files', recursive_settings) # Inherited + self.assertEqual(direct_settings["description"], "test subvgrid") + self.assertIn("write_shared_files", recursive_settings) # Inherited def test_vgrid_add_owner_single(self): """Test vgrid_add_owners for initial owner""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN added, msg = vgrid_add_owners( - self.configuration, self.test_vgrid, [owner1]) + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1]) def test_vgrid_add_owner_with_rank(self): """Test owner ranking/ordering functionality""" - new_owner = '/C=DK/CN=New Owner' - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [new_owner], rank=0) # Add first owner + new_owner = "/C=DK/CN=New Owner" + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [new_owner], rank=0 + ) # Add first owner self.assertTrue(added, msg) - owners = vgrid_list(self.test_vgrid, 'owners', self.configuration)[1] + owners = vgrid_list(self.test_vgrid, "owners", self.configuration)[1] self.assertEqual(owners[0], new_owner) def test_vgrid_add_owners_twice(self): """Test vgrid_add_owners for initial and secondary owner""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2] + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1, owner2]) def test_vgrid_add_owners_twice_with_rank_zero(self): """Test vgrid_add_owners for two owners with 2nd inserted first""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2], rank=0) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2], rank=0 + ) self.assertTrue(added, msg) status, owners = vgrid_list( - self.test_vgrid, 'owners', self.configuration) + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner1]) def test_vgrid_add_owners_thrice_with_rank_one(self): """Test vgrid_add_owners for three owners with 3rd inserted in middle""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2] + ) self.assertTrue(added, msg) owner3 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner3], rank=1) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner3], rank=1 + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1, owner3, owner2]) def test_vgrid_add_owners_multiple(self): """Test vgrid_add_owners inserting list of owners at once""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) new_owners = [ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/' - 'emailAddress=owner5@example.com', - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/' - 'emailAddress=owner6@example.com' + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/" + "emailAddress=owner5@example.com", + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/" + "emailAddress=owner6@example.com", ] - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - new_owners) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, new_owners + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, new_owners) def test_vgrid_add_owners_repeat_inserts_only_one(self): """Test vgrid_add_owners inserting same owners twice does nothing""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(len(owners), 1) self.assertEqual(owners, [owner1]) def test_vgrid_add_owners_lifecycle(self): """Comprehensive life-cycle test for vgrid_add_owners functionality""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) # Test 1: Add initial owner owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) # Verify single owner - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1]) # Test 2: Prepend new owner owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2], rank=0) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2], rank=0 + ) self.assertTrue(added, msg) # Verify new order - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner1]) # Test 3: Append without rank owner3 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner3]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner3] + ) self.assertTrue(added, msg) # Verify append - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner1, owner3]) # Test 4: Insert at middle position - owner4 = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Four/'\ - 'emailAddress=owner4@example.com' - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner4], rank=1) + owner4 = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Four/" + "emailAddress=owner4@example.com" + ) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner4], rank=1 + ) self.assertTrue(added, msg) # Verify insertion - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner4, owner1, owner3]) # Test 5: Add multiple owners at once new_owners = [ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/' - 'emailAddress=owner5@example.com', - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/' - 'emailAddress=owner6@example.com' + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/" + "emailAddress=owner5@example.com", + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/" + "emailAddress=owner6@example.com", ] - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - new_owners, rank=2) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, new_owners, rank=2 + ) self.assertTrue(added, msg) # Verify multi-insert - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) expected = [owner2, owner4] + new_owners + [owner1, owner3] self.assertEqual(owners, expected) # Test 6: Prevent duplicate owner addition pre_add_count = len(owners) - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) # Verify no duplicate added - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(len(owners), pre_add_count) def test_vgrid_is_trigger(self): """Test trigger rule detection""" test_rule = { - 'rule_id': 'test_rule', - 'vgrid_name': self.test_vgrid, - 'path': '*.txt', - 'changes': ['modified'], - 'run_as': self.TEST_OWNER_DN, - 'action': 'copy', - 'arguments': [], - 'match_files': True, - 'match_dirs': False, - 'match_recursive': False, + "rule_id": "test_rule", + "vgrid_name": self.test_vgrid, + "path": "*.txt", + "changes": ["modified"], + "run_as": self.TEST_OWNER_DN, + "action": "copy", + "arguments": [], + "match_files": True, + "match_dirs": False, + "match_recursive": False, } # Add trigger to vgrid with all required fields - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'triggers', [test_rule]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "triggers", [test_rule] + ) self.assertTrue(added, msg) - self.assertTrue(vgrid_is_trigger( - self.test_vgrid, 'test_rule', self.configuration)) + self.assertTrue( + vgrid_is_trigger(self.test_vgrid, "test_rule", self.configuration) + ) def test_vgrid_sharelink_operations(self): """Test sharelink add/remove cycles""" test_share = { - 'share_id': 'test_share', - 'path': '/test/path', - 'access': ['read'], # Must be list type - 'invites': [self.TEST_MEMBER_DN], # Required field - 'single_file': True, # Correct field name (was 'is_dir') - 'expire': '-1', # Optional but included for completeness - 'owner': self.TEST_OWNER_DN, - 'created_timestamp': datetime.datetime.now() + "share_id": "test_share", + "path": "/test/path", + "access": ["read"], # Must be list type + "invites": [self.TEST_MEMBER_DN], # Required field + "single_file": True, # Correct field name (was 'is_dir') + "expire": "-1", # Optional but included for completeness + "owner": self.TEST_OWNER_DN, + "created_timestamp": datetime.datetime.now(), } - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'sharelinks', [test_share]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "sharelinks", [test_share] + ) self.assertTrue(added, msg) # Test removal - removed, msg = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'sharelinks', ['test_share'], True) + removed, msg = vgrid_remove_entities( + self.configuration, + self.test_vgrid, + "sharelinks", + ["test_share"], + True, + ) self.assertTrue(removed, msg) # TODO: adjust API to allow enabling the next test @@ -637,287 +797,357 @@ def test_vgrid_sharelink_operations(self): def test_vgrid_settings_validation(self): """Test settings key validation""" invalid_settings = [ - ('vgrid_name', self.test_vgrid), # Required field - ('invalid_key', 'value') # Invalid extra field + ("vgrid_name", self.test_vgrid), # Required field + ("invalid_key", "value"), # Invalid extra field ] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', invalid_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "settings", invalid_settings + ) # Should not accept invalid key even with valid required fields self.assertFalse(added) self.assertIn("unknown settings key 'invalid_key'", msg) - status, settings = vgrid_list(self.test_vgrid, 'settings', - self.configuration) + status, settings = vgrid_list( + self.test_vgrid, "settings", self.configuration + ) # Should never save invalid key even with valid required fields self.assertTrue(status) - self.assertNotIn('invalid_key', settings[0]) + self.assertNotIn("invalid_key", settings[0]) def test_vgrid_entity_listing(self): """Test direct entity listing functions""" # Test empty members and one owner listing from init - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertTrue(status) self.assertEqual(len(members), 0) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertTrue(status) self.assertEqual(len(owners), 1) # Populate and verify - vgrid_add_owners(self.configuration, self.test_vgrid, - [self.TEST_OWNER_DN]) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + vgrid_add_owners( + self.configuration, self.test_vgrid, [self.TEST_OWNER_DN] + ) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertTrue(status) self.assertEqual(owners, [self.TEST_OWNER_DN]) def test_vgrid_add_members_single(self): """Test vgrid_add_owners for initial member""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1]) def test_vgrid_add_member_with_rank(self): """Test member ranking/ordering functionality""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1], rank=0) # Add first owner + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1], rank=0 + ) # Add first owner self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1]) def test_vgrid_add_members_twice(self): """Test vgrid_add_members for initial and secondary member""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2] + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1, member2]) def test_vgrid_add_members_twice_with_rank_zero(self): """Test vgrid_add_members for two members with 2nd inserted first""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2], rank=0) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2], rank=0 + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member1]) def test_vgrid_add_members_thrice_with_rank_one(self): """Test vgrid_add_members for three members with 3rd inserted in middle""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2] + ) self.assertTrue(added, msg) member3 = self.TEST_OWNER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member3], rank=1) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member3], rank=1 + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1, member3, member2]) def test_vgrid_add_members_lifecycle(self): """Comprehensive life-cycle test for vgrid_add_members functionality""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) # Test 1: Add initial member member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) # Verify single member - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1]) # Test 2: Prepend new member member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2], rank=0) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2], rank=0 + ) self.assertTrue(added, msg) # Verify new order - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member1]) # Test 3: Append without rank member3 = self.TEST_OWNER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member3]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member3] + ) self.assertTrue(added, msg) # Verify append - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member1, member3]) # Test 4: Insert at middle position - member4 = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Four/'\ - 'emailAddress=member4@example.com' - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member4], rank=1) + member4 = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Four/" + "emailAddress=member4@example.com" + ) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member4], rank=1 + ) self.assertTrue(added, msg) # Verify insertion - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member4, member1, member3]) # Test 5: Add multiple members at once new_members = [ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Five/' - 'emailAddress=member5@example.com', - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Six/' - 'emailAddress=member6@example.com' + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Five/" + "emailAddress=member5@example.com", + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Six/" + "emailAddress=member6@example.com", ] - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - new_members, rank=2) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, new_members, rank=2 + ) self.assertTrue(added, msg) # Verify multi-insert - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) expected = [member2, member4] + new_members + [member1, member3] self.assertEqual(members, expected) # Test 6: Prevent duplicate member addition pre_add_count = len(members) - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) # Verify no duplicate added - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(len(members), pre_add_count) def test_flat_vgrid_name(self): """Test vgrid_flat_name conversion""" - nested_vgrid = 'testvgrid/sub' - expected_flat = '%s' % vgrid_nest_sep.join(['testvgrid', 'sub']) + nested_vgrid = "testvgrid/sub" + expected_flat = "%s" % vgrid_nest_sep.join(["testvgrid", "sub"]) converted = vgrid_flat_name(nested_vgrid, self.configuration) self.assertEqual(converted, expected_flat) - nested_vgrid = 'testvgrid/sub/test' - expected_flat = '%s' % vgrid_nest_sep.join( - ['testvgrid', 'sub', 'test']) + nested_vgrid = "testvgrid/sub/test" + expected_flat = "%s" % vgrid_nest_sep.join(["testvgrid", "sub", "test"]) converted = vgrid_flat_name(nested_vgrid, self.configuration) self.assertEqual(converted, expected_flat) def test_resource_signup_workflow(self): """Test full resource signup workflow""" # Sign up resource - added, msg = vgrid_add_resources(self.configuration, self.test_vgrid, - [self.TEST_RESOURCE_DN]) + added, msg = vgrid_add_resources( + self.configuration, self.test_vgrid, [self.TEST_RESOURCE_DN] + ) self.assertTrue(added, msg) # Verify visibility - matched = vgrid_match_resources(self.test_vgrid, [self.TEST_RESOURCE_DN], - self.configuration) + matched = vgrid_match_resources( + self.test_vgrid, [self.TEST_RESOURCE_DN], self.configuration + ) self.assertEqual(matched, [self.TEST_RESOURCE_DN]) def test_multi_level_inheritance(self): """Test settings propagation through multiple vgrid levels""" # Create grandchild vgrid - grandchild = os.path.join(self.test_subvgrid, 'grandchild') + grandchild = os.path.join(self.test_subvgrid, "grandchild") grandchild_path = os.path.join( - self.configuration.vgrid_home, grandchild) + self.configuration.vgrid_home, grandchild + ) ensure_dirs_exist(grandchild_path) # Set valid inherited setting at top level with required vgrid_name top_settings = [ - ('vgrid_name', self.test_vgrid), - ('hidden', True) # Valid inherited field with boolean value + ("vgrid_name", self.test_vgrid), + ("hidden", True), # Valid inherited field with boolean value ] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', top_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "settings", top_settings + ) self.assertTrue(added, msg) # Verify grandchild inheritance using 'hidden' field inherit=true - status, settings = vgrid_settings(grandchild, self.configuration, - recursive=True, as_dict=True) + status, settings = vgrid_settings( + grandchild, self.configuration, recursive=True, as_dict=True + ) self.assertTrue(status) - self.assertEqual(settings.get('hidden'), True) + self.assertEqual(settings.get("hidden"), True) # Verify vgrid_name is preserved - self.assertEqual(settings['vgrid_name'], grandchild) + self.assertEqual(settings["vgrid_name"], grandchild) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires tweaking of funcion") def test_workflow_job_priority(self): """Test workflow job queue ordering and limits""" # Create max jobs + 1 - job_entries = [{ - 'vgrid_name': self.test_vgrid, - 'client_id': self.TEST_OWNER_DN, - 'job_id': str(i), - 'run_as': self.TEST_OWNER_DN, # Required field - 'exe': '/bin/echo', # Required job field - 'arguments': ['Test job'], # Required job field - } for i in range(101)] + job_entries = [ + { + "vgrid_name": self.test_vgrid, + "client_id": self.TEST_OWNER_DN, + "job_id": str(i), + "run_as": self.TEST_OWNER_DN, # Required field + "exe": "/bin/echo", # Required job field + "arguments": ["Test job"], # Required job field + } + for i in range(101) + ] added, msg = vgrid_add_workflow_jobs( - self.configuration, - self.test_vgrid, - job_entries + self.configuration, self.test_vgrid, job_entries ) self.assertTrue(added, msg) status, jobs = vgrid_list( - self.test_vgrid, 'jobqueue', self.configuration) + self.test_vgrid, "jobqueue", self.configuration + ) self.assertTrue(status) # Should stay at max 100 by removing oldest self.assertEqual(len(jobs), 100) - self.assertEqual(jobs[-1]['job_id'], '100') # Newest at end + self.assertEqual(jobs[-1]["job_id"], "100") # Newest at end class TestMigSharedVgrid__legacy_main(MigTestCase): @@ -925,14 +1155,17 @@ class TestMigSharedVgrid__legacy_main(MigTestCase): def test_existing_main(self): """Run the legacy self-tests directly in module""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -942,5 +1175,5 @@ def record_last_print(value): legacy_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_vgridaccess.py b/tests/test_mig_shared_vgridaccess.py index 1dbe2a086..6bfe086d3 100644 --- a/tests/test_mig_shared_vgridaccess.py +++ b/tests/test_mig_shared_vgridaccess.py @@ -35,92 +35,157 @@ import mig.shared.vgridaccess as vgridaccess from mig.shared.fileio import pickle, read_file from mig.shared.vgrid import vgrid_list, vgrid_set_entities, vgrid_settings -from mig.shared.vgridaccess import CONF, MEMBERS, OWNERS, RESOURCES, SETTINGS, \ - USERID, USERS, VGRIDS, check_resources_modified, check_vgrid_access, \ - check_vgrids_modified, fill_placeholder_cache, force_update_resource_map, \ - force_update_user_map, force_update_vgrid_map, get_re_provider_map, \ - get_resource_map, get_user_map, get_vgrid_map, get_vgrid_map_vgrids, \ - is_vgrid_parent_placeholder, load_resource_map, load_user_map, \ - load_vgrid_map, mark_vgrid_modified, refresh_resource_map, \ - refresh_user_map, refresh_vgrid_map, res_vgrid_access, \ - reset_resources_modified, reset_vgrids_modified, resources_using_re, \ - unmap_inheritance, unmap_resource, unmap_vgrid, user_allowed_res_confs, \ - user_allowed_res_exes, user_allowed_res_stores, user_allowed_res_units, \ - user_allowed_user_confs, user_owned_res_exes, user_owned_res_stores, \ - user_vgrid_access, user_visible_res_confs, user_visible_res_exes, \ - user_visible_res_stores, user_visible_user_confs, vgrid_inherit_map +from mig.shared.vgridaccess import ( + CONF, + MEMBERS, + OWNERS, + RESOURCES, + SETTINGS, + USERID, + USERS, + VGRIDS, + check_resources_modified, + check_vgrid_access, + check_vgrids_modified, + fill_placeholder_cache, + force_update_resource_map, + force_update_user_map, + force_update_vgrid_map, + get_re_provider_map, + get_resource_map, + get_user_map, + get_vgrid_map, + get_vgrid_map_vgrids, + is_vgrid_parent_placeholder, + load_resource_map, + load_user_map, + load_vgrid_map, + mark_vgrid_modified, + refresh_resource_map, + refresh_user_map, + refresh_vgrid_map, + res_vgrid_access, + reset_resources_modified, + reset_vgrids_modified, + resources_using_re, + unmap_inheritance, + unmap_resource, + unmap_vgrid, + user_allowed_res_confs, + user_allowed_res_exes, + user_allowed_res_stores, + user_allowed_res_units, + user_allowed_user_confs, + user_owned_res_exes, + user_owned_res_stores, + user_vgrid_access, + user_visible_res_confs, + user_visible_res_exes, + user_visible_res_stores, + user_visible_user_confs, + vgrid_inherit_map, +) from tests.support import MigTestCase, ensure_dirs_exist, testmain -from tests.support.usersupp import UserAssertMixin, TEST_USER_DN +from tests.support.usersupp import TEST_USER_DN, UserAssertMixin class TestMigSharedVgridAccess(MigTestCase, UserAssertMixin): """Unit tests for vgridaccess related helper functions""" - TEST_OWNER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/'\ - 'emailAddress=owner@example.org' - TEST_MEMBER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/'\ - 'emailAddress=member@example.org' - TEST_OUTSIDER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/'\ - 'emailAddress=outsider@example.com' - TEST_RESOURCE_ID = 'test.example.org.0' - TEST_VGRID_NAME = 'testvgrid' - - TEST_OWNER_UUID = 'ff326a2b984828d9b32077c9b0b35a05' - TEST_MEMBER_UUID = 'ea9aedcbe69db279ca3676f83de94669' - TEST_RESOURCE_ALIAS = '0835f310d6422c36e33eeb7d0d3e9cf5' + TEST_OWNER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/" + "emailAddress=owner@example.org" + ) + TEST_MEMBER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/" + "emailAddress=member@example.org" + ) + TEST_OUTSIDER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/" + "emailAddress=outsider@example.com" + ) + TEST_RESOURCE_ID = "test.example.org.0" + TEST_VGRID_NAME = "testvgrid" + + TEST_OWNER_UUID = "ff326a2b984828d9b32077c9b0b35a05" + TEST_MEMBER_UUID = "ea9aedcbe69db279ca3676f83de94669" + TEST_RESOURCE_ALIAS = "0835f310d6422c36e33eeb7d0d3e9cf5" # Default vgrid is initially set up without settings when force loaded - MINIMAL_VGRIDS = {'Generic': {OWNERS: [], MEMBERS: [], RESOURCES: [], - SETTINGS: []}} + MINIMAL_VGRIDS = { + "Generic": {OWNERS: [], MEMBERS: [], RESOURCES: [], SETTINGS: []} + } def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' - - def _create_vgrid(self, vgrid_name, *, owners=None, members=None, - resources=None, settings=None, triggers=None): + return "testconfig" + + def _create_vgrid( + self, + vgrid_name, + *, + owners=None, + members=None, + resources=None, + settings=None, + triggers=None + ): """Helper to create valid skeleton vgrid for testing""" vgrid_path = os.path.join(self.configuration.vgrid_home, vgrid_name) ensure_dirs_exist(vgrid_path) # Save vgrid owners, members, resources, settings and triggers if owners is None: owners = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'owners', owners, allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, vgrid_name, "owners", owners, allow_empty=True + ) self.assertEqual(success_and_msg, (True, "")) if members is None: members = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'members', members, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "members", + members, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if resources is None: resources = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'resources', resources, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "resources", + resources, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if settings is None: - settings = [('vgrid_name', vgrid_name)] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'settings', settings, - allow_empty=True) + settings = [("vgrid_name", vgrid_name)] + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "settings", + settings, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if triggers is None: triggers = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'triggers', triggers, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "triggers", + triggers, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) def _create_resource(self, res_name, owners, config=None): """Helper to create valid skeleton resource for testing""" res_path = os.path.join(self.configuration.resource_home, res_name) - res_owners_path = os.path.join(res_path, 'owners') - res_config_path = os.path.join(res_path, 'config') + res_owners_path = os.path.join(res_path, "owners") + res_config_path = os.path.join(res_path, "config") # Add resource skeleton with owners ensure_dirs_exist(res_path) if owners is None: @@ -129,10 +194,11 @@ def _create_resource(self, res_name, owners, config=None): self.assertTrue(saved) if config is None: # Make sure conf has one valid field - config = {'HOSTURL': res_name, - 'EXECONFIG': [{'name': 'exe', 'vgrid': ['Generic']}], - 'STORECONFIG': [{'name': 'exe', 'vgrid': ['Generic']}] - } + config = { + "HOSTURL": res_name, + "EXECONFIG": [{"name": "exe", "vgrid": ["Generic"]}], + "STORECONFIG": [{"name": "exe", "vgrid": ["Generic"]}], + } saved = pickle(config, res_config_path, self.logger) self.assertTrue(saved) @@ -234,8 +300,10 @@ def test_force_update_vgrid_map(self): updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) self.assertTrue(updated_vgrid_map) - self.assertNotEqual(len(vgrid_map_before.get(VGRIDS, {})), - len(updated_vgrid_map.get(VGRIDS, {}))) + self.assertNotEqual( + len(vgrid_map_before.get(VGRIDS, {})), + len(updated_vgrid_map.get(VGRIDS, {})), + ) self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) def test_refresh_user_map(self): @@ -339,7 +407,7 @@ def test_get_vgrid_map_vgrids(self): vgrid_list = get_vgrid_map_vgrids(self.configuration) self.assertTrue(isinstance(vgrid_list, list)) - self.assertEqual(['Generic'], vgrid_list) + self.assertEqual(["Generic"], vgrid_list) def test_user_owned_res_exes(self): """Test user_owned_res_exes returns owned execution nodes""" @@ -367,7 +435,8 @@ def test_user_allowed_res_units(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) allowed = user_allowed_res_units( - self.configuration, self.TEST_OWNER_DN, "exe") + self.configuration, self.TEST_OWNER_DN, "exe" + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) @@ -394,7 +463,8 @@ def test_user_allowed_res_stores(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) allowed = user_allowed_res_stores( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) @@ -419,23 +489,29 @@ def test_user_visible_res_stores(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) visible = user_visible_res_stores( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(visible, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, visible) def test_user_allowed_user_confs(self): """Test user_allowed_user_confs returns allowed user confs""" - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) - - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) + + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_user_map(self.configuration) allowed = user_allowed_user_confs( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_OWNER_UUID, allowed) self.assertIn(self.TEST_MEMBER_UUID, allowed) @@ -443,32 +519,36 @@ def test_user_allowed_user_confs(self): def test_fill_placeholder_cache(self): """Test fill_placeholder_cache populates cache""" cache = {} - fill_placeholder_cache(self.configuration, cache, [ - self.TEST_VGRID_NAME]) + fill_placeholder_cache( + self.configuration, cache, [self.TEST_VGRID_NAME] + ) self.assertIn(self.TEST_VGRID_NAME, cache) def test_is_vgrid_parent_placeholder(self): """Test is_vgrid_parent_placeholder detection""" - test_path = os.path.join(self.configuration.user_home, 'testvgrid') - result = is_vgrid_parent_placeholder(self.configuration, test_path, - test_path) + test_path = os.path.join(self.configuration.user_home, "testvgrid") + result = is_vgrid_parent_placeholder( + self.configuration, test_path, test_path + ) self.assertIsNone(result) def test_resources_using_re_notfound(self): """Test RE with no assigned resources returns empty list""" # Nonexistent RE should have no resources - res_list = resources_using_re(self.configuration, 'NoSuchRE') + res_list = resources_using_re(self.configuration, "NoSuchRE") self.assertEqual(res_list, []) def test_vgrid_inherit_map_single(self): """Test inheritance mapping with single vgrid""" - test_settings = [('vgrid_name', self.TEST_VGRID_NAME), - ('hidden', True)] + test_settings = [ + ("vgrid_name", self.TEST_VGRID_NAME), + ("hidden", True), + ] test_map = { VGRIDS: { self.TEST_VGRID_NAME: { SETTINGS: test_settings, - OWNERS: [self.TEST_OWNER_DN] + OWNERS: [self.TEST_OWNER_DN], } } } @@ -477,13 +557,13 @@ def test_vgrid_inherit_map_single(self): self.assertIn(self.TEST_VGRID_NAME, vgrid_data) settings_dict = dict(vgrid_data[self.TEST_VGRID_NAME][SETTINGS]) self.assertIs(type(settings_dict), dict) - self.assertEqual(settings_dict.get('hidden'), True) + self.assertEqual(settings_dict.get("hidden"), True) # TODO: move these two modified tests to a test_mig_shared_modified.py def test_check_vgrids_modified_initial(self): """Verify initial modified vgrids list marks ALL and empty on reset""" modified, stamp = check_vgrids_modified(self.configuration) - self.assertEqual(modified, ['ALL']) + self.assertEqual(modified, ["ALL"]) reset_vgrids_modified(self.configuration) modified, stamp = check_vgrids_modified(self.configuration) self.assertEqual(modified, []) @@ -506,10 +586,9 @@ def test_user_vgrid_access(self): self._provision_test_user(self, TEST_USER_DN) # Start with global access to default vgrid - allowed_vgrids = user_vgrid_access(self.configuration, - TEST_USER_DN) + allowed_vgrids = user_vgrid_access(self.configuration, TEST_USER_DN) - self.assertIn('Generic', allowed_vgrids) + self.assertIn("Generic", allowed_vgrids) self.assertTrue(len(allowed_vgrids), 1) # Create private vgrid self._create_vgrid(self.TEST_VGRID_NAME, owners=[TEST_USER_DN]) @@ -517,20 +596,21 @@ def test_user_vgrid_access(self): initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) - allowed_vgrids = user_vgrid_access(self.configuration, - TEST_USER_DN) + allowed_vgrids = user_vgrid_access(self.configuration, TEST_USER_DN) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) def test_res_vgrid_access(self): """Minimal test for resource vgrid participation""" # Only Generic access initially allowed_vgrids = res_vgrid_access( - self.configuration, self.TEST_RESOURCE_ID) - self.assertEqual(allowed_vgrids, ['Generic']) + self.configuration, self.TEST_RESOURCE_ID + ) + self.assertEqual(allowed_vgrids, ["Generic"]) # Add to vgrid self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, resources=[ - self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, resources=[self.TEST_RESOURCE_ID] + ) # Refresh maps to reflect new content initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) @@ -550,33 +630,40 @@ def test_vgrid_map_refresh(self): self._verify_vgrid_map_integrity(updated_vgrid_map) vgrids = updated_vgrid_map.get(VGRIDS, {}) self.assertIn(self.TEST_VGRID_NAME, vgrids) - self.assertEqual(vgrids[self.TEST_VGRID_NAME] - [OWNERS], [self.TEST_OWNER_DN]) + self.assertEqual( + vgrids[self.TEST_VGRID_NAME][OWNERS], [self.TEST_OWNER_DN] + ) def test_user_map_access(self): """Test user permissions through cached access maps""" # Add user as member - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Verify member access - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) def test_resource_map_update(self): """Verify resource visibility in cache""" # Check cached resource map does not yet contain entry - res_map_before, _ = load_resource_map(self.configuration, - caching=True) + res_map_before, _ = load_resource_map(self.configuration, caching=True) self.assertEqual(res_map_before, {}) # Add vgrid with assigned resource self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + resources=[self.TEST_RESOURCE_ID], + ) updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) @@ -594,11 +681,13 @@ def test_resource_map_update(self): def test_settings_inheritance(self): """Test inherited settings propagation through cached maps""" # Create top and sub vgrids with 'hidden' setting on top vgrid - top_settings = [('vgrid_name', self.TEST_VGRID_NAME), - ('hidden', True)] - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - settings=top_settings) - sub_vgrid = os.path.join(self.TEST_VGRID_NAME, 'subvgrid') + top_settings = [("vgrid_name", self.TEST_VGRID_NAME), ("hidden", True)] + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + settings=top_settings, + ) + sub_vgrid = os.path.join(self.TEST_VGRID_NAME, "subvgrid") self._create_vgrid(sub_vgrid) # Force refresh of cached map @@ -618,7 +707,7 @@ def test_settings_inheritance(self): self.assertTrue(top_settings_dict) # Verify hidden setting in cache - self.assertEqual(top_settings_dict.get('hidden'), True) + self.assertEqual(top_settings_dict.get("hidden"), True) # Retrieve sub vgrid settings from cached map sub_vgrid_data = vgrid_data.get(sub_vgrid, {}) @@ -626,10 +715,9 @@ def test_settings_inheritance(self): sub_settings_dict = dict(sub_vgrid_data.get(SETTINGS, [])) # Verify hidden setting unset without inheritance - self.assertFalse(sub_settings_dict.get('hidden')) + self.assertFalse(sub_settings_dict.get("hidden")) - inherited_map = vgrid_inherit_map( - self.configuration, updated_vgrid_map) + inherited_map = vgrid_inherit_map(self.configuration, updated_vgrid_map) vgrid_data = inherited_map.get(VGRIDS, {}) self.assertTrue(vgrid_data) @@ -639,12 +727,12 @@ def test_settings_inheritance(self): sub_settings_dict = dict(sub_vgrid_data.get(SETTINGS, [])) # Verify hidden setting inheritance - self.assertEqual(sub_settings_dict.get('hidden'), True) + self.assertEqual(sub_settings_dict.get("hidden"), True) def test_unmap_inheritance(self): """Test unmap_inheritance clears inherited mappings""" self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN]) - sub_vgrid = os.path.join(self.TEST_VGRID_NAME, 'subvgrid') + sub_vgrid = os.path.join(self.TEST_VGRID_NAME, "subvgrid") self._create_vgrid(sub_vgrid) # Force refresh of cached map @@ -653,8 +741,9 @@ def test_unmap_inheritance(self): self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) # Unmap and verify mark modified - unmap_inheritance(self.configuration, self.TEST_VGRID_NAME, - self.TEST_OWNER_DN) + unmap_inheritance( + self.configuration, self.TEST_VGRID_NAME, self.TEST_OWNER_DN + ) modified, stamp = check_vgrids_modified(self.configuration) self.assertEqual(modified, [self.TEST_VGRID_NAME, sub_vgrid]) @@ -662,8 +751,9 @@ def test_unmap_inheritance(self): def test_user_map_fields(self): """Verify user map includes complete profile/settings data""" # First add a couple of test users - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) # Force fresh user map initial_vgrid_map = force_update_vgrid_map(self.configuration) @@ -680,8 +770,11 @@ def test_resource_revoked_access(self): """Verify resource removal propagates through cached maps""" # First add resource and vgrid self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + resources=[self.TEST_RESOURCE_ID], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) @@ -697,9 +790,14 @@ def test_resource_revoked_access(self): self.assertIn(self.TEST_RESOURCE_ID, initial_map) # Remove resource assignment from vgrid - success_and_msg = vgrid_set_entities(self.configuration, self.TEST_VGRID_NAME, - 'resources', [], allow_empty=True) - self.assertEqual(success_and_msg, (True, '')) + success_and_msg = vgrid_set_entities( + self.configuration, + self.TEST_VGRID_NAME, + "resources", + [], + allow_empty=True, + ) + self.assertEqual(success_and_msg, (True, "")) updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) @@ -717,9 +815,9 @@ def test_resource_revoked_access(self): def test_non_recursive_inheritance(self): """Verify non-recursive map excludes nested vgrids""" # Create parent+child vgrids - parent_vgrid = 'parent' + parent_vgrid = "parent" self._create_vgrid(parent_vgrid, owners=[self.TEST_OWNER_DN]) - child_vgrid = os.path.join(parent_vgrid, 'child') + child_vgrid = os.path.join(parent_vgrid, "child") self._create_vgrid(child_vgrid, members=[self.TEST_MEMBER_DN]) # Force update to avoid auto caching and get non-recursive map @@ -733,21 +831,25 @@ def test_non_recursive_inheritance(self): # Child should still appear when non-recursive but just not inherit self.assertIn(child_vgrid, vgrid_map.get(VGRIDS, {})) # Check owners and members to verify they aren't inherited - self.assertEqual(vgrid_map[VGRIDS][parent_vgrid][OWNERS], - [self.TEST_OWNER_DN]) + self.assertEqual( + vgrid_map[VGRIDS][parent_vgrid][OWNERS], [self.TEST_OWNER_DN] + ) self.assertEqual(len(vgrid_map[VGRIDS][parent_vgrid][MEMBERS]), 0) self.assertEqual(len(vgrid_map[VGRIDS][child_vgrid][OWNERS]), 0) - self.assertEqual(vgrid_map[VGRIDS][child_vgrid][MEMBERS], - [self.TEST_MEMBER_DN]) + self.assertEqual( + vgrid_map[VGRIDS][child_vgrid][MEMBERS], [self.TEST_MEMBER_DN] + ) def test_hidden_setting_propagation(self): """Verify hidden=True propagates to not infect parent settings""" - parent_vgrid = 'parent' + parent_vgrid = "parent" self._create_vgrid(parent_vgrid, owners=[self.TEST_OWNER_DN]) - child_vgrid = os.path.join(parent_vgrid, 'child') - self._create_vgrid(child_vgrid, owners=[self.TEST_OWNER_DN], - settings=[('vgrid_name', child_vgrid), - ('hidden', True)]) + child_vgrid = os.path.join(parent_vgrid, "child") + self._create_vgrid( + child_vgrid, + owners=[self.TEST_OWNER_DN], + settings=[("vgrid_name", child_vgrid), ("hidden", True)], + ) # Verify parent remains visible in cache updated_vgrid_map = force_update_vgrid_map(self.configuration) @@ -756,64 +858,79 @@ def test_hidden_setting_propagation(self): self.assertIn(child_vgrid, updated_vgrid_map.get(VGRIDS, {})) parent_data = updated_vgrid_map.get(VGRIDS, {}).get(parent_vgrid, {}) parent_settings = dict(parent_data.get(SETTINGS, [])) - self.assertNotEqual(parent_settings.get('hidden'), True) + self.assertNotEqual(parent_settings.get("hidden"), True) def test_default_vgrid_access(self): """Verify special access rules for default vgrid""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Even non-member should have access to default vgrid - participant = check_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN, - 'Generic') + participant = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, "Generic" + ) self.assertFalse(participant) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN) - self.assertIn('Generic', allowed_vgrids) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN + ) + self.assertIn("Generic", allowed_vgrids) # Invalid vgrid should not allow any participation or access - participant = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - 'invalid-vgrid-name') + participant = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, "invalid-vgrid-name" + ) self.assertFalse(participant) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_MEMBER_DN) - self.assertNotIn('invalid-vgrid-name', allowed_vgrids) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_MEMBER_DN + ) + self.assertNotIn("invalid-vgrid-name", allowed_vgrids) def test_general_vgrid_access(self): """Verify general access rules for vgrids""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Test vgrid must allow owner and members access - allowed = check_vgrid_access(self.configuration, self.TEST_OWNER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OWNER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OWNER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OWNER_DN + ) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_MEMBER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_MEMBER_DN + ) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) # Test vgrid must reject allow outsider access - allowed = check_vgrid_access(self.configuration, self.TEST_OUTSIDER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, self.TEST_VGRID_NAME + ) self.assertFalse(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN + ) self.assertNotIn(self.TEST_VGRID_NAME, allowed_vgrids) def test_user_allowed_res_confs(self): @@ -821,44 +938,49 @@ def test_user_allowed_res_confs(self): # Create test user and add test resource to vgrid self._provision_test_user(self, TEST_USER_DN) self._create_resource(self.TEST_RESOURCE_ID, [TEST_USER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[TEST_USER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[TEST_USER_DN], + resources=[self.TEST_RESOURCE_ID], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) # Owner should be allowed access - allowed = user_allowed_res_confs(self.configuration, - TEST_USER_DN) + allowed = user_allowed_res_confs(self.configuration, TEST_USER_DN) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) def test_user_visible_res_confs(self): """Minimal test for user_visible_res_confs""" # Owner should see owned resources even without vgrid access - self._create_resource(self.TEST_RESOURCE_ID, - owners=[self.TEST_OWNER_DN]) + self._create_resource( + self.TEST_RESOURCE_ID, owners=[self.TEST_OWNER_DN] + ) force_update_resource_map(self.configuration) - visible = user_visible_res_confs( - self.configuration, self.TEST_OWNER_DN) + visible = user_visible_res_confs(self.configuration, self.TEST_OWNER_DN) self.assertIn(self.TEST_RESOURCE_ALIAS, visible) def test_user_visible_user_confs(self): """Minimal test for user_visible_user_confs""" # Owners should see themselves in auto map # NOTE: use provision users to skip fixtures here - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) force_update_user_map(self.configuration) visible = user_visible_user_confs( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertIn(self.TEST_OWNER_UUID, visible) def test_get_re_provider_map(self): """Test RE provider map includes test resource""" - test_re = 'Python' - res_config = {'RUNTIMEENVIRONMENT': [(test_re, '/python/path')]} - self._create_resource(self.TEST_RESOURCE_ID, [ - self.TEST_OWNER_DN], res_config) + test_re = "Python" + res_config = {"RUNTIMEENVIRONMENT": [(test_re, "/python/path")]} + self._create_resource( + self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN], res_config + ) # Update maps to include new resource force_update_resource_map(self.configuration) @@ -870,10 +992,11 @@ def test_get_re_provider_map(self): def test_resources_using_re(self): """Test finding resources with specific runtime environment""" - test_re = 'Bash' - res_config = {'RUNTIMEENVIRONMENT': [(test_re, '/bash/path')]} - self._create_resource(self.TEST_RESOURCE_ID, [ - self.TEST_OWNER_DN], res_config) + test_re = "Bash" + res_config = {"RUNTIMEENVIRONMENT": [(test_re, "/bash/path")]} + self._create_resource( + self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN], res_config + ) # Refresh resource map force_update_resource_map(self.configuration) @@ -909,28 +1032,32 @@ def test_unmap_resource(self): def test_access_nonexistent_vgrid(self): """Ensure checks fail cleanly for non-existent vgrid""" - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - 'no-such-vgrid') + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, "no-such-vgrid" + ) self.assertFalse(allowed) # Should not appear in allowed vgrids allowed_vgrids = user_vgrid_access( - self.configuration, self.TEST_MEMBER_DN) - self.assertNotIn('no-such-vgrid', allowed_vgrids) + self.configuration, self.TEST_MEMBER_DN + ) + self.assertNotIn("no-such-vgrid", allowed_vgrids) def test_empty_member_access(self): """Verify members-only vgrid rejects outsiders""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, owners=[], members=[self.TEST_MEMBER_DN] + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Outsider should be blocked despite no owners - allowed = check_vgrid_access(self.configuration, self.TEST_OUTSIDER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, self.TEST_VGRID_NAME + ) self.assertFalse(allowed) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_unittest_testcore.py b/tests/test_mig_unittest_testcore.py index b27a74d33..127f77fe7 100644 --- a/tests/test_mig_unittest_testcore.py +++ b/tests/test_mig_unittest_testcore.py @@ -31,28 +31,28 @@ import os import sys -from tests.support import MigTestCase, testmain - from mig.unittest.testcore import main as testcore_main +from tests.support import MigTestCase, testmain class MigUnittestTestcore(MigTestCase): def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_existing_main(self): def raise_on_error_exit(exit_code, identifying_message=None): if exit_code != 0: if identifying_message is None: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) - print("") # account for wrapped tests printing to console + print("") # account for wrapped tests printing to console testcore_main(self.configuration, _exit=raise_on_error_exit) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_wsgibin.py b/tests/test_mig_wsgibin.py index 1d0f9ecdb..3c487ad60 100644 --- a/tests/test_mig_wsgibin.py +++ b/tests/test_mig_wsgibin.py @@ -38,12 +38,24 @@ # Imports required for the unit test wrapping import mig.shared.returnvalues as returnvalues -from mig.shared.base import allow_script, brief_list, client_dir_id, \ - client_id_dir, get_short_id, invisible_path +from mig.shared.base import ( + allow_script, + brief_list, + client_dir_id, + client_id_dir, + get_short_id, + invisible_path, +) from mig.shared.compat import SimpleNamespace + # Imports required for the unit tests themselves -from tests.support import MIG_BASE, MigTestCase, ensure_dirs_exist, \ - is_path_within, testmain +from tests.support import ( + MIG_BASE, + MigTestCase, + ensure_dirs_exist, + is_path_within, + testmain, +) from tests.support.snapshotsupp import SnapshotAssertMixin from tests.support.wsgisupp import WsgiAssertMixin, prepare_wsgi @@ -60,12 +72,12 @@ def __init__(self): def handle_decl(self, decl): try: - decltag, decltype = decl.split(' ') + decltag, decltype = decl.split(" ") except Exception: decltag = "" decltype = "" - if decltag.upper() == 'DOCTYPE': + if decltag.upper() == "DOCTYPE": self._saw_doctype = True else: decltype = "unknown" @@ -73,11 +85,11 @@ def handle_decl(self, decl): self._doctype = decltype def handle_starttag(self, tag, attrs): - if tag == 'html': + if tag == "html": if self._saw_tags: - tag_html = 'not_first' + tag_html = "not_first" else: - tag_html = 'was_first' + tag_html = "was_first" self._tag_html = tag_html self._saw_tags = True @@ -85,13 +97,13 @@ def assert_basics(self): if not self._saw_doctype: raise AssertionError("missing DOCTYPE") - if self._doctype != 'html': + if self._doctype != "html": raise AssertionError("non-html DOCTYPE") - if self._tag_html == 'none': + if self._tag_html == "none": raise AssertionError("missing ") - if self._tag_html != 'was_first': + if self._tag_html != "was_first": raise AssertionError("first tag seen was not ") @@ -110,13 +122,13 @@ def handle_data(self, *args, **kwargs): def handle_starttag(self, tag, attrs): DocumentBasicsHtmlParser.handle_starttag(self, tag, attrs) - if tag == 'title': + if tag == "title": self._within_title = True def handle_endtag(self, tag): DocumentBasicsHtmlParser.handle_endtag(self, tag) - if tag == 'title': + if tag == "title": self._within_title = False def title(self, trim_newlines=False): @@ -138,7 +150,7 @@ def _import_forcibly(module_name, relative_module_dir=None): that resides within a non-module directory. """ - module_path = os.path.join(MIG_BASE, 'mig') + module_path = os.path.join(MIG_BASE, "mig") if relative_module_dir is not None: module_path = os.path.join(module_path, relative_module_dir) sys.path.append(module_path) @@ -148,7 +160,7 @@ def _import_forcibly(module_name, relative_module_dir=None): # Imports of the code under test (indirect import needed here) -migwsgi = _import_forcibly('migwsgi', relative_module_dir='wsgi-bin') +migwsgi = _import_forcibly("migwsgi", relative_module_dir="wsgi-bin") class FakeBackend: @@ -160,8 +172,8 @@ class FakeBackend: def __init__(self): self.output_objects = [ - {'object_type': 'start'}, - {'object_type': 'title', 'text': 'ERROR'}, + {"object_type": "start"}, + {"object_type": "title", "text": "ERROR"}, ] self.return_value = returnvalues.ERROR @@ -175,6 +187,7 @@ def set_response(self, output_objects, returnvalue): def to_import_module(self): def _import_module(module_path): return self + return _import_module @@ -182,11 +195,11 @@ class MigWsgibin(MigTestCase, SnapshotAssertMixin, WsgiAssertMixin): """WSGI glue test cases""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/') + self.fake_wsgi = prepare_wsgi(self.configuration, "http://localhost/") self.application_args = ( self.fake_wsgi.environ, @@ -208,51 +221,45 @@ def assertHtmlTitle(self, value, title_text=None, trim_newlines=False): def test_top_level_request_returns_status_ok(self): wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) def test_objects_containing_only_title_has_expected_title(self): - output_objects = [ - {'object_type': 'title', 'text': 'TEST'} - ] + output_objects = [{"object_type": "title", "text": "TEST"}] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) - self.assertHtmlTitle(output, title_text='TEST', trim_newlines=True) + self.assertHtmlTitle(output, title_text="TEST", trim_newlines=True) def test_objects_containing_only_title_matches_snapshot(self): - output_objects = [ - {'object_type': 'title', 'text': 'TEST'} - ] + output_objects = [{"object_type": "title", "text": "TEST"}] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) - self.assertSnapshot(output, extension='html') + self.assertSnapshot(output, extension="html") -class MigWsgibin_output_objects(MigTestCase, WsgiAssertMixin, - SnapshotAssertMixin): +class MigWsgibin_output_objects( + MigTestCase, WsgiAssertMixin, SnapshotAssertMixin +): """Unit tests for output_object related part of wsgi functions.""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/') + self.fake_wsgi = prepare_wsgi(self.configuration, "http://localhost/") self.application_args = ( self.fake_wsgi.environ, @@ -273,51 +280,46 @@ def test_unknown_object_type_generates_valid_error_page(self): self.logger.forgive_errors() output_objects = [ { - 'object_type': 'nonexistent', # trigger error handling path + "object_type": "nonexistent", # trigger error handling path } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) - output, _ = self.assertWsgiResponse( - wsgi_result, self.fake_wsgi, 200) + output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) self.assertIsValidHtmlDocument(output) def test_objects_with_type_text(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) self.assertSnapshotOfHtmlContent(output) -class MigWsgibin_input_object(MigTestCase, WsgiAssertMixin, - SnapshotAssertMixin): +class MigWsgibin_input_object( + MigTestCase, WsgiAssertMixin, SnapshotAssertMixin +): """Unit tests for input_object related part of wsgi functions.""" - DUMMY_BYTES = 'dummyæøå-ßßß-value'.encode('utf-8') + DUMMY_BYTES = "dummyæøå-ßßß-value".encode("utf-8") def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() @@ -330,8 +332,9 @@ def _prepare_test(self, form_overrides=None, custom_env=None): # Set up a wsgi input with non-ascii bytes and open it in binary mode # If form_overrides is passed a list of tuples like [('key' 'val')] it # produces a fake_wsgi input on the form: b'key=val' - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/', - form=form_overrides) + self.fake_wsgi = prepare_wsgi( + self.configuration, "http://localhost/", form=form_overrides + ) # override the default environ fields from wsgisupp if custom_env: self.fake_wsgi.environ.update(custom_env) @@ -348,7 +351,7 @@ def _prepare_test(self, form_overrides=None, custom_env=None): # NOTE: enabled with underlying wsgi use of Fieldstorage fixed def test_put_text_plain_with_binary_input_succeeds(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = { "REQUEST_METHOD": "PUT", "CONTENT_TYPE": "text/plain", @@ -358,20 +361,16 @@ def test_put_text_plain_with_binary_input_succeeds(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must succeed with HTTP 200 when it parses input @@ -379,7 +378,7 @@ def test_put_text_plain_with_binary_input_succeeds(self): @unittest.skip("disabled with underlying wsgi use of Fieldstorage fixed") def test_put_text_plain_with_binary_input_fails(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = { "REQUEST_METHOD": "PUT", "CONTENT_TYPE": "text/plain", @@ -389,52 +388,44 @@ def test_put_text_plain_with_binary_input_fails(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) # TODO: can we add assertLogs to check error log explicitly? wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must fail with HTTP 500 from failing to parse input output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 500) def test_post_url_encoded_with_binary_input_succeeds(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = None self._prepare_test(test_form, test_env) output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must succeed with HTTP 200 when it parses input output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_support.py b/tests/test_support.py index 56b74f594..b68c4249c 100644 --- a/tests/test_support.py +++ b/tests/test_support.py @@ -28,15 +28,21 @@ """Unit tests for the tests module pointed to in the filename""" from __future__ import print_function + import os import sys import unittest -from tests.support import MigTestCase, PY2, testmain, temppath, \ - AssertOver, FakeConfiguration - from mig.shared.conf import get_configuration_object from mig.shared.configuration import Configuration +from tests.support import ( + PY2, + AssertOver, + FakeConfiguration, + MigTestCase, + temppath, + testmain, +) class InstrumentedAssertOver(AssertOver): @@ -62,6 +68,7 @@ def to_check_callable(self): def _wrapped_check_callable(): self._check_callable_called = True _check_callable() + self._check_callable = _wrapped_check_callable return _wrapped_check_callable @@ -71,8 +78,8 @@ class SupportTestCase(MigTestCase): def _class_attribute(self, name, **kwargs): cls = type(self) - if 'value' in kwargs: - setattr(cls, name, kwargs['value']) + if "value" in kwargs: + setattr(cls, name, kwargs["value"]) else: return getattr(cls, name, None) @@ -80,15 +87,17 @@ def test_requires_requesting_a_configuration(self): with self.assertRaises(AssertionError) as raised: self.configuration theexception = raised.exception - self.assertEqual(str(theexception), - "configuration access but testcase did not request it") + self.assertEqual( + str(theexception), + "configuration access but testcase did not request it", + ) @unittest.skipIf(PY2, "Python 3 only") def test_unclosed_files_are_recorded(self): tmp_path = temppath("support-unclosed", self) def open_without_close(): - with open(tmp_path, 'w'): + with open(tmp_path, "w"): pass open(tmp_path) return @@ -112,11 +121,13 @@ def assert_is_int(value): assert isinstance(value, int) attempt_wrapper = self.assert_over( - values=(1, 2, 3), _AssertOver=InstrumentedAssertOver) + values=(1, 2, 3), _AssertOver=InstrumentedAssertOver + ) # record the wrapper on the test case so the subsequent test can assert against it - self._class_attribute('surviving_attempt_wrapper', - value=attempt_wrapper) + self._class_attribute( + "surviving_attempt_wrapper", value=attempt_wrapper + ) with attempt_wrapper as attempt: attempt(assert_is_int) @@ -124,14 +135,15 @@ def assert_is_int(value): self.assertTrue(attempt_wrapper.has_check_callable()) # cleanup was recorded - self.assertIn(attempt_wrapper.get_check_callable(), - self._cleanup_checks) + self.assertIn( + attempt_wrapper.get_check_callable(), self._cleanup_checks + ) def test_when_asserting_over_multiple_values_after(self): # test name is purposefully after ..._recorded in sort order # such that we can check the check function was called correctly - attempt_wrapper = self._class_attribute('surviving_attempt_wrapper') + attempt_wrapper = self._class_attribute("surviving_attempt_wrapper") self.assertTrue(attempt_wrapper.was_check_callable_called()) @@ -139,7 +151,7 @@ class SupportTestCase_using_fakeconfig(MigTestCase): """Coverage of a MiG Testcase hat requests a fakeconfig""" def _provide_configuration(self): - return 'fakeconfig' + return "fakeconfig" def test_provides_a_fake_configuration(self): configuration = self.configuration @@ -157,10 +169,10 @@ class SupportTestCase_using_testconfig(MigTestCase): """Coverage of a MiG Testcase that requests a testconfig""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_provides_the_test_configuration(self): - expected_last_dir = 'testconfs-py2' if PY2 else 'testconfs-py3' + expected_last_dir = "testconfs-py2" if PY2 else "testconfs-py3" configuration = self.configuration @@ -173,5 +185,5 @@ def test_provides_the_test_configuration(self): self.assertTrue(config_file_last_dir, expected_last_dir) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_tests_support_assertover.py b/tests/test_tests_support_assertover.py index 702a80e94..ae9be10b2 100644 --- a/tests/test_tests_support_assertover.py +++ b/tests/test_tests_support_assertover.py @@ -35,7 +35,7 @@ def assert_a_thing(value): """A simple assert helper to test with""" - assert value.endswith(' thing'), "must end with a thing" + assert value.endswith(" thing"), "must end with a thing" class TestsSupportAssertOver(unittest.TestCase): @@ -44,7 +44,9 @@ class TestsSupportAssertOver(unittest.TestCase): def test_none_failing(self): saw_raise = False try: - with AssertOver(values=('some thing', 'other thing')) as value_block: + with AssertOver( + values=("some thing", "other thing") + ) as value_block: value_block(lambda _: assert_a_thing(_)) except Exception as exc: saw_raise = True @@ -52,13 +54,18 @@ def test_none_failing(self): def test_three_total_two_failing(self): with self.assertRaises(AssertionError) as raised: - with AssertOver(values=('some thing', 'other stuff', 'foobar')) as value_block: + with AssertOver( + values=("some thing", "other stuff", "foobar") + ) as value_block: value_block(lambda _: assert_a_thing(_)) theexception = raised.exception - self.assertEqual(str(theexception), """assertions raised for the following values: + self.assertEqual( + str(theexception), + """assertions raised for the following values: - <'other stuff'> : must end with a thing -- <'foobar'> : must end with a thing""") +- <'foobar'> : must end with a thing""", + ) def test_no_cases(self): with self.assertRaises(AssertionError) as raised: @@ -69,5 +76,5 @@ def test_no_cases(self): self.assertIsInstance(theexception, NoCasesError) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_tests_support_configsupp.py b/tests/test_tests_support_configsupp.py index 0e85cd2c1..8d8d23e2d 100644 --- a/tests/test_tests_support_configsupp.py +++ b/tests/test_tests_support_configsupp.py @@ -27,11 +27,10 @@ """Unit tests for the tests module pointed to in the filename""" +from mig.shared.configuration import Configuration from tests.support import MigTestCase, testmain from tests.support.configsupp import FakeConfiguration -from mig.shared.configuration import Configuration - class TestsSupportConfigsupp_FakeConfiguration(MigTestCase): """Check some basic behaviours of FakeConfiguration instances.""" @@ -43,13 +42,13 @@ def test_consistent_parameters(self): self.maxDiff = None self.assertEqual( Configuration.to_dict(default_configuration), - Configuration.to_dict(fake_configuration) + Configuration.to_dict(fake_configuration), ) def test_only_configuration_keys(self): with self.assertRaises(AssertionError): - FakeConfiguration(bar='1') + FakeConfiguration(bar="1") -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_tests_support_wsgisupp.py b/tests/test_tests_support_wsgisupp.py index 4adcb975d..03e5346e4 100644 --- a/tests/test_tests_support_wsgisupp.py +++ b/tests/test_tests_support_wsgisupp.py @@ -28,15 +28,15 @@ """Unit tests for the tests module pointed to in the filename""" import unittest -from mig.shared.compat import SimpleNamespace +from mig.shared.compat import SimpleNamespace from tests.support import AssertOver from tests.support.wsgisupp import prepare_wsgi def assert_a_thing(value): """A simple assert helper to test with""" - assert value.endswith(' thing'), "must end with a thing" + assert value.endswith(" thing"), "must end with a thing" class TestsSupportWsgisupp_prepare_wsgi(unittest.TestCase): @@ -44,56 +44,57 @@ class TestsSupportWsgisupp_prepare_wsgi(unittest.TestCase): def test_prepare_GET(self): configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, 'http://testhost/some/path') + environ, _ = prepare_wsgi(configuration, "http://testhost/some/path") - self.assertEqual(environ['MIG_CONF'], - '/path/to/the/confs/MiGserver.conf') - self.assertEqual(environ['HTTP_HOST'], 'testhost') - self.assertEqual(environ['PATH_INFO'], '/some/path') - self.assertEqual(environ['REQUEST_METHOD'], 'GET') + self.assertEqual( + environ["MIG_CONF"], "/path/to/the/confs/MiGserver.conf" + ) + self.assertEqual(environ["HTTP_HOST"], "testhost") + self.assertEqual(environ["PATH_INFO"], "/some/path") + self.assertEqual(environ["REQUEST_METHOD"], "GET") def test_prepare_GET_with_query(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, test_url, query={ - 'foo': 'true', - 'bar': 1 - }) + environ, _ = prepare_wsgi( + configuration, test_url, query={"foo": "true", "bar": 1} + ) - self.assertEqual(environ['QUERY_STRING'], 'foo=true&bar=1') + self.assertEqual(environ["QUERY_STRING"], "foo=true&bar=1") def test_prepare_POST(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, test_url, method='POST') + environ, _ = prepare_wsgi(configuration, test_url, method="POST") - self.assertEqual(environ['REQUEST_METHOD'], 'POST') + self.assertEqual(environ["REQUEST_METHOD"], "POST") def test_prepare_POST_with_headers(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) headers = { - 'Authorization': 'Basic XXXX', - 'Content-Length': 0, + "Authorization": "Basic XXXX", + "Content-Length": 0, } environ, _ = prepare_wsgi( - configuration, test_url, method='POST', headers=headers) + configuration, test_url, method="POST", headers=headers + ) - self.assertEqual(environ['CONTENT_LENGTH'], 0) - self.assertEqual(environ['HTTP_AUTHORIZATION'], 'Basic XXXX') + self.assertEqual(environ["CONTENT_LENGTH"], 0) + self.assertEqual(environ["HTTP_AUTHORIZATION"], "Basic XXXX") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()