From 4f2e5fac8a6e8c8438ba00c8b23f2d7be41144a0 Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Wed, 11 Mar 2026 14:03:54 +0100 Subject: [PATCH 1/2] Integrate black into the Makefile replacing previous autopep8 attempt. We are currently trying to have consistent formatting in the codebase but here is currently no standardised way to do this. Based on some growing consensus bring in black and switch local formatting and linting to it. A new `make fmt` target will format files with the appropriate style. The `make lint` target (broken/unused) invokes the same but in checking mode and thus will report violations. Given we are likely to adopt this in stages we define a variable in the Makefile listing the directories we wish to enforce cleanliness for. --- Makefile | 37 ++++++++++++++++++++++++++++++++++++- local-requirements.txt | 2 ++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cd4f81d89..f06bb85e6 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,11 @@ ifndef PY PY = 3 endif +FORMAT_ENFORCE_DIRS = state/ +FORMAT_EXCLUDE_REGEX = '.*' +FORMAT_EXCLUDE_GLOB = '*' +FORMAT_LINE_LENGTH = 80 + LOCAL_PYTHON_BIN = './envhelp/lpython' ifdef PYTHON_BIN @@ -35,7 +40,37 @@ ifneq ($(MIG_ENV),'local') @echo "unavailable outside local development environment" @exit 1 endif - $(LOCAL_PYTHON_BIN) -m autopep8 --ignore E402 -i + @make format-python + +.PHONY:format-python +format-python: + @$(LOCAL_PYTHON_BIN) -m black $(FORMAT_ENFORCE_DIRS) \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --exclude=$(FORMAT_EXCLUDE_REGEX) + @$(LOCAL_PYTHON_BIN) -m isort $(FORMAT_ENFORCE_DIRS) \ + --profile=black \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --skip-glob=$(FORMAT_EXCLUDE_GLOB) + +.PHONY: lint +lint: +ifneq ($(MIG_ENV),'local') + @echo "unavailable outside local development environment" + @exit 1 +endif + @make lint-python + +.PHONY: lint-python +lint-python: + @$(LOCAL_PYTHON_BIN) -m black $(FORMAT_ENFORCE_DIRS) \ + --check \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --exclude $(FORMAT_EXCLUDE_REGEX) + @$(LOCAL_PYTHON_BIN) -m isort $(FORMAT_ENFORCE_DIRS) \ + --check-only \ + --profile=black \ + --line-length=$(FORMAT_LINE_LENGTH) \ + --skip-glob=$(FORMAT_EXCLUDE_GLOB) .PHONY: clean clean: diff --git a/local-requirements.txt b/local-requirements.txt index 7faf3b026..4941e1a6b 100644 --- a/local-requirements.txt +++ b/local-requirements.txt @@ -3,6 +3,8 @@ # This list is mainly used to specify addons needed for the unit tests. # We only need autopep8 on py 3 as it's used in 'make fmt' (with py3) autopep8;python_version >= "3" +black +isort # We need paramiko for the ssh unit tests # NOTE: paramiko-3.0.0 dropped python2 and python3.6 support paramiko;python_version >= "3.7" From e1f035547ac349c7b1a96056bd77641f76156aea Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Wed, 11 Mar 2026 13:58:41 +0100 Subject: [PATCH 2/2] Enforce formatting in mig/lib and tests/ directories. --- Makefile | 7 +- mig/lib/accounting.py | 474 ++++---- mig/lib/events.py | 6 +- mig/lib/janitor.py | 78 +- mig/lib/lustrequota.py | 539 +++++---- mig/lib/quota.py | 16 +- tests/__init__.py | 12 +- tests/support/__init__.py | 106 +- tests/support/_env.py | 6 +- tests/support/assertover.py | 24 +- tests/support/configsupp.py | 19 +- tests/support/fixturesupp.py | 133 ++- tests/support/loggersupp.py | 44 +- tests/support/picklesupp.py | 4 +- tests/support/serversupp.py | 5 +- tests/support/snapshotsupp.py | 29 +- tests/support/suppconst.py | 9 +- tests/support/usersupp.py | 42 +- tests/support/wsgisupp.py | 64 +- tests/test_booleans.py | 3 +- tests/test_mig_install_generateconfs.py | 72 +- tests/test_mig_lib_accounting.py | 271 +++-- tests/test_mig_lib_daemon.py | 131 ++- tests/test_mig_lib_events.py | 158 ++- tests/test_mig_lib_janitor.py | 737 ++++++------ tests/test_mig_lib_quota.py | 11 +- tests/test_mig_lib_xgicore.py | 113 +- tests/test_mig_shared_accountreq.py | 345 +++--- tests/test_mig_shared_auth.py | 653 ++++++----- tests/test_mig_shared_base.py | 1008 ++++++++++------- tests/test_mig_shared_cloud.py | 155 +-- tests/test_mig_shared_compat.py | 12 +- tests/test_mig_shared_configuration.py | 320 +++--- tests/test_mig_shared_fileio.py | 791 +++++++------ tests/test_mig_shared_filemarks.py | 156 ++- tests/test_mig_shared_functionality_cat.py | 200 ++-- ...t_mig_shared_functionality_datatransfer.py | 43 +- tests/test_mig_shared_install.py | 207 ++-- tests/test_mig_shared_jupyter.py | 33 +- tests/test_mig_shared_localfile.py | 24 +- tests/test_mig_shared_pwcrypto.py | 548 +++++---- tests/test_mig_shared_safeeval.py | 61 +- tests/test_mig_shared_safeinput.py | 80 +- tests/test_mig_shared_serial.py | 12 +- tests/test_mig_shared_settings.py | 84 +- tests/test_mig_shared_ssh.py | 29 +- tests/test_mig_shared_transferfunctions.py | 133 ++- tests/test_mig_shared_url.py | 88 +- tests/test_mig_shared_useradm.py | 358 +++--- tests/test_mig_shared_userdb.py | 105 +- tests/test_mig_shared_userio.py | 12 +- tests/test_mig_shared_vgrid.py | 971 ++++++++++------ tests/test_mig_shared_vgridaccess.py | 499 +++++--- tests/test_mig_unittest_testcore.py | 14 +- tests/test_mig_wsgibin.py | 161 ++- tests/test_support.py | 48 +- tests/test_tests_support_assertover.py | 19 +- tests/test_tests_support_configsupp.py | 9 +- tests/test_tests_support_wsgisupp.py | 57 +- 59 files changed, 6099 insertions(+), 4249 deletions(-) diff --git a/Makefile b/Makefile index f06bb85e6..2f7c1ffda 100644 --- a/Makefile +++ b/Makefile @@ -6,11 +6,10 @@ ifndef PY PY = 3 endif -FORMAT_ENFORCE_DIRS = state/ -FORMAT_EXCLUDE_REGEX = '.*' -FORMAT_EXCLUDE_GLOB = '*' +FORMAT_ENFORCE_DIRS = ./mig/lib ./tests +FORMAT_EXCLUDE_REGEX = '.git|tests/data/|tests/fixture/' +FORMAT_EXCLUDE_GLOB = '.git/* tests/data/* tests/fixture/*' FORMAT_LINE_LENGTH = 80 - LOCAL_PYTHON_BIN = './envhelp/lpython' ifdef PYTHON_BIN diff --git a/mig/lib/accounting.py b/mig/lib/accounting.py index 82d3f63f9..3fce1e92f 100644 --- a/mig/lib/accounting.py +++ b/mig/lib/accounting.py @@ -41,11 +41,9 @@ from mig.shared.vgrid import vgrid_list, vgrid_list_vgrids -def __init_accounting_entry(user_bytes=0, - freeze_bytes=0, - vgrid_bytes=None, - peers=None, - ext_users=None): +def __init_accounting_entry( + user_bytes=0, freeze_bytes=0, vgrid_bytes=None, peers=None, ext_users=None +): """Return new user account dict entry""" if vgrid_bytes is None: vgrid_bytes = {} @@ -54,11 +52,13 @@ def __init_accounting_entry(user_bytes=0, if ext_users is None: ext_users = {} - return {'user_bytes': user_bytes, - 'freeze_bytes': freeze_bytes, - 'vgrid_bytes': vgrid_bytes, - 'peers': peers, - 'ext_users': ext_users} + return { + "user_bytes": user_bytes, + "freeze_bytes": freeze_bytes, + "vgrid_bytes": vgrid_bytes, + "peers": peers, + "ext_users": ext_users, + } def __get_owned_vgrid(configuration, verbose=False): @@ -67,17 +67,16 @@ def __get_owned_vgrid(configuration, verbose=False): NOTE: First owner of top-vgrid is primary owner""" logger = configuration.logger result = {} - (status, vgrids) = vgrid_list_vgrids(configuration) + status, vgrids = vgrid_list_vgrids(configuration) if status: for vgrid_name in vgrids: # print("checking vgrid: %s" % check_vgrid_name) - (owners_status, owners_list) = vgrid_list(vgrid_name, - 'owners', - configuration, - recursive=True) + owners_status, owners_list = vgrid_list( + vgrid_name, "owners", configuration, recursive=True + ) # Find first non-zero owner # NOTE: Some owner files contain empty owners) - owner = '' + owner = "" if owners_status and owners_list: owner = next(ent for ent in owners_list if ent) if owner: @@ -85,8 +84,7 @@ def __get_owned_vgrid(configuration, verbose=False): owned_vgrids.append(vgrid_name) result[owner] = owned_vgrids else: - msg = "Failed to find owner for vgrid: %s" \ - % vgrid_name + msg = "Failed to find owner for vgrid: %s" % vgrid_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -108,36 +106,37 @@ def __get_peers_map(configuration, verbose=False): accepted_peers = get_accepted_peers(configuration, client_id) for ext_client_id, value in accepted_peers.items(): if not isinstance(value, dict): - msg = "Invalid peers format: %s: %s: %s" \ - % (client_id, ext_client_id, value) + msg = "Invalid peers format: %s: %s: %s" % ( + client_id, + ext_client_id, + value, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue # Map external users to their peer - ext_users = peer_result.get('ext_users', {}) + ext_users = peer_result.get("ext_users", {}) ext_users[ext_client_id] = value - peer_result['ext_users'] = ext_users + peer_result["ext_users"] = ext_users # Map peers to their external user ext_result = result.get(ext_client_id, {}) - peers = ext_result.get('peers', {}) + peers = ext_result.get("peers", {}) peers[client_id] = value - ext_result['peers'] = peers + ext_result["peers"] = peers result[ext_client_id] = ext_result result[client_id] = peer_result return result -def update_accounting(configuration, - verbose=False): +def update_accounting(configuration, verbose=False): """Update user accounting information""" logger = configuration.logger retval = True - result = {'accounting': {}, - 'quota': {}} - accounting = result['accounting'] - result['timestamp'] = int(time.time()) + result = {"accounting": {}, "quota": {}} + accounting = result["accounting"] + result["timestamp"] = int(time.time()) # Map vgrid to their primary owner msg = "Creating vgrid owners map ..." @@ -184,27 +183,24 @@ def update_accounting(configuration, quota_info_json = entry.path quota_fs = entry.name.replace(".json", "") else: - logger.debug("Skipping non quota info entry: %s" - % entry.name) + logger.debug("Skipping non quota info entry: %s" % entry.name) continue quota_info = None # Try .pck first then .json if quota_info_pck: quota_info = unpickle(quota_info_pck, configuration.logger) elif quota_info_json: - quota_info = load_json(quota_info_json, - configuration.logger, - convert_utf8=False) + quota_info = load_json( + quota_info_json, configuration.logger, convert_utf8=False + ) if not quota_info: - msg = "Failed to load quota info for FS entry: %s" \ - % entry.name + msg = "Failed to load quota info for FS entry: %s" % entry.name logger.error(msg) if verbose: print("ERROR: %s" % msg) retval = False continue - quota_basepath = os.path.join(configuration.quota_home, - quota_fs) + quota_basepath = os.path.join(configuration.quota_home, quota_fs) if not os.path.isdir(quota_basepath): msg = "Missing quota_basepath: %r" % quota_basepath logger.error(msg) @@ -212,14 +208,15 @@ def update_accounting(configuration, print("ERROR: %s" % msg) retval = False continue - quota_mtime = quota_info.get('mtime', 0) - quota_datestr = datetime.datetime.fromtimestamp(quota_mtime) \ - .strftime('%d/%m/%Y-%H:%M:%S') - result['quota'][quota_fs] = {'mtime': quota_mtime} + quota_mtime = quota_info.get("mtime", 0) + quota_datestr = datetime.datetime.fromtimestamp( + quota_mtime + ).strftime("%d/%m/%Y-%H:%M:%S") + result["quota"][quota_fs] = {"mtime": quota_mtime} # User quota - user_path = os.path.join(quota_basepath, 'user') + user_path = os.path.join(quota_basepath, "user") if not os.path.isdir(user_path): msg = "Missing quota user path: %r" % user_path logger.error(msg) @@ -228,11 +225,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s user quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - user_path) + msg = "Scanning %s user quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + user_path, + ) logger.info(msg) if verbose: print(msg) @@ -241,30 +239,34 @@ def update_accounting(configuration, for user_entry in it2: if user_entry.name.endswith(".pck"): client_id = client_dir_id( - user_entry.name.replace('.pck', '')) + user_entry.name.replace(".pck", "") + ) elif user_entry.name.endswith(".json"): client_id = client_dir_id( - user_entry.name.replace('.json', '')) + user_entry.name.replace(".json", "") + ) else: - logger.debug("Skipping non-user entry: %s" - % user_entry.name) + logger.debug( + "Skipping non-user entry: %s" % user_entry.name + ) continue user_quota_files[client_id] = user_entry.path t2 = time.time() - msg = "Scanned %s user quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - user_path, - (t2 - t1)) + msg = "Scanned %s user quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + user_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) # Vgrid quota - vgrid_path = os.path.join(quota_basepath, 'vgrid') + vgrid_path = os.path.join(quota_basepath, "vgrid") if not os.path.isdir(vgrid_path): msg = "Missing quota vgrid path: %r" % vgrid_path logger.error(msg) @@ -273,11 +275,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s vgrid quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - vgrid_path) + msg = "Scanning %s vgrid quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + vgrid_path, + ) logger.info(msg) if verbose: print(msg) @@ -286,26 +289,29 @@ def update_accounting(configuration, for vgrid_entry in it2: if vgrid_entry.name.endswith(".pck"): vgrid_name = force_native_str( - vgrid_entry.name.replace('.pck', '')) + vgrid_entry.name.replace(".pck", "") + ) elif vgrid_entry.name.endswith(".json"): vgrid_name = force_native_str( - vgrid_entry.name.replace('.json', '')) + vgrid_entry.name.replace(".json", "") + ) else: # logger.debug("Skipping non-vgrid entry: %s" # % vgrid_entry.name) continue # NOTE: sub-vgrids uses ':' # as delimiter in 'vgrid_files_writable' - vgrid_name = vgrid_name.replace(':', '/') + vgrid_name = vgrid_name.replace(":", "/") # print("%s: %s" % (vgrid_name, vgrid_entry.path)) vgrid_quota_files[vgrid_name] = vgrid_entry.path t2 = time.time() - msg = "Scanned %s vgrid quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - vgrid_path, - (t2 - t1)) + msg = "Scanned %s vgrid quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + vgrid_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) @@ -313,7 +319,7 @@ def update_accounting(configuration, # Freeze quota if configuration.site_enable_freeze: - freeze_path = os.path.join(quota_basepath, 'freeze') + freeze_path = os.path.join(quota_basepath, "freeze") if not os.path.isdir(freeze_path): msg = "Missing quota freeze path: %r" % freeze_path logger.error(msg) @@ -322,11 +328,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s freeze quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - freeze_path) + msg = "Scanning %s freeze quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + freeze_path, + ) logger.info(msg) if verbose: print(msg) @@ -335,23 +342,27 @@ def update_accounting(configuration, for freeze_entry in it2: if freeze_entry.name.endswith(".pck"): freeze_client_id = client_dir_id( - freeze_entry.name.replace('.pck', '')) + freeze_entry.name.replace(".pck", "") + ) elif freeze_entry.name.endswith(".json"): freeze_client_id = client_dir_id( - freeze_entry.name.replace('.json', '')) + freeze_entry.name.replace(".json", "") + ) else: - logger.debug("Skipping non-freeze entry: %s" - % freeze_entry.name) + logger.debug( + "Skipping non-freeze entry: %s" + % freeze_entry.name + ) continue - freeze_quota_files[freeze_client_id] \ - = freeze_entry.path + freeze_quota_files[freeze_client_id] = freeze_entry.path t2 = time.time() - msg = "Scanned %s freeze quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - freeze_path, - (t2 - t1)) + msg = "Scanned %s freeze quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + freeze_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) @@ -361,17 +372,18 @@ def update_accounting(configuration, vgrids_accounted = [] for client_id, user_quota_filepath in user_quota_files.items(): # Init user accounting - peers = peers_map.get(client_id, {}).get('peers', {}) - ext_users = peers_map.get(client_id, {}).get('ext_users', {}) - accounting[client_id] = __init_accounting_entry(peers=peers, - ext_users=ext_users) + peers = peers_map.get(client_id, {}).get("peers", {}) + ext_users = peers_map.get(client_id, {}).get("ext_users", {}) + accounting[client_id] = __init_accounting_entry( + peers=peers, ext_users=ext_users + ) # Extract user bytes - if user_quota_filepath.endswith('.pck'): + if user_quota_filepath.endswith(".pck"): user_quota = unpickle(user_quota_filepath, configuration) - elif user_quota_filepath.endswith('.json'): - user_quota = load_json(user_quota_filepath, - configuration.logger, - convert_utf8=False) + elif user_quota_filepath.endswith(".json"): + user_quota = load_json( + user_quota_filepath, configuration.logger, convert_utf8=False + ) else: msg = "Invalid user quota file: %r" % user_quota_filepath logger.error(msg) @@ -380,11 +392,13 @@ def update_accounting(configuration, retval = False continue try: - accounting[client_id]['user_bytes'] = user_quota['bytes'] + accounting[client_id]["user_bytes"] = user_quota["bytes"] except Exception as err: - accounting[client_id]['user_bytes'] = 0 - msg = "Failed to load user quota: %r, error: %s" \ - % (user_quota_filepath, err) + accounting[client_id]["user_bytes"] = 0 + msg = "Failed to load user quota: %r, error: %s" % ( + user_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -394,27 +408,28 @@ def update_accounting(configuration, # Extract vgrid bytes for user 'client_id' for vgrid_name in owned_vgrid.get(client_id, []): - vgrid_quota_filepath = vgrid_quota_files.get(vgrid_name, '') + vgrid_quota_filepath = vgrid_quota_files.get(vgrid_name, "") if not os.path.exists(vgrid_quota_filepath): if verbose: # NOTE: Legacy vgrids are accounted at by top-vgrid - vgrid_array = vgrid_name.split('/') - legacy_vgrid = os.path.join(configuration.vgrid_files_home, - vgrid_name) - if not os.path.isdir(legacy_vgrid) \ - or len(vgrid_array) == 1: - msg = "Missing quota for vgrid: %r" \ - % vgrid_name + vgrid_array = vgrid_name.split("/") + legacy_vgrid = os.path.join( + configuration.vgrid_files_home, vgrid_name + ) + if not os.path.isdir(legacy_vgrid) or len(vgrid_array) == 1: + msg = "Missing quota for vgrid: %r" % vgrid_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue - if vgrid_quota_filepath.endswith('.pck'): + if vgrid_quota_filepath.endswith(".pck"): vgrid_quota = unpickle(vgrid_quota_filepath, configuration) - elif vgrid_quota_filepath.endswith('.json'): - vgrid_quota = load_json(vgrid_quota_filepath, - configuration.logger, - convert_utf8=False) + elif vgrid_quota_filepath.endswith(".json"): + vgrid_quota = load_json( + vgrid_quota_filepath, + configuration.logger, + convert_utf8=False, + ) else: msg = "Invalid vgrid quota file: %r" % vgrid_quota_filepath logger.error(msg) @@ -423,12 +438,15 @@ def update_accounting(configuration, retval = False continue try: - accounting[client_id]['vgrid_bytes'][vgrid_name] \ - = vgrid_quota['bytes'] + accounting[client_id]["vgrid_bytes"][vgrid_name] = vgrid_quota[ + "bytes" + ] except Exception as err: - accounting[client_id]['vgrid_bytes'][vgrid_name] = 0 - msg = "Failed to load vgrid quota: %r, error: %s" \ - % (vgrid_quota_filepath, err) + accounting[client_id]["vgrid_bytes"][vgrid_name] = 0 + msg = "Failed to load vgrid quota: %r, error: %s" % ( + vgrid_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -442,13 +460,15 @@ def update_accounting(configuration, for vgrid_name in vgrid_quota_files: if vgrid_name not in vgrids_accounted: - vgridowner = '' + vgridowner = "" for owner, owned_vgrids in owned_vgrid.items(): if vgrid_name in owned_vgrids: vgridowner = owner break - msg = "no accounting for vgrid: %r, missing owner?: %r" \ - % (vgrid_name, vgridowner) + msg = "no accounting for vgrid: %r, missing owner?: %r" % ( + vgrid_name, + vgridowner, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -457,24 +477,25 @@ def update_accounting(configuration, for freeze_name, freeze_quota_filepath in freeze_quota_files.items(): # Extract client_id from legacy freeze archive format - if freeze_name.startswith('archive-'): - legacy_freeze_meta_filepath \ - = os.path.join(configuration.freeze_home, - freeze_name, - 'meta.pck') - legacy_freeze_meta = unpickle(legacy_freeze_meta_filepath, - configuration.logger) + if freeze_name.startswith("archive-"): + legacy_freeze_meta_filepath = os.path.join( + configuration.freeze_home, freeze_name, "meta.pck" + ) + legacy_freeze_meta = unpickle( + legacy_freeze_meta_filepath, configuration.logger + ) if not legacy_freeze_meta: - msg = "Missing metadata for archive: %r" \ - % freeze_name + msg = "Missing metadata for archive: %r" % freeze_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue - client_id = legacy_freeze_meta.get('CREATOR', '') + client_id = legacy_freeze_meta.get("CREATOR", "") if not client_id: - msg = "Failed to extract client_id from: %r" \ - % legacy_freeze_meta_filepath + msg = ( + "Failed to extract client_id from: %r" + % legacy_freeze_meta_filepath + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -486,12 +507,12 @@ def update_accounting(configuration, # Load freeze quota freeze_bytes = 0 - if freeze_quota_filepath.endswith('.pck'): + if freeze_quota_filepath.endswith(".pck"): freeze_quota = unpickle(freeze_quota_filepath, configuration) - elif freeze_quota_filepath.endswith('.json'): - freeze_quota = load_json(freeze_quota_filepath, - configuration.logger, - convert_utf8=False) + elif freeze_quota_filepath.endswith(".json"): + freeze_quota = load_json( + freeze_quota_filepath, configuration.logger, convert_utf8=False + ) else: msg = "Invalid freeze quota file: %r" % freeze_quota_filepath logger.error(msg) @@ -500,11 +521,13 @@ def update_accounting(configuration, retval = False continue try: - freeze_bytes = int(freeze_quota['bytes']) + freeze_bytes = int(freeze_quota["bytes"]) except Exception as err: freeze_bytes = 0 - msg = "Failed to fetch freeze quota: %r, error: %s" \ - % (freeze_quota_filepath, err) + msg = "Failed to fetch freeze quota: %r, error: %s" % ( + freeze_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -512,24 +535,27 @@ def update_accounting(configuration, continue if freeze_bytes > 0: - freeze_accounting = accounting.get(client_id, '') + freeze_accounting = accounting.get(client_id, "") if not freeze_accounting: - msg = "added missing archive user: %r : %d" \ - % (client_id, freeze_bytes) + msg = "added missing archive user: %r : %d" % ( + client_id, + freeze_bytes, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) accounting[client_id] = __init_accounting_entry() freeze_accounting = accounting[client_id] - freeze_accounting['freeze_bytes'] += freeze_bytes + freeze_accounting["freeze_bytes"] += freeze_bytes # Save accounting result - accounting_filepath = os.path.join(configuration.accounting_home, - "%s.pck" % result['timestamp']) + accounting_filepath = os.path.join( + configuration.accounting_home, "%s.pck" % result["timestamp"] + ) status = pickle(result, accounting_filepath, configuration.logger) if status: - latest = os.path.join(configuration.accounting_home, 'latest') + latest = os.path.join(configuration.accounting_home, "latest") status = make_symlink(accounting_filepath, latest, logger, force=True) if not status: retval = False @@ -546,33 +572,26 @@ def human_readable_filesize(filesize): return "0 B" try: p = int(math.floor(math.log(filesize, 2) / 10)) - return "%.3f %s" % (filesize / math.pow(1024, p), - ['B', - 'KiB', - 'MiB', - 'GiB', - 'TiB', - 'PiB', - 'EiB', - 'ZiB', - 'YiB'][p]) + return "%.3f %s" % ( + filesize / math.pow(1024, p), + ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"][p], + ) except (ValueError, TypeError, IndexError): - return 'NaN' + return "NaN" -def get_usage(configuration, - userlist=[], - timestamp=0, - verbose=False): +def get_usage(configuration, userlist=[], timestamp=0, verbose=False): """Generate and return 'storage' usage""" # Load accounting if it exists logger = configuration.logger if timestamp == 0: - accounting_filepath = os.path.join(configuration.accounting_home, - "latest") + accounting_filepath = os.path.join( + configuration.accounting_home, "latest" + ) else: - accounting_filepath = os.path.join(configuration.accounting_home, - "%s.pck" % timestamp) + accounting_filepath = os.path.join( + configuration.accounting_home, "%s.pck" % timestamp + ) data = unpickle(accounting_filepath, configuration.logger) if not data: msg = "Failed to load accounting data from: %r" % accounting_filepath @@ -581,7 +600,7 @@ def get_usage(configuration, print("ERROR: %s" % msg) return None - accounting = data.get('accounting', {}) + accounting = data.get("accounting", {}) # Do not show external users as main accounts unless requested # or if the user act as both peer and external user @@ -590,10 +609,13 @@ def get_usage(configuration, peer_users = [] skip_ext_users = [] for values in accounting.values(): - ext_users.extend(list(values.get('ext_users', {}))) - peer_users.extend(list(values.get('peers', {}))) - skip_ext_users = [user for user in ext_users - if user not in userlist and user not in peer_users] + ext_users.extend(list(values.get("ext_users", {}))) + peer_users.extend(list(values.get("peers", {}))) + skip_ext_users = [ + user + for user in ext_users + if user not in userlist and user not in peer_users + ] # Create accounting report @@ -610,7 +632,7 @@ def get_usage(configuration, # Home usage - home_bytes = values.get('user_bytes', 0) + home_bytes = values.get("user_bytes", 0) total_bytes += home_bytes home_report = "" if create_reports: @@ -619,7 +641,7 @@ def get_usage(configuration, # Freeze archive usage - freeze_bytes = values.get('freeze_bytes', 0) + freeze_bytes = values.get("freeze_bytes", 0) total_bytes += freeze_bytes freeze_report = "" if create_reports and freeze_bytes > 0: @@ -630,32 +652,34 @@ def get_usage(configuration, vgrid_report = "" vgrid_total = 0 - for vgrid_name, vgrid_bytes in values.get('vgrid_bytes', {}).items(): + for vgrid_name, vgrid_bytes in values.get("vgrid_bytes", {}).items(): vgrid_total += vgrid_bytes if create_reports: vgrid_bytes_human = human_readable_filesize(vgrid_bytes) - vgrid_report += "\n - %s: %s" \ - % (vgrid_name, vgrid_bytes_human) + vgrid_report += "\n - %s: %s" % (vgrid_name, vgrid_bytes_human) if vgrid_report: - vgrid_report = "%s usage (total: %s)%s" \ - % (configuration.site_vgrid_label, - human_readable_filesize(vgrid_total), - vgrid_report) + vgrid_report = "%s usage (total: %s)%s" % ( + configuration.site_vgrid_label, + human_readable_filesize(vgrid_total), + vgrid_report, + ) total_bytes += vgrid_total # Create account usage entry - account_usage[username] = {'total_bytes': total_bytes, - 'home_total': home_bytes, - 'vgrid_total': vgrid_total, - 'freeze_total': freeze_bytes, - 'ext_users_total': 0, - 'total_report': '', - 'home_report': home_report, - 'freeze_report': freeze_report, - 'vgrid_report': vgrid_report, - 'ext_users_report': '', - 'peers_report': ''} + account_usage[username] = { + "total_bytes": total_bytes, + "home_total": home_bytes, + "vgrid_total": vgrid_total, + "freeze_total": freeze_bytes, + "ext_users_total": 0, + "total_report": "", + "home_report": home_report, + "freeze_report": freeze_report, + "vgrid_report": vgrid_report, + "ext_users_report": "", + "peers_report": "", + } # Create external users report # NOTE: We need total bytes and therefore we need the above full report @@ -667,13 +691,12 @@ def get_usage(configuration, if userlist and username not in userlist: continue # Create ext_users report - ext_users = values.get('ext_users', {}) - peers = values.get('peers', {}) + ext_users = values.get("ext_users", {}) + peers = values.get("peers", {}) if not ext_users: continue if ext_users and peers: - msg = "User %r acts as both peer and external user" \ - % username + msg = "User %r acts as both peer and external user" % username logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -683,20 +706,25 @@ def get_usage(configuration, ext_users_report = "" ext_users_total = 0 for ext_user in ext_users: - ext_user_total_bytes = account_usage.get( - ext_user, {}).get('total_bytes', 0) + ext_user_total_bytes = account_usage.get(ext_user, {}).get( + "total_bytes", 0 + ) ext_users_total += ext_user_total_bytes ext_user_total_bytes_human = human_readable_filesize( - ext_user_total_bytes) - ext_users_report += "\n - %s: %s" % (ext_user, - ext_user_total_bytes_human) + ext_user_total_bytes + ) + ext_users_report += "\n - %s: %s" % ( + ext_user, + ext_user_total_bytes_human, + ) if ext_users_report: - ext_users_report = "External users usage (total: %s):%s" \ - % (human_readable_filesize(ext_users_total), - ext_users_report) - account_usage[username]['ext_users_total'] = ext_users_total - account_usage[username]['ext_users_report'] = ext_users_report - account_usage[username]['total_bytes'] += ext_users_total + ext_users_report = "External users usage (total: %s):%s" % ( + human_readable_filesize(ext_users_total), + ext_users_report, + ) + account_usage[username]["ext_users_total"] = ext_users_total + account_usage[username]["ext_users_report"] = ext_users_report + account_usage[username]["total_bytes"] += ext_users_total # Create peers report @@ -705,22 +733,26 @@ def get_usage(configuration, peers_report += "\n - %s" % peer if peers_report: peers_report = "Accepted by the following peer:%s" % peers_report - account_usage[username]['peers_report'] = peers_report + account_usage[username]["peers_report"] = peers_report # Create total usage report for each user for usage in account_usage.values(): - usage['total_report'] = "Total usage: %s" \ - % human_readable_filesize(usage['total_bytes']) + usage["total_report"] = "Total usage: %s" % human_readable_filesize( + usage["total_bytes"] + ) # External users are accounted for by their peer # unless the external user also act as a peer result = {} - result['timestamp'] = data.get('timestamp', 0) - result['quota'] = data.get('quota', {}) - result['accounting'] = {username: values for username, values - in account_usage.items() - if not userlist or username in userlist - and username not in skip_ext_users} + result["timestamp"] = data.get("timestamp", 0) + result["quota"] = data.get("quota", {}) + result["accounting"] = { + username: values + for username, values in account_usage.items() + if not userlist + or username in userlist + and username not in skip_ext_users + } return result diff --git a/mig/lib/events.py b/mig/lib/events.py index 6b49b999c..c6716db5c 100644 --- a/mig/lib/events.py +++ b/mig/lib/events.py @@ -433,8 +433,7 @@ def run_cron_command( _restore_env(saved_environ, os.environ) raise exc logger.info( - "(%s) done running command for %s: %s" % ( - pid, target_path, command_str) + "(%s) done running command for %s: %s" % (pid, target_path, command_str) ) # logger.debug('(%s) raw output is: %s' % (pid, output_objects)) @@ -532,8 +531,7 @@ def run_events_command( _restore_env(saved_environ, os.environ) raise exc logger.info( - "(%s) done running command for %s: %s" % ( - pid, target_path, command_str) + "(%s) done running command for %s: %s" % (pid, target_path, command_str) ) # logger.debug('(%s) raw output is: %s' % (pid, output_objects)) diff --git a/mig/lib/janitor.py b/mig/lib/janitor.py index ea3e0051e..c18590226 100644 --- a/mig/lib/janitor.py +++ b/mig/lib/janitor.py @@ -36,8 +36,11 @@ import os import time -from mig.shared.accountreq import accept_account_req, existing_user_collision, \ - reject_account_req +from mig.shared.accountreq import ( + accept_account_req, + existing_user_collision, + reject_account_req, +) from mig.shared.base import get_user_id from mig.shared.fileio import delete_file, listdir from mig.shared.pwcrypto import verify_reset_token @@ -274,7 +277,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): _logger.info("%r made an invalid account request" % client_id) # NOTE: 'invalid' is a list of validation error strings if set reason = "invalid request: %s." % ". ".join(req_invalid) - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -284,21 +287,27 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject invalid %r account request: %s" % (client_id, - rej_err) + "failed to reject invalid %r account request: %s" + % (client_id, rej_err) ) else: _logger.info("rejected invalid %r account request" % client_id) elif authorized: - _logger.info("%r requested renew and authorized password change" % - client_id) + _logger.info( + "%r requested renew and authorized password change" % client_id + ) peer_id = user_dict.get("peers", [None])[0] # NOTE: let authorized reqs (with valid peer) renew even with pw change default_renew = True - if accept_account_req(req_id, configuration, peer_id, - user_copy=user_copy, admin_copy=admin_copy, - auth_type=auth_type, - default_renew=default_renew): + if accept_account_req( + req_id, + configuration, + peer_id, + user_copy=user_copy, + admin_copy=admin_copy, + auth_type=auth_type, + default_renew=default_renew, + ): _logger.info("accepted authorized %r access renew" % client_id) else: _logger.warning("failed authorized %r access renew" % client_id) @@ -313,7 +322,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): "%r requested and authorized password reset" % client_id ) peer_id = user_dict.get("peers", [None])[0] - (acc_status, acc_err) = accept_account_req( + acc_status, acc_err = accept_account_req( req_id, configuration, peer_id, @@ -324,18 +333,18 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not acc_status: _logger.warning( - "failed to accept %r password reset: %s" % (client_id, - acc_err) + "failed to accept %r password reset: %s" + % (client_id, acc_err) ) else: _logger.info("accepted %r password reset" % client_id) else: _logger.warning( - "%r requested password reset with bad token: %s" % ( - client_id, reset_token) + "%r requested password reset with bad token: %s" + % (client_id, reset_token) ) reason = "invalid password reset token" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -345,8 +354,8 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject %r password reset: %s" % (client_id, - rej_err) + "failed to reject %r password reset: %s" + % (client_id, rej_err) ) else: _logger.info("rejected %r password reset" % client_id) @@ -354,7 +363,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): # NOTE: probably should no longer happen after initial auto clean _logger.warning("%r request is now past expire" % client_id) reason = "expired request - please re-request if still relevant" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -363,15 +372,15 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): auth_type=auth_type, ) if not rej_status: - _logger.warning("failed to reject expired %r request: %s" % - (client_id, rej_err) - ) + _logger.warning( + "failed to reject expired %r request: %s" % (client_id, rej_err) + ) else: _logger.info("rejected %r request now past expire" % client_id) elif existing_user_collision(configuration, req_dict, client_id): _logger.warning("ID collision in request from %r" % client_id) reason = "ID collision - please re-request with *existing* ID fields" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -381,8 +390,8 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject %r request with ID collision: %s" % - (client_id, rej_err) + "failed to reject %r request with ID collision: %s" + % (client_id, rej_err) ) else: _logger.info("rejected %r request with ID collision" % client_id) @@ -417,8 +426,7 @@ def manage_trivial_user_requests(configuration, now=None): continue req_id = filename req_path = os.path.join(configuration.user_pending, req_id) - _logger.debug("checking if account request in %r is trivial" % - req_path) + _logger.debug("checking if account request in %r is trivial" % req_path) req_age = now - os.path.getmtime(req_path) req_age_minutes = req_age / SECS_PER_MINUTE if req_age_minutes > MANAGE_TRIVIAL_REQ_MINUTES: @@ -428,8 +436,7 @@ def manage_trivial_user_requests(configuration, now=None): ) manage_single_req(configuration, req_id, req_path, db_path, now) handled += 1 - _logger.debug("handled %d trivial user account request action(s)" % - handled) + _logger.debug("handled %d trivial user account request action(s)" % handled) return handled @@ -474,7 +481,7 @@ def remind_and_expire_user_pending(configuration, now=None): ) user_copy = True admin_copy = True - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -483,11 +490,12 @@ def remind_and_expire_user_pending(configuration, now=None): auth_type=auth_type, ) if not rej_status: - _logger.warning("failed to expire %s request from %r: %s" % - (req_id, client_id, rej_err)) + _logger.warning( + "failed to expire %s request from %r: %s" + % (req_id, client_id, rej_err) + ) else: - _logger.info("expired %s request from %r" % (req_id, - client_id)) + _logger.info("expired %s request from %r" % (req_id, client_id)) handled += 1 _logger.debug("handled %d user account request action(s)" % handled) return handled diff --git a/mig/lib/lustrequota.py b/mig/lib/lustrequota.py index ea669b936..695461ebc 100644 --- a/mig/lib/lustrequota.py +++ b/mig/lib/lustrequota.py @@ -41,13 +41,24 @@ psutil = None from mig.shared.base import force_unicode -from mig.shared.fileio import make_symlink, makedirs_rec, pickle, save_json, \ - scandir, unpickle, walk, write_file +from mig.shared.fileio import ( + make_symlink, + makedirs_rec, + pickle, + save_json, + scandir, + unpickle, + walk, + write_file, +) from mig.shared.vgrid import vgrid_flat_name try: - from lustreclient.lfs import lfs_get_project_quota, lfs_set_project_id, \ - lfs_set_project_quota + from lustreclient.lfs import ( + lfs_get_project_quota, + lfs_set_project_id, + lfs_set_project_quota, + ) except ImportError: lfs_set_project_id = None lfs_get_project_quota = None @@ -63,16 +74,19 @@ def __get_lustre_basepath(configuration, lustre_basepath=None): valid_lustre_basepath = None for dpart in psutil.disk_partitions(all=True): if dpart.fstype == "lustre": - if lustre_basepath \ - and lustre_basepath.startswith(dpart.mountpoint) \ - and os.path.isdir(lustre_basepath): + if ( + lustre_basepath + and lustre_basepath.startswith(dpart.mountpoint) + and os.path.isdir(lustre_basepath) + ): valid_lustre_basepath = lustre_basepath break elif dpart.mountpoint.endswith(configuration.server_fqdn): valid_lustre_basepath = dpart.mountpoint else: - check_lustre_basepath = os.path.join(dpart.mountpoint, - configuration.server_fqdn) + check_lustre_basepath = os.path.join( + dpart.mountpoint, configuration.server_fqdn + ) if os.path.isdir(check_lustre_basepath): valid_lustre_basepath = check_lustre_basepath break @@ -85,8 +99,9 @@ def __get_gocryptfs_socket(configuration, gocryptfs_sock=None): otherwise return default if it exists""" valid_gocryptfs_sock = None if gocryptfs_sock is None: - gocryptfs_sock = "/var/run/gocryptfs.%s.sock" \ - % configuration.server_fqdn + gocryptfs_sock = ( + "/var/run/gocryptfs.%s.sock" % configuration.server_fqdn + ) if os.path.exists(gocryptfs_sock): gocryptfs_sock_stat = os.lstat(gocryptfs_sock) if stat.S_ISSOCK(gocryptfs_sock_stat.st_mode): @@ -95,12 +110,14 @@ def __get_gocryptfs_socket(configuration, gocryptfs_sock=None): return valid_gocryptfs_sock -def __shellexec(configuration, - command, - args=[], - stdin_str=None, - stdout_filepath=None, - stderr_filepath=None): +def __shellexec( + configuration, + command, + args=[], + stdin_str=None, + stdout_filepath=None, + stderr_filepath=None, +): """Execute shell command Returns (exit_code, stdout, stderr) of subprocess""" result = 0 @@ -116,10 +133,8 @@ def __shellexec(configuration, __args.extend(args) logger.debug("__args: %s" % __args) process = subprocess.Popen( - __args, - stdin=stdin_handle, - stdout=stdout_handle, - stderr=stderr_handle) + __args, stdin=stdin_handle, stdout=stdout_handle, stderr=stderr_handle + ) if stdin_str: process.stdin.write(stdin_str.encode()) stdout, stderr = process.communicate() @@ -145,28 +160,26 @@ def __shellexec(configuration, if stderr: stderr = force_unicode(stderr) if result == 0: - logger.debug("%s %s: rc: %s, stdout: %s, error: %s" - % (command, - " ".join(args), - rc, - stdout, - stderr)) + logger.debug( + "%s %s: rc: %s, stdout: %s, error: %s" + % (command, " ".join(args), rc, stdout, stderr) + ) else: - logger.error("shellexec: %s %s: rc: %s, stdout: %s, error: %s" - % (command, - " ".join(__args), - rc, - stdout, - stderr)) + logger.error( + "shellexec: %s %s: rc: %s, stdout: %s, error: %s" + % (command, " ".join(__args), rc, stdout, stderr) + ) return (rc, stdout, stderr) -def __set_project_id(configuration, - lustre_basepath, - quota_datapath, - quota_name, - quota_lustre_pid): +def __set_project_id( + configuration, + lustre_basepath, + quota_datapath, + quota_name, + quota_lustre_pid, +): """Set lustre project *quota_lustre_pid* Find the next *free* project id (PID) if *quota_lustre_pid* is occupied NOTE: lustre uses a global counter for project id's (PID) @@ -181,19 +194,22 @@ def __set_project_id(configuration, logger = configuration.logger next_lustre_pid = quota_lustre_pid while next_lustre_pid < max_lustre_pid: - (rc, currfiles, _, _, _) \ - = lfs_get_project_quota(lustre_basepath, next_lustre_pid) + rc, currfiles, _, _, _ = lfs_get_project_quota( + lustre_basepath, next_lustre_pid + ) if rc != 0: - logger.error("Failed to fetch quota for lustre project id: %d, %r" - % (next_lustre_pid, lustre_basepath) - + ", rc: %d" % rc) + logger.error( + "Failed to fetch quota for lustre project id: %d, %r" + % (next_lustre_pid, lustre_basepath) + + ", rc: %d" % rc + ) return -1 if currfiles == 0: break - logger.info("Skipping project id: %d" - % next_lustre_pid - + " already registered with %d files" - % currfiles) + logger.info( + "Skipping project id: %d" % next_lustre_pid + + " already registered with %d files" % currfiles + ) next_lustre_pid += 1 if next_lustre_pid == max_lustre_pid: @@ -202,22 +218,28 @@ def __set_project_id(configuration, # Set new project id - logger.info("Setting lustre project id: %d for %r: %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.info( + "Setting lustre project id: %d for %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) rc = lfs_set_project_id(quota_datapath, next_lustre_pid, 1) if rc != 0: - logger.error("lfs_set_project_id failed for lustre project id: %d for %r: %r" - % (next_lustre_pid, quota_name, quota_datapath) - + ", rc: %d" % rc) + logger.error( + "lfs_set_project_id failed for lustre project id: %d for %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + + ", rc: %d" % rc + ) return -1 # Dump lustre pid in quota_datapath and wait for it to appear in the quota - lustre_pid_filepath = os.path.join(quota_datapath, '.lustrepid') + lustre_pid_filepath = os.path.join(quota_datapath, ".lustrepid") status = write_file(next_lustre_pid, lustre_pid_filepath, logger) if not status: - logger.error("Failed write lustre project id: %d for %r to %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed write lustre project id: %d for %r to %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) return -1 # Wait for files to appear in quota before returning @@ -226,36 +248,44 @@ def __set_project_id(configuration, waiting = 0 max_waiting = 60 while files == 0 and waiting < max_waiting: - (rc, files, _, _, _) \ - = lfs_get_project_quota(lustre_basepath, next_lustre_pid) + rc, files, _, _, _ = lfs_get_project_quota( + lustre_basepath, next_lustre_pid + ) if rc != 0: files = 0 - logger.error("lfs_get_project_quota failed for:" - + " %d, %r, %r, rc: %d" - % (next_lustre_pid, quota_name, quota_datapath, rc)) + logger.error( + "lfs_get_project_quota failed for:" + + " %d, %r, %r, rc: %d" + % (next_lustre_pid, quota_name, quota_datapath, rc) + ) if files == 0: - logger.info("Waiting for lustre quota: %d: %r: %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.info( + "Waiting for lustre quota: %d: %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) time.sleep(1) max_waiting += 1 if waiting == max_waiting: - logger.error("Failed to fetch quota for:" - + " %d, %r, %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed to fetch quota for:" + + " %d, %r, %r" % (next_lustre_pid, quota_name, quota_datapath) + ) return -1 return next_lustre_pid -def __update_quota(configuration, - lustre_basepath, - lustre_setting, - quota_name, - quota_type, - data_basefs, - gocryptfs_sock, - timestamp): +def __update_quota( + configuration, + lustre_basepath, + lustre_setting, + quota_name, + quota_type, + data_basefs, + gocryptfs_sock, + timestamp, +): """Update quota for *quota_name*, if new entry then assign lustre project id and set default quota. If existing entry then update quota settings if changed @@ -263,15 +293,17 @@ def __update_quota(configuration, """ logger = configuration.logger quota_limits_changed = False - next_lustre_pid = lustre_setting.get('next_pid', -1) + next_lustre_pid = lustre_setting.get("next_pid", -1) if next_lustre_pid == -1: - logger.error("Invalid lustre quota next_pid: %d for: %r" - % (next_lustre_pid, quota_name)) + logger.error( + "Invalid lustre quota next_pid: %d for: %r" + % (next_lustre_pid, quota_name) + ) return False # Resolve quota limit and data basepath - if quota_type == 'vgrid': + if quota_type == "vgrid": default_quota_limit = configuration.quota_vgrid_limit data_basepath = configuration.vgrid_files_writable else: @@ -279,11 +311,13 @@ def __update_quota(configuration, data_basepath = configuration.user_home if data_basepath.startswith(configuration.state_path): - rel_data_basepath = data_basepath. \ - replace(configuration.state_path, "").lstrip(os.sep) + rel_data_basepath = data_basepath.replace( + configuration.state_path, "" + ).lstrip(os.sep) else: - logger.error("Failed to resolve relative data basepath from: %r" - % data_basepath) + logger.error( + "Failed to resolve relative data basepath from: %r" % data_basepath + ) return False # Resolve quota data path @@ -291,37 +325,38 @@ def __update_quota(configuration, if configuration.quota_backend == "lustre": quota_basefs = "lustre" - quota_datapath = os.path.join(lustre_basepath, - rel_data_basepath, - quota_name) + quota_datapath = os.path.join( + lustre_basepath, rel_data_basepath, quota_name + ) elif configuration.quota_backend == "lustre-gocryptfs": quota_basefs = "fuse.gocryptfs" stdin_str = os.path.join(rel_data_basepath, quota_name) cmd = "gocryptfs-xray -encrypt-paths %s" % gocryptfs_sock - (rc, stdout, stderr) = __shellexec(configuration, - cmd, - stdin_str=stdin_str) + rc, stdout, stderr = __shellexec( + configuration, cmd, stdin_str=stdin_str + ) if rc == 0 and stdout: encoded_path = stdout.strip() - quota_datapath = os.path.join(lustre_basepath, - encoded_path) + quota_datapath = os.path.join(lustre_basepath, encoded_path) else: - logger.error("Failed to resolve encrypted path for: %r" - % quota_name - + ", rc: %d, error: %s" - % (rc, stderr)) + logger.error( + "Failed to resolve encrypted path for: %r" % quota_name + + ", rc: %d, error: %s" % (rc, stderr) + ) return False else: - logger.error("Invalid quota backend: %r" - % configuration.quota_backend) + logger.error("Invalid quota backend: %r" % configuration.quota_backend) return False # Check if valid lustre data dir if not os.path.isdir(quota_datapath): - msg = "skipping entry: %r : %r, no lustre data path: %r" \ - % (quota_type, quota_name, quota_datapath) + msg = "skipping entry: %r : %r, no lustre data path: %r" % ( + quota_type, + quota_name, + quota_datapath, + ) # NOTE: log error and return false if dir is missing # and we expect data to be on lustre or gocryoptfs) if data_basefs == quota_basefs: @@ -335,143 +370,170 @@ def __update_quota(configuration, # Load quota if it exists otherwise new quota - quota_filepath = os.path.join(configuration.quota_home, - configuration.quota_backend, - quota_type, - "%s.pck" % quota_name) + quota_filepath = os.path.join( + configuration.quota_home, + configuration.quota_backend, + quota_type, + "%s.pck" % quota_name, + ) if os.path.exists(quota_filepath): quota = unpickle(quota_filepath, logger) if not quota: - logger.error("Failed to load quota settings for: %r from %r" - % (quota_name, quota_filepath)) + logger.error( + "Failed to load quota settings for: %r from %r" + % (quota_name, quota_filepath) + ) return False else: - quota = {'lustre_pid': next_lustre_pid, - 'files': -1, - 'bytes': -1, - 'softlimit_bytes': -1, - 'hardlimit_bytes': -1, - } + quota = { + "lustre_pid": next_lustre_pid, + "files": -1, + "bytes": -1, + "softlimit_bytes": -1, + "hardlimit_bytes": -1, + } # Fetch quota lustre pid - quota_lustre_pid = quota.get('lustre_pid', -1) + quota_lustre_pid = quota.get("lustre_pid", -1) if quota_lustre_pid == -1: - logger.error("Invalid quota lustre pid: %d for %r" - % (quota_lustre_pid, quota_name)) + logger.error( + "Invalid quota lustre pid: %d for %r" + % (quota_lustre_pid, quota_name) + ) return False # If new entry then set lustre project id new_lustre_pid = -1 if quota_lustre_pid == next_lustre_pid: - new_lustre_pid = __set_project_id(configuration, - lustre_basepath, - quota_datapath, - quota_name, - quota_lustre_pid) + new_lustre_pid = __set_project_id( + configuration, + lustre_basepath, + quota_datapath, + quota_name, + quota_lustre_pid, + ) if new_lustre_pid == -1: - logger.error("Failed to set project id: %d, %r, %r" - % (new_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed to set project id: %d, %r, %r" + % (new_lustre_pid, quota_name, quota_datapath) + ) return False - lustre_setting['next_pid'] = new_lustre_pid + 1 - quota['lustre_pid'] = quota_lustre_pid = new_lustre_pid + lustre_setting["next_pid"] = new_lustre_pid + 1 + quota["lustre_pid"] = quota_lustre_pid = new_lustre_pid # Get current quota values for lustre_pid - (rc, currfiles, currbytes, softlimit_bytes, hardlimit_bytes) \ - = lfs_get_project_quota(lustre_basepath, quota_lustre_pid) + rc, currfiles, currbytes, softlimit_bytes, hardlimit_bytes = ( + lfs_get_project_quota(lustre_basepath, quota_lustre_pid) + ) if rc != 0: - logger.error("lfs_get_project_quota failed for: %d, %r, %r" - % (quota_lustre_pid, quota_name, quota_datapath) - + ", rc: %d" % rc) + logger.error( + "lfs_get_project_quota failed for: %d, %r, %r" + % (quota_lustre_pid, quota_name, quota_datapath) + + ", rc: %d" % rc + ) return False # Update quota info if currfiles == 0 or currbytes == 0: - logger.warning("lustre_basepath: %r: pid: %d: quota_type: %s" - % (lustre_basepath, quota_lustre_pid, quota_type) - + "quota_name: %s, files: %d, bytes: %d" - % (quota_name, currfiles, currbytes)) + logger.warning( + "lustre_basepath: %r: pid: %d: quota_type: %s" + % (lustre_basepath, quota_lustre_pid, quota_type) + + "quota_name: %s, files: %d, bytes: %d" + % (quota_name, currfiles, currbytes) + ) - quota['mtime'] = timestamp - quota['files'] = currfiles - quota['bytes'] = currbytes + quota["mtime"] = timestamp + quota["files"] = currfiles + quota["bytes"] = currbytes # If new entry use default quota # and update quota if changed if new_lustre_pid > -1: quota_limits_changed = True - quota['softlimit_bytes'] = default_quota_limit - quota['hardlimit_bytes'] = default_quota_limit - elif hardlimit_bytes != quota.get('hardlimit_bytes', -1) \ - or softlimit_bytes != quota.get('softlimit_bytes', -1): + quota["softlimit_bytes"] = default_quota_limit + quota["hardlimit_bytes"] = default_quota_limit + elif hardlimit_bytes != quota.get( + "hardlimit_bytes", -1 + ) or softlimit_bytes != quota.get("softlimit_bytes", -1): quota_limits_changed = True - quota['softlimit_bytes'] = softlimit_bytes - quota['hardlimit_bytes'] = hardlimit_bytes + quota["softlimit_bytes"] = softlimit_bytes + quota["hardlimit_bytes"] = hardlimit_bytes if quota_limits_changed: - rc = lfs_set_project_quota(quota_datapath, - quota_lustre_pid, - quota['softlimit_bytes'], - quota['hardlimit_bytes'], - ) + rc = lfs_set_project_quota( + quota_datapath, + quota_lustre_pid, + quota["softlimit_bytes"], + quota["hardlimit_bytes"], + ) if rc != 0: - logger.error("Failed to set quota limit: %d/%d" - % (softlimit_bytes, - hardlimit_bytes) - + " for lustre project id: %d, %r, %r, rc: %d" - % (quota_lustre_pid, - quota_name, - quota_datapath, - rc)) + logger.error( + "Failed to set quota limit: %d/%d" + % (softlimit_bytes, hardlimit_bytes) + + " for lustre project id: %d, %r, %r, rc: %d" + % (quota_lustre_pid, quota_name, quota_datapath, rc) + ) return False # Save current quota - new_quota_basepath = os.path.join(configuration.quota_home, - configuration.quota_backend, - quota_type, - str(timestamp)) - if not os.path.exists(new_quota_basepath) \ - and not makedirs_rec(new_quota_basepath, configuration): - logger.error("Failed to create new quota base path: %r" - % new_quota_basepath) + new_quota_basepath = os.path.join( + configuration.quota_home, + configuration.quota_backend, + quota_type, + str(timestamp), + ) + if not os.path.exists(new_quota_basepath) and not makedirs_rec( + new_quota_basepath, configuration + ): + logger.error( + "Failed to create new quota base path: %r" % new_quota_basepath + ) return False - new_quota_filepath_pck = os.path.join(new_quota_basepath, - "%s.pck" % quota_name) + new_quota_filepath_pck = os.path.join( + new_quota_basepath, "%s.pck" % quota_name + ) - logger.debug("Saving: %s: %s: %s -> %r" - % (quota_type, quota_name, quota, new_quota_filepath_pck)) + logger.debug( + "Saving: %s: %s: %s -> %r" + % (quota_type, quota_name, quota, new_quota_filepath_pck) + ) status = pickle(quota, new_quota_filepath_pck, logger) if not status: - logger.error("Failed to save quota for: %r to %r" - % (quota_name, new_quota_filepath_pck)) + logger.error( + "Failed to save quota for: %r to %r" + % (quota_name, new_quota_filepath_pck) + ) return False - new_quota_filepath_json = os.path.join(new_quota_basepath, - "%s.json" % quota_name) - status = save_json(quota, - new_quota_filepath_json, - logger) + new_quota_filepath_json = os.path.join( + new_quota_basepath, "%s.json" % quota_name + ) + status = save_json(quota, new_quota_filepath_json, logger) if not status: - logger.error("Failed to save quota for: %r to %r" - % (quota_name, new_quota_filepath_json)) + logger.error( + "Failed to save quota for: %r to %r" + % (quota_name, new_quota_filepath_json) + ) return False # Create symlink to new quota - status = make_symlink(new_quota_filepath_pck, - quota_filepath, - logger, - force=True) + status = make_symlink( + new_quota_filepath_pck, quota_filepath, logger, force=True + ) if not status: - logger.error("Failed to make quota symlink for: %r: %r -> %r" - % (quota_name, new_quota_filepath_pck, quota_filepath)) + logger.error( + "Failed to make quota symlink for: %r: %r -> %r" + % (quota_name, new_quota_filepath_pck, quota_filepath) + ) return False return True @@ -483,9 +545,11 @@ def update_lustre_quota(configuration): # Check if lustreclient module was imported correctly - if lfs_set_project_id is None \ - or lfs_get_project_quota is None \ - or lfs_set_project_quota is None: + if ( + lfs_set_project_id is None + or lfs_get_project_quota is None + or lfs_set_project_quota is None + ): logger.error("Failed to import lustreclient module") return False @@ -496,11 +560,11 @@ def update_lustre_quota(configuration): lustre_basepath = __get_lustre_basepath(configuration) if lustre_basepath: - logger.debug("Using lustre basepath: %r" - % lustre_basepath) + logger.debug("Using lustre basepath: %r" % lustre_basepath) else: - logger.error("Found no valid lustre mounts for: %s" - % configuration.server_fqdn) + logger.error( + "Found no valid lustre mounts for: %s" % configuration.server_fqdn + ) return False # Get gocryptfs socket if enabled @@ -509,37 +573,38 @@ def update_lustre_quota(configuration): if configuration.quota_backend == "lustre-gocryptfs": gocryptfs_sock = __get_gocryptfs_socket(configuration) if gocryptfs_sock: - logger.debug("Using gocryptfs socket: %r" - % gocryptfs_sock) + logger.debug("Using gocryptfs socket: %r" % gocryptfs_sock) else: logger.error("Missing gocryptfs socket") return False # Load lustre quota settings - lustre_setting_filepath = os.path.join(configuration.quota_home, - '%s.pck' - % configuration.quota_backend) + lustre_setting_filepath = os.path.join( + configuration.quota_home, "%s.pck" % configuration.quota_backend + ) if os.path.exists(lustre_setting_filepath): - lustre_setting = unpickle(lustre_setting_filepath, - logger) + lustre_setting = unpickle(lustre_setting_filepath, logger) if not lustre_setting: - logger.error("Failed to load lustre quota: %r" - % lustre_setting_filepath) + logger.error( + "Failed to load lustre quota: %r" % lustre_setting_filepath + ) return False else: - lustre_setting = {'next_pid': 1, - 'mtime': 0} + lustre_setting = {"next_pid": 1, "mtime": 0} # Update quota - quota_targets = {'vgrid': {'basefs': 'lustre', - 'entries': {}, - }, - 'user': {'basefs': 'lustre', - 'entries': {}, - }, - } + quota_targets = { + "vgrid": { + "basefs": "lustre", + "entries": {}, + }, + "user": { + "basefs": "lustre", + "entries": {}, + }, + } # Resolve basefs if possible @@ -547,26 +612,29 @@ def update_lustre_quota(configuration): mountpoint = dpart.mountpoint.rstrip(os.sep) fstype = dpart.fstype if mountpoint == configuration.vgrid_files_writable.rstrip(os.sep): - logger.debug("Found basefs for vgrid data: %r : %r" - % (mountpoint, fstype)) - quota_targets['vgrid']['basefs'] = fstype + logger.debug( + "Found basefs for vgrid data: %r : %r" % (mountpoint, fstype) + ) + quota_targets["vgrid"]["basefs"] = fstype if mountpoint == configuration.user_home.rstrip(os.sep): - logger.debug("Found basefs for user data: %r : %r" - % (mountpoint, fstype)) - quota_targets['user']['basefs'] = fstype + logger.debug( + "Found basefs for user data: %r : %r" % (mountpoint, fstype) + ) + quota_targets["user"]["basefs"] = fstype # Resolve vgrids and sub-vgrids for root, dirs, _ in walk(configuration.vgrid_home, topdown=True): for dirent in dirs: vgrid_dirpath = os.path.join(root, dirent) - owners_filepath = os.path.join(vgrid_dirpath, 'owners') + owners_filepath = os.path.join(vgrid_dirpath, "owners") if os.path.isfile(owners_filepath): vgrid = vgrid_flat_name( - vgrid_dirpath[len(configuration.vgrid_home):], - configuration) + vgrid_dirpath[len(configuration.vgrid_home) :], + configuration, + ) logger.debug("Found vgrid: %r" % vgrid) - quota_targets['vgrid']['entries'][vgrid] = 1 + quota_targets["vgrid"]["entries"][vgrid] = 1 # Resolve users @@ -576,45 +644,46 @@ def update_lustre_quota(configuration): userhome = os.readlink(entry.path) # NOTE: Relative links are prefixed with 'user_home' if not userhome.startswith(os.sep): - userhome = os.path.join(configuration.user_home, - userhome) + userhome = os.path.join(configuration.user_home, userhome) else: userhome = entry.path if os.path.isdir(userhome): user = os.path.basename(userhome) else: - logger.debug("skipping non-userhome: %r (%r)" - % (userhome, entry.path)) + logger.debug( + "skipping non-userhome: %r (%r)" % (userhome, entry.path) + ) continue # NOTE: Multiple links might point to same user - quota_targets['user']['entries'][user] = True + quota_targets["user"]["entries"][user] = True # Update quotas for quota_type in quota_targets: target = quota_targets.get(quota_type, {}) - data_basefs = target.get('basefs', 'lustre') - quota_entries = target.get('entries', {}) + data_basefs = target.get("basefs", "lustre") + quota_entries = target.get("entries", {}) for quota_entry in quota_entries: - status = __update_quota(configuration, - lustre_basepath, - lustre_setting, - quota_entry, - quota_type, - data_basefs, - gocryptfs_sock, - timestamp) + status = __update_quota( + configuration, + lustre_basepath, + lustre_setting, + quota_entry, + quota_type, + data_basefs, + gocryptfs_sock, + timestamp, + ) if not status: retval = False # Save updated lustre quota settings - lustre_setting['mtime'] = timestamp - status = pickle(lustre_setting, - lustre_setting_filepath, - logger) + lustre_setting["mtime"] = timestamp + status = pickle(lustre_setting, lustre_setting_filepath, logger) if not status: - logger.error("Failed to save lustra quota settings: %r" - % lustre_setting_filepath) + logger.error( + "Failed to save lustra quota settings: %r" % lustre_setting_filepath + ) return retval diff --git a/mig/lib/quota.py b/mig/lib/quota.py index e93830d28..85b795f2c 100644 --- a/mig/lib/quota.py +++ b/mig/lib/quota.py @@ -30,20 +30,22 @@ from mig.lib.lustrequota import update_lustre_quota - -supported_quota_backends = ['lustre', 'lustre-gocryptfs'] +supported_quota_backends = ["lustre", "lustre-gocryptfs"] def update_quota(configuration): """Update quota for users and vgrids""" retval = False logger = configuration.logger - if configuration.quota_backend == 'lustre' \ - or configuration.quota_backend == 'lustre-gocryptfs': + if ( + configuration.quota_backend == "lustre" + or configuration.quota_backend == "lustre-gocryptfs" + ): retval = update_lustre_quota(configuration) else: - logger.error("quota_backend: %r not in supported_quota_backends: %r" - % (configuration.quota_backend, - supported_quota_backends)) + logger.error( + "quota_backend: %r not in supported_quota_backends: %r" + % (configuration.quota_backend, supported_quota_backends) + ) return retval diff --git a/tests/__init__.py b/tests/__init__.py index bcec2ab8a..d1e480f4d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,10 +1,14 @@ def _print_identity(): import os import sys - python_version_string = sys.version.split(' ')[0] - mig_env = os.environ.get('MIG_ENV', 'local') - print("running with MIG_ENV='%s' under Python %s" % - (mig_env, python_version_string)) + + python_version_string = sys.version.split(" ")[0] + mig_env = os.environ.get("MIG_ENV", "local") + print( + "running with MIG_ENV='%s' under Python %s" + % (mig_env, python_version_string) + ) print("") + _print_identity() diff --git a/tests/support/__init__.py b/tests/support/__init__.py index eb2c8019f..da75f1b7c 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -28,8 +28,6 @@ """Supporting functions for the unit test framework""" -from collections import defaultdict -from configparser import ConfigParser import difflib import errno import io @@ -40,34 +38,40 @@ import shutil import stat import sys +from collections import defaultdict +from configparser import ConfigParser from types import SimpleNamespace -from unittest import TestCase, main as testmain +from unittest import TestCase +from unittest import main as testmain +from tests.support._env import MIG_ENV, PY2 from tests.support.configsupp import FakeConfiguration from tests.support.fixturesupp import _PreparedFixture -from tests.support.suppconst import MIG_BASE, TEST_BASE, \ - TEST_DATA_DIR, TEST_OUTPUT_DIR, ENVHELP_OUTPUT_DIR +from tests.support.suppconst import ( + ENVHELP_OUTPUT_DIR, + MIG_BASE, + TEST_BASE, + TEST_DATA_DIR, + TEST_OUTPUT_DIR, +) from tests.support.usersupp import UserAssertMixin -from tests.support._env import MIG_ENV, PY2 - - # Provide access to a configuration file for the active environment. -if MIG_ENV in ('local', 'docker'): +if MIG_ENV in ("local", "docker"): # force local testconfig - _output_dir = os.path.join(MIG_BASE, 'envhelp/output') + _output_dir = os.path.join(MIG_BASE, "envhelp/output") _conf_dir_name = "testconfs-%s" % (MIG_ENV,) _conf_dir = os.path.join(_output_dir, _conf_dir_name) - _local_conf = os.path.join(_conf_dir, 'MiGserver.conf') - _config_file = os.getenv('MIG_CONF', None) + _local_conf = os.path.join(_conf_dir, "MiGserver.conf") + _config_file = os.getenv("MIG_CONF", None) if _config_file is None: - os.environ['MIG_CONF'] = _local_conf + os.environ["MIG_CONF"] = _local_conf # adjust the link through which confs are accessed to suit the environment - _conf_link = os.path.join(_output_dir, 'testconfs') + _conf_link = os.path.join(_output_dir, "testconfs") assert os.path.lexists(_conf_link) # it must already exist - os.remove(_conf_link) # blow it away + os.remove(_conf_link) # blow it away os.symlink(_conf_dir, _conf_link) # recreate it using the active MIG_BASE else: raise NotImplementedError() @@ -97,7 +101,6 @@ from tests.support.loggersupp import FakeLogger, FakeLoggerChecker from tests.support.serversupp import make_wrapped_server - # Basic global logging configuration for testing @@ -181,7 +184,7 @@ def tearDown(self): @classmethod def tearDownClass(cls): - if MIG_ENV == 'docker': + if MIG_ENV == "docker": # the permissions story wrt running inside docker containers is # such that we can end up with files from previous test runs left # around that might subsequently cause spurious permissions errors @@ -209,20 +212,24 @@ def _reset_logging(self, stream): @staticmethod def _make_configuration_instance(testcase, configuration_to_make): - if configuration_to_make == 'fakeconfig': + if configuration_to_make == "fakeconfig": return FakeConfiguration(logger=testcase.logger) - elif configuration_to_make == 'testconfig': + elif configuration_to_make == "testconfig": from mig.shared.conf import get_configuration_object - configuration = get_configuration_object(skip_log=True, - disable_auth_log=True) + + configuration = get_configuration_object( + skip_log=True, disable_auth_log=True + ) configuration.logger = testcase.logger return configuration else: raise AssertionError( - "MigTestCase: unknown configuration %r" % (configuration_to_make,)) + "MigTestCase: unknown configuration %r" + % (configuration_to_make,) + ) def _provide_configuration(self): - return 'unspecified' + return "unspecified" @property def configuration(self): @@ -233,14 +240,16 @@ def configuration(self): configuration_to_make = self._provide_configuration() - if configuration_to_make == 'unspecified': + if configuration_to_make == "unspecified": raise AssertionError( - "configuration access but testcase did not request it") + "configuration access but testcase did not request it" + ) configuration_instance = self._make_configuration_instance( - self, configuration_to_make) + self, configuration_to_make + ) - if configuration_to_make == 'testconfig': + if configuration_to_make == "testconfig": # use the paths defined by the loaded configuration to create # the directories which are expected to be present by the code os.mkdir(configuration_instance.certs_path) @@ -274,7 +283,8 @@ def assertDirEmpty(self, relative_path): """Make sure the supplied path is an empty directory""" path_kind = self.assertPathExists(relative_path) assert path_kind == "dir", "expected a directory but found %s" % ( - path_kind, ) + path_kind, + ) absolute_path = os.path.join(TEST_OUTPUT_DIR, relative_path) entries = os.listdir(absolute_path) assert not entries, "directory is not empty" @@ -283,7 +293,8 @@ def assertDirNotEmpty(self, relative_path): """Make sure the supplied path is a non-empty directory""" path_kind = self.assertPathExists(relative_path) assert path_kind == "dir", "expected a directory but found %s" % ( - path_kind, ) + path_kind, + ) absolute_path = os.path.join(TEST_OUTPUT_DIR, relative_path) entries = os.listdir(absolute_path) assert entries, "directory is empty" @@ -291,28 +302,35 @@ def assertDirNotEmpty(self, relative_path): def assertFileContentIdentical(self, file_actual, file_expected): """Make sure file_actual and file_expected are identical""" - with io.open(file_actual) as f_actual, io.open(file_expected) as f_expected: + with io.open(file_actual) as f_actual, io.open( + file_expected + ) as f_expected: lhs = f_actual.readlines() rhs = f_expected.readlines() different_lines = list(difflib.unified_diff(rhs, lhs)) try: self.assertEqual(len(different_lines), 0) except AssertionError: - raise AssertionError("""differences found between files + raise AssertionError( + """differences found between files * %s * %s included: %s - """ % ( - os.path.relpath(file_expected, MIG_BASE), - os.path.relpath(file_actual, MIG_BASE), - ''.join(different_lines))) + """ + % ( + os.path.relpath(file_expected, MIG_BASE), + os.path.relpath(file_actual, MIG_BASE), + "".join(different_lines), + ) + ) def assertFileExists(self, relative_path): """Make sure relative_path exists and is a file""" path_kind = self.assertPathExists(relative_path) assert path_kind == "file", "expected a file but found %s" % ( - path_kind, ) + path_kind, + ) return os.path.join(TEST_OUTPUT_DIR, relative_path) def assertPathExists(self, relative_path): @@ -342,13 +360,14 @@ def assertPathWithin(self, path, start=None): """Make sure path is within start directory""" if not is_path_within(path, start=start): raise AssertionError( - "path %s is not within directory %s" % (path, start)) + "path %s is not within directory %s" % (path, start) + ) @staticmethod def pretty_display_path(absolute_path): assert os.path.isabs(absolute_path) relative_path = os.path.relpath(absolute_path, start=MIG_BASE) - assert not relative_path.startswith('..') + assert not relative_path.startswith("..") return relative_path @staticmethod @@ -359,7 +378,9 @@ def _provision_test_user(testcase, distinguished_name): Note that this method, along with a number of others, are defined in the user portion of the test support libraries. """ - return UserAssertMixin._provision_test_user(testcase, distinguished_name) + return UserAssertMixin._provision_test_user( + testcase, distinguished_name + ) def is_path_within(path, start=None, _msg=None): @@ -369,7 +390,7 @@ def is_path_within(path, start=None, _msg=None): relative = os.path.relpath(path, start=start) except: return False - return not relative.startswith('..') + return not relative.startswith("..") def ensure_dirs_exist(absolute_dir): @@ -400,7 +421,7 @@ def temppath(relative_path, test_case, ensure_dir=False): # failsafe path checking that supplied paths are rooted within valid paths is_tmp_path_within_safe_dir = False - for start in (ENVHELP_OUTPUT_DIR): + for start in ENVHELP_OUTPUT_DIR: is_tmp_path_within_safe_dir = is_path_within(tmp_path, start=start) if is_tmp_path_within_safe_dir: break @@ -413,7 +434,8 @@ def temppath(relative_path, test_case, ensure_dir=False): except OSError as oserr: if oserr.errno == errno.EEXIST: raise AssertionError( - "ABORT: use of unclean output path: %s" % tmp_path) + "ABORT: use of unclean output path: %s" % tmp_path + ) return tmp_path diff --git a/tests/support/_env.py b/tests/support/_env.py index 2c71386a4..d86bbd2d3 100644 --- a/tests/support/_env.py +++ b/tests/support/_env.py @@ -2,10 +2,10 @@ import sys # expose the configured environment as a constant -MIG_ENV = os.environ.get('MIG_ENV', 'local') +MIG_ENV = os.environ.get("MIG_ENV", "local") # force the chosen environment globally -os.environ['MIG_ENV'] = MIG_ENV +os.environ["MIG_ENV"] = MIG_ENV # expose a boolean indicating whether we are executing on Python 2 -PY2 = (sys.version_info[0] == 2) +PY2 = sys.version_info[0] == 2 diff --git a/tests/support/assertover.py b/tests/support/assertover.py index 52b445ab2..3e31b7b5b 100644 --- a/tests/support/assertover.py +++ b/tests/support/assertover.py @@ -29,11 +29,13 @@ class NoBlockError(AssertionError): """Decorate AssertionError for our own convenience""" + pass class NoCasesError(AssertionError): """Decorate AssertionError for our own convenience""" + pass @@ -76,10 +78,15 @@ def __exit__(self, exc_type, exc_value, traceback): if not any(self._attempts): return True - value_lines = ["- <%r> : %s" % (attempt[0], str(attempt[1])) for - attempt in self._attempts if attempt] - raise AssertionError("assertions raised for the following values:\n%s" - % '\n'.join(value_lines)) + value_lines = [ + "- <%r> : %s" % (attempt[0], str(attempt[1])) + for attempt in self._attempts + if attempt + ] + raise AssertionError( + "assertions raised for the following values:\n%s" + % "\n".join(value_lines) + ) def record_attempt(self, attempt_info): """Record the result of a test attempt""" @@ -89,7 +96,9 @@ def to_check_callable(self): def raise_unless_consulted(): if not self._consulted: raise AssertionError( - "no examiniation made of assertion of multiple values") + "no examiniation made of assertion of multiple values" + ) + return raise_unless_consulted def assert_success(self): @@ -103,4 +112,7 @@ def _execute_block(cls, block, block_value): block.__call__(block_value) return None except Exception as blockexc: - return (block_value, blockexc,) + return ( + block_value, + blockexc, + ) diff --git a/tests/support/configsupp.py b/tests/support/configsupp.py index 0846e465d..bb8511542 100644 --- a/tests/support/configsupp.py +++ b/tests/support/configsupp.py @@ -27,20 +27,21 @@ """Configuration related details within the test support library.""" -from tests.support.loggersupp import FakeLogger - from mig.shared.compat import SimpleNamespace -from mig.shared.configuration import \ - _CONFIGURATION_ARGUMENTS, _CONFIGURATION_PROPERTIES +from mig.shared.configuration import ( + _CONFIGURATION_ARGUMENTS, + _CONFIGURATION_PROPERTIES, +) +from tests.support.loggersupp import FakeLogger def _ensure_only_configuration_keys(thedict): - """Check a dictionary contains only keys valid as Configuration properties. - """ + """Check a dictionary contains only keys valid as Configuration properties.""" unknown_keys = set(thedict.keys()) - set(_CONFIGURATION_ARGUMENTS) - assert len(unknown_keys) == 0, \ - "non-Configuration keys: %s" % (', '.join(unknown_keys),) + assert len(unknown_keys) == 0, "non-Configuration keys: %s" % ( + ", ".join(unknown_keys), + ) def _generate_namespace_kwargs(): @@ -49,7 +50,7 @@ def _generate_namespace_kwargs(): """ properties_and_defaults = dict(_CONFIGURATION_PROPERTIES) - properties_and_defaults['logger'] = None + properties_and_defaults["logger"] = None return properties_and_defaults diff --git a/tests/support/fixturesupp.py b/tests/support/fixturesupp.py index c418fc7a2..0172a3e43 100644 --- a/tests/support/fixturesupp.py +++ b/tests/support/fixturesupp.py @@ -27,12 +27,12 @@ """Fixture related details within the test support library.""" -from configparser import ConfigParser -from datetime import date, timedelta import json import os import pickle import shutil +from configparser import ConfigParser +from datetime import date, timedelta from time import mktime from types import SimpleNamespace @@ -51,30 +51,35 @@ def _fixturefile_loadrelative(fixture_name, fixture_format=None): assert fixture_format is not None, "fixture format must be specified" relative_path_with_ext = "%s.%s" % (fixture_name, fixture_format) tmp_path = os.path.join(TEST_FIXTURE_DIR, relative_path_with_ext) - assert os.path.isfile(tmp_path), \ - 'fixture named "%s" with format %s is not present: %s' % \ - (fixture_name, fixture_format, relative_path_with_ext) + assert os.path.isfile( + tmp_path + ), 'fixture named "%s" with format %s is not present: %s' % ( + fixture_name, + fixture_format, + relative_path_with_ext, + ) data = None - if fixture_format == 'binary': - with open(tmp_path, 'rb') as binfile: + if fixture_format == "binary": + with open(tmp_path, "rb") as binfile: data = binfile.read() - elif fixture_format == 'json': + elif fixture_format == "json": with open(tmp_path) as jsonfile: data = json.load(jsonfile, object_hook=_FixtureHint.object_hook) _hints_apply_from_instances_if_present(data) _hints_apply_from_fixture_ini_if_present(fixture_name, data) else: raise AssertionError( - "unsupported fixture format: %s" % (fixture_format,)) + "unsupported fixture format: %s" % (fixture_format,) + ) return data, tmp_path -def _fixturefile_normname(relative_path, prefix=''): +def _fixturefile_normname(relative_path, prefix=""): """Grab normname from relative_path and optionally add a path prefix""" - normname, _ = relative_path.split('--') + normname, _ = relative_path.split("--") if prefix: return os.path.join(prefix, normname) return normname @@ -96,6 +101,7 @@ def _fixturefile_normname(relative_path, prefix=''): # # + def _hints_apply_array_of_tuples(value, modifier): """ Convert list of lists such that its values are instead tuples. @@ -109,7 +115,7 @@ def _hints_apply_today_relative(value, modifier): Geneate a time value by applying a declared delta to today's date. """ - kind, delta = modifier.split('|') + kind, delta = modifier.split("|") if kind == "days": time_delta = timedelta(days=int(delta)) adjusted_datetime = date.today() + time_delta @@ -131,15 +137,17 @@ def _hints_apply_dict_bytes_to_strings_kv(input_dict, modifier): for k, v in input_dict.items(): key_to_use = k if isinstance(k, bytes): - key_to_use = str(k, 'utf8') + key_to_use = str(k, "utf8") if isinstance(v, dict): - output_dict[key_to_use] = _hints_apply_dict_bytes_to_strings_kv(v, modifier) + output_dict[key_to_use] = _hints_apply_dict_bytes_to_strings_kv( + v, modifier + ) continue val_to_use = v if isinstance(v, bytes): - val_to_use = str(v, 'utf8') + val_to_use = str(v, "utf8") output_dict[key_to_use] = val_to_use @@ -159,15 +167,17 @@ def _hints_apply_dict_strings_to_bytes_kv(input_dict, modifier): for k, v in input_dict.items(): key_to_use = k if isinstance(k, str): - key_to_use = bytes(k, 'utf8') + key_to_use = bytes(k, "utf8") if isinstance(v, dict): - output_dict[key_to_use] = _hints_apply_dict_strings_to_bytes_kv(v, modifier) + output_dict[key_to_use] = _hints_apply_dict_strings_to_bytes_kv( + v, modifier + ) continue val_to_use = v if isinstance(v, str): - val_to_use = bytes(v, 'utf8') + val_to_use = bytes(v, "utf8") output_dict[key_to_use] = val_to_use @@ -176,21 +186,21 @@ def _hints_apply_dict_strings_to_bytes_kv(input_dict, modifier): # hints that can be aplied without an additional modifier argument _HINTS_APPLIERS_ARGLESS = { - 'array_of_tuples': _hints_apply_array_of_tuples, - 'today_relative': _hints_apply_today_relative, - 'convert_dict_bytes_to_strings_kv': _hints_apply_dict_bytes_to_strings_kv, - 'convert_dict_strings_to_bytes_kv': _hints_apply_dict_strings_to_bytes_kv, + "array_of_tuples": _hints_apply_array_of_tuples, + "today_relative": _hints_apply_today_relative, + "convert_dict_bytes_to_strings_kv": _hints_apply_dict_bytes_to_strings_kv, + "convert_dict_strings_to_bytes_kv": _hints_apply_dict_strings_to_bytes_kv, } # hints applicable to the conversion of attributes during fixture loading _FIXTUREFILE_APPLIERS_ATTRIBUTES = { - 'array_of_tuples': _hints_apply_array_of_tuples, - 'today_relative': _hints_apply_today_relative, + "array_of_tuples": _hints_apply_array_of_tuples, + "today_relative": _hints_apply_today_relative, } # hints applied when writing the contents of a fixture as a temporary file _FIXTUREFILE_APPLIERS_ONWRITE = { - 'convert_dict_strings_to_bytes_kv': _hints_apply_dict_strings_to_bytes_kv, + "convert_dict_strings_to_bytes_kv": _hints_apply_dict_strings_to_bytes_kv, } @@ -222,7 +232,7 @@ def _load_hints_ini_for_fixture_if_present(fixture_name): pass # ensure empty required fixture to avoid extra conditionals later - for required_section in ['ATTRIBUTES']: + for required_section in ["ATTRIBUTES"]: if not hints.has_section(required_section): hints.add_section(required_section) @@ -239,10 +249,10 @@ def _hints_apply_from_fixture_ini_if_present(fixture_name, json_object): # apply any attriutes hints ahead of specified conversions such that any # key can be specified matching what is visible within the loaded fixture - for item_name, item_hint_unparsed in hints['ATTRIBUTES'].items(): + for item_name, item_hint_unparsed in hints["ATTRIBUTES"].items(): loaded_value = json_object[item_name] - item_hint_and_maybe_modifier = item_hint_unparsed.split('--') + item_hint_and_maybe_modifier = item_hint_unparsed.split("--") item_hint = item_hint_and_maybe_modifier[0] if len(item_hint_and_maybe_modifier) == 2: modifier = item_hint_and_maybe_modifier[1] @@ -267,7 +277,9 @@ def __init__(self, hint=None, modifier=None, value=None): def decode_hint(hint_obj): """Produce a value based on the properties of a hint instance.""" assert isinstance(hint_obj, _FixtureHint) - value_from_loaded_value = _FIXTUREFILE_APPLIERS_ATTRIBUTES[hint_obj.hint] + value_from_loaded_value = _FIXTUREFILE_APPLIERS_ATTRIBUTES[ + hint_obj.hint + ] return value_from_loaded_value(hint_obj.value, hint_obj.modifier) @staticmethod @@ -278,11 +290,14 @@ def object_hook(decoded_object): """ if "_FixtureHint" in decoded_object: - fixture_hint = _FixtureHint(decoded_object["hint"], decoded_object["modifier"]) + fixture_hint = _FixtureHint( + decoded_object["hint"], decoded_object["modifier"] + ) return _FixtureHint.decode_hint(fixture_hint) return decoded_object + # @@ -295,7 +310,7 @@ def fixturepath(relative_path): def _to_display_path(value): """Convert an absolute path to one to be shown as part of test output.""" display_path = os.path.relpath(value, MIG_BASE) - if not display_path.startswith('.'): + if not display_path.startswith("."): return "./" + display_path return display_path @@ -307,10 +322,9 @@ class _PreparedFixture: NO_DATA = object() - def __init__(self, testcase, - fixture_name, - fixture_format='', - fixture_data=NO_DATA): + def __init__( + self, testcase, fixture_name, fixture_format="", fixture_data=NO_DATA + ): self.testcase = testcase self.fixture_name = fixture_name self.fixture_format = fixture_format @@ -336,9 +350,12 @@ def assertAgainstFixture(self, value): if self.fixture_format: message_infix = " with format %s" % (self.fixture_format,) else: - message_infix = '' + message_infix = "" message = "value differed from fixture named %s%s\n\n%s" % ( - self.fixture_name, message_infix, raised_exception) + self.fixture_name, + message_infix, + raised_exception, + ) raise AssertionError(message) def write_to_dir(self, target_dir, output_format=None): @@ -347,42 +364,47 @@ def write_to_dir(self, target_dir, output_format=None): directory applying any onwrite hints that may be specified. """ - assert self.fixture_data is not self.NO_DATA, \ - "fixture is not populated with data" + assert ( + self.fixture_data is not self.NO_DATA + ), "fixture is not populated with data" assert os.path.isabs(target_dir) # convert fixture name (which includes the varaint) to the target file - fixture_file_target = _fixturefile_normname(self.fixture_name, prefix=target_dir) + fixture_file_target = _fixturefile_normname( + self.fixture_name, prefix=target_dir + ) output_data = self.fixture_data # now apply any onwrite conversions hints = _load_hints_ini_for_fixture_if_present(self.fixture_name) - for item_name in hints['ONWRITE']: + for item_name in hints["ONWRITE"]: if item_name not in _FIXTUREFILE_APPLIERS_ONWRITE: raise AssertionError( - "unsupported fixture conversion: %s" % (item_name,)) + "unsupported fixture conversion: %s" % (item_name,) + ) - enabled = hints.getboolean('ONWRITE', item_name) + enabled = hints.getboolean("ONWRITE", item_name) if not enabled: continue hint_fn = _FIXTUREFILE_APPLIERS_ONWRITE[item_name] output_data = hint_fn(output_data, None) - if output_format == 'binary': - with open(fixture_file_target, 'wb') as fixture_outputfile: + if output_format == "binary": + with open(fixture_file_target, "wb") as fixture_outputfile: fixture_outputfile.write(output_data) - elif output_format == 'json': - with open(fixture_file_target, 'w') as fixture_outputfile: + elif output_format == "json": + with open(fixture_file_target, "w") as fixture_outputfile: json.dump(output_data, fixture_outputfile) - elif output_format == 'pickle': - with open(fixture_file_target, 'wb') as fixture_outputfile: + elif output_format == "pickle": + with open(fixture_file_target, "wb") as fixture_outputfile: pickle.dump(output_data, fixture_outputfile) else: raise AssertionError( - "unsupported fixture format: %s" % (output_format,)) + "unsupported fixture format: %s" % (output_format,) + ) @staticmethod def from_relpath(testcase, fixture_name, fixture_format): @@ -392,11 +414,16 @@ def from_relpath(testcase, fixture_name, fixture_format): """ fixture_data, fixture_path = _fixturefile_loadrelative( - fixture_name, fixture_format) - return _PreparedFixture(testcase, fixture_name, fixture_format, fixture_data) + fixture_name, fixture_format + ) + return _PreparedFixture( + testcase, fixture_name, fixture_format, fixture_data + ) class FixtureAssertMixin: def prepareFixtureAssert(self, fixture_relpath, fixture_format=None): """Prepare to assert a value against a fixture.""" - return _PreparedFixture.from_relpath(self, fixture_relpath, fixture_format) + return _PreparedFixture.from_relpath( + self, fixture_relpath, fixture_format + ) diff --git a/tests/support/loggersupp.py b/tests/support/loggersupp.py index b1eb2c295..5de13fb8a 100644 --- a/tests/support/loggersupp.py +++ b/tests/support/loggersupp.py @@ -28,9 +28,9 @@ """Logger related details within the test support library.""" -from collections import defaultdict import os import re +from collections import defaultdict from tests.support.suppconst import MIG_BASE, TEST_BASE @@ -47,7 +47,8 @@ class FakeLogger: """ RE_UNCLOSEDFILE = re.compile( - 'unclosed file <.*? name=\'(?P.*?)\'( .*?)?>') + "unclosed file <.*? name='(?P.*?)'( .*?)?>" + ) def __init__(self): self.channels_dict = defaultdict(list) @@ -70,13 +71,19 @@ def check_empty_and_reset(self): # complain loudly (and in detail) in the case of unclosed files if len(unclosed_by_file) > 0: - messages = '\n'.join({' --> %s: line=%s, file=%s' % (fname, lineno, outname) - for fname, (lineno, outname) in unclosed_by_file.items()}) - raise RuntimeError('unclosed files encountered:\n%s' % (messages,)) - - if channels_dict['error'] and not forgive_by_channel['error']: - raise RuntimeError('errors reported to logger:\n%s' % - '\n'.join(channels_dict['error'])) + messages = "\n".join( + { + " --> %s: line=%s, file=%s" % (fname, lineno, outname) + for fname, (lineno, outname) in unclosed_by_file.items() + } + ) + raise RuntimeError("unclosed files encountered:\n%s" % (messages,)) + + if channels_dict["error"] and not forgive_by_channel["error"]: + raise RuntimeError( + "errors reported to logger:\n%s" + % "\n".join(channels_dict["error"]) + ) def forgive_errors(self): """Allow log errors for cases where they are expected""" @@ -90,26 +97,26 @@ def forgive_messages_on(self, *, channel_name=None): def debug(self, line): """Mock log action of same name""" - self._append_as('debug', line) + self._append_as("debug", line) def error(self, line): """Mock log action of same name""" - self._append_as('error', line) + self._append_as("error", line) def info(self, line): """Mock log action of same name""" - self._append_as('info', line) + self._append_as("info", line) def warning(self, line): """Mock log action of same name""" - self._append_as('warning', line) + self._append_as("warning", line) def write(self, message): """Actual write handler""" - channel, namespace, specifics = message.split(':', 2) + channel, namespace, specifics = message.split(":", 2) # ignore everything except warnings sent by the python runtime - if not (channel == 'WARNING' and namespace == 'py.warnings'): + if not (channel == "WARNING" and namespace == "py.warnings"): return filename_and_datatuple = FakeLogger.identify_unclosed_file(specifics) @@ -119,10 +126,10 @@ def write(self, message): @staticmethod def identify_unclosed_file(specifics): """Warn about unclosed files""" - filename, lineno, exc_name, message = specifics.split(':', 3) + filename, lineno, exc_name, message = specifics.split(":", 3) exc_name = exc_name.lstrip() - if exc_name != 'ResourceWarning': + if exc_name != "ResourceWarning": return matched = FakeLogger.RE_UNCLOSEDFILE.match(message.lstrip()) @@ -131,7 +138,8 @@ def identify_unclosed_file(specifics): relative_testfile = os.path.relpath(filename, start=MIG_BASE) relative_outputfile = os.path.relpath( - matched.groups('location')[0], start=TEST_BASE) + matched.groups("location")[0], start=TEST_BASE + ) return (relative_testfile, (lineno, relative_outputfile)) diff --git a/tests/support/picklesupp.py b/tests/support/picklesupp.py index 667dd4b01..262e4c50c 100644 --- a/tests/support/picklesupp.py +++ b/tests/support/picklesupp.py @@ -29,8 +29,8 @@ import pickle -from tests.support.suppconst import TEST_OUTPUT_DIR from tests.support.fixturesupp import _HINTS_APPLIERS_ARGLESS +from tests.support.suppconst import TEST_OUTPUT_DIR class PickleAssertMixin: @@ -44,7 +44,7 @@ def assertPickledFile(self, pickle_file_path, apply_hints=None): having been optionally transformed as requested by hints. """ - with open(pickle_file_path, 'rb') as picklefile: + with open(pickle_file_path, "rb") as picklefile: pickled = pickle.load(picklefile) if not apply_hints: diff --git a/tests/support/serversupp.py b/tests/support/serversupp.py index 0e0fd4b94..bdca0fa33 100644 --- a/tests/support/serversupp.py +++ b/tests/support/serversupp.py @@ -27,7 +27,8 @@ """Server threading related details within the test support library""" -from threading import Thread, Event as ThreadEvent +from threading import Event as ThreadEvent +from threading import Thread class ServerWithinThreadExecutor: @@ -50,7 +51,7 @@ def run(self): """Mimic the same method from the standard thread API""" server_args, server_kwargs = self._arguments - server_kwargs['on_start'] = lambda _: self._started.set() + server_kwargs["on_start"] = lambda _: self._started.set() self._wrapped = self._serverclass(*server_args, **server_kwargs) diff --git a/tests/support/snapshotsupp.py b/tests/support/snapshotsupp.py index 355095814..a24bb2238 100644 --- a/tests/support/snapshotsupp.py +++ b/tests/support/snapshotsupp.py @@ -28,14 +28,14 @@ import difflib import errno -import re import os +import re from tests.support.suppconst import TEST_BASE -HTML_TAG = '' -MARKER_CONTENT_BEGIN = '' -MARKER_CONTENT_END = '' +HTML_TAG = "" +MARKER_CONTENT_BEGIN = "" +MARKER_CONTENT_END = "" TEST_SNAPSHOTS_DIR = os.path.join(TEST_BASE, "snapshots") try: @@ -57,9 +57,9 @@ def _html_content_only(value): # set the index after the content marker content_start_index += len(MARKER_CONTENT_BEGIN) # we now need to remove the container div inside it ..first find it - content_start_inner_div = value.find('', content_start_inner_div) + 1 + content_start_index = value.find(">", content_start_inner_div) + 1 content_end_index = value.find(MARKER_CONTENT_END) assert content_end_index > -1, "unable to locate end of content" @@ -77,7 +77,7 @@ def _delimited_lines(value): lines = [] while from_index < last_index: - found_index = value.find('\n', from_index) + found_index = value.find("\n", from_index) if found_index == -1: break found_index += 1 @@ -93,8 +93,8 @@ def _delimited_lines(value): def _force_refresh_snapshots(): """Check whether the environment specifies snapshots should be refreshed.""" - env_refresh_snapshots = os.environ.get('REFRESH_SNAPSHOTS', 'no').lower() - return env_refresh_snapshots in ('true', 'yes', '1') + env_refresh_snapshots = os.environ.get("REFRESH_SNAPSHOTS", "no").lower() + return env_refresh_snapshots in ("true", "yes", "1") class SnapshotAssertMixin: @@ -107,7 +107,7 @@ def _snapshotsupp_compare_snapshot(self, extension, actual_content): In the case a snapshot does not exist it is saved on first invocation. """ - file_name = ''.join([self._testMethodName, ".", extension]) + file_name = "".join([self._testMethodName, ".", extension]) file_path = os.path.join(TEST_SNAPSHOTS_DIR, file_name) if not os.path.isfile(file_path) or _force_refresh_snapshots(): @@ -126,11 +126,12 @@ def _snapshotsupp_compare_snapshot(self, extension, actual_content): udiff = difflib.unified_diff( _delimited_lines(expected_content), _delimited_lines(actual_content), - 'expected', - 'actual' + "expected", + "actual", ) raise AssertionError( - "content did not match snapshot\n\n%s" % (''.join(udiff),)) + "content did not match snapshot\n\n%s" % ("".join(udiff),) + ) def assertSnapshot(self, actual_content, extension=None): """Load a snapshot corresponding to the named test and check that what @@ -148,4 +149,4 @@ def assertSnapshotOfHtmlContent(self, actual_content): """ actual_content = _html_content_only(actual_content) - self._snapshotsupp_compare_snapshot('html', actual_content) + self._snapshotsupp_compare_snapshot("html", actual_content) diff --git a/tests/support/suppconst.py b/tests/support/suppconst.py index 148303f0d..204b0ba29 100644 --- a/tests/support/suppconst.py +++ b/tests/support/suppconst.py @@ -29,11 +29,11 @@ from tests.support._env import MIG_ENV -if MIG_ENV == 'local': +if MIG_ENV == "local": # Use abspath for __file__ on Py2 _SUPPORT_DIR = os.path.dirname(os.path.abspath(__file__)) -elif MIG_ENV == 'docker': - _SUPPORT_DIR = '/usr/src/app/tests/support' +elif MIG_ENV == "docker": + _SUPPORT_DIR = "/usr/src/app/tests/support" else: raise NotImplementedError("ABORT: unsupported environment: %s" % (MIG_ENV,)) @@ -46,7 +46,8 @@ ENVHELP_OUTPUT_DIR = os.path.join(ENVHELP_DIR, "output") -if __name__ == '__main__': +if __name__ == "__main__": + def print_root_relative(prefix, path): print("%s = /%s" % (prefix, os.path.relpath(path, MIG_BASE))) diff --git a/tests/support/usersupp.py b/tests/support/usersupp.py index 65a02fa4a..849c98f92 100644 --- a/tests/support/usersupp.py +++ b/tests/support/usersupp.py @@ -33,15 +33,11 @@ import pickle from mig.shared.base import client_id_dir - from tests.support.fixturesupp import _PreparedFixture +TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" -TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' - -_FIXTURE_NAME_BY_USER_DN = { - TEST_USER_DN: 'MiG-users.db--example' -} +_FIXTURE_NAME_BY_USER_DN = {TEST_USER_DN: "MiG-users.db--example"} class UserAssertMixin: @@ -59,9 +55,9 @@ def _provision_user_db_dir(testcase): conf_user_db_home = testcase.configuration.user_db_home os.makedirs(conf_user_db_home, exist_ok=True) - user_db_file = os.path.join(conf_user_db_home, 'MiG-users.db') + user_db_file = os.path.join(conf_user_db_home, "MiG-users.db") if os.path.exists(user_db_file): - raise AssertionError('a user database file already exists') + raise AssertionError("a user database file already exists") return conf_user_db_home @@ -79,19 +75,19 @@ def _provision_test_user(testcase, distinguished_name): try: fixture_relpath = _FIXTURE_NAME_BY_USER_DN[distinguished_name] except KeyError: - raise AssertionError('supplied test user is not known as a fixture') + raise AssertionError("supplied test user is not known as a fixture") # note: this is a non-standard direct use of fixture preparation due # to this being bootstrap code and should not be used elsewhere prepared_fixture = _PreparedFixture.from_relpath( - testcase, - fixture_relpath, - fixture_format='json' + testcase, fixture_relpath, fixture_format="json" ) # write out the user database fixture containing the user - prepared_fixture.write_to_dir(conf_user_db_home, output_format='pickle') + prepared_fixture.write_to_dir(conf_user_db_home, output_format="pickle") - test_user_dir = UserAssertMixin._provision_test_user_dirs(testcase, distinguished_name) + test_user_dir = UserAssertMixin._provision_test_user_dirs( + testcase, distinguished_name + ) return test_user_dir @@ -112,12 +108,16 @@ def _provision_test_user_dirs(testcase, distinguished_name): # create the test user settings directory conf_user_settings = os.path.normpath(self.configuration.user_settings) - test_user_settings_dir = os.path.join(conf_user_settings, test_client_dir_name) + test_user_settings_dir = os.path.join( + conf_user_settings, test_client_dir_name + ) os.makedirs(test_user_settings_dir) # create an empty user settings file - test_user_settings_file = os.path.join(test_user_settings_dir, 'settings') - with open(test_user_settings_file, 'wb') as outfile: + test_user_settings_file = os.path.join( + test_user_settings_dir, "settings" + ) + with open(test_user_settings_file, "wb") as outfile: pickle.dump({}, outfile) return test_user_dir @@ -154,9 +154,11 @@ def _provision_test_users(testcase, *distinguished_names): # write out all the users we have assembled by populating an empty # fixture with their data but using a known fixture name and thus one # suitably hinted so a production format pickle file ends up on-disk - prepared_fixture = _PreparedFixture(testcase, 'MiG-users.db--example') + prepared_fixture = _PreparedFixture(testcase, "MiG-users.db--example") prepared_fixture.fixture_data = users_by_dn - prepared_fixture.write_to_dir(conf_user_db_home, output_format='pickle') + prepared_fixture.write_to_dir(conf_user_db_home, output_format="pickle") for distinguished_name in distinguished_names: - UserAssertMixin._provision_test_user_dirs(testcase, distinguished_name) + UserAssertMixin._provision_test_user_dirs( + testcase, distinguished_name + ) diff --git a/tests/support/wsgisupp.py b/tests/support/wsgisupp.py index 1105d0db8..2a6a209d2 100644 --- a/tests/support/wsgisupp.py +++ b/tests/support/wsgisupp.py @@ -27,22 +27,21 @@ """Test support library for WSGI.""" -from collections import namedtuple import codecs +from collections import namedtuple from io import BytesIO from urllib.parse import urlencode, urlparse from werkzeug.datastructures import MultiDict - # named type representing the tuple that is passed to WSGI handlers -_PreparedWsgi = namedtuple('_PreparedWsgi', ['environ', 'start_response']) +_PreparedWsgi = namedtuple("_PreparedWsgi", ["environ", "start_response"]) class FakeWsgiStartResponse: """Glue object that conforms to the same interface as the start_response() - in the WSGI specs but records the calls to it such that they can be - inspected and, for our purposes, asserted against.""" + in the WSGI specs but records the calls to it such that they can be + inspected and, for our purposes, asserted against.""" def __init__(self): self.calls = [] @@ -51,7 +50,9 @@ def __call__(self, status, headers, exc=None): self.calls.append((status, headers, exc)) -def create_wsgi_environ(configuration, wsgi_url, method='GET', query=None, headers=None, form=None): +def create_wsgi_environ( + configuration, wsgi_url, method="GET", query=None, headers=None, form=None +): """Populate the necessary variables that will constitute a valid WSGI environment given a URL to which we will make a requests under test and various other options that set up the nature of that request.""" @@ -59,21 +60,21 @@ def create_wsgi_environ(configuration, wsgi_url, method='GET', query=None, heade parsed_url = urlparse(wsgi_url) if query: - method = 'GET' + method = "GET" request_query = urlencode(query) wsgi_input = () elif form: - method = 'POST' - request_query = '' + method = "POST" + request_query = "" - body = urlencode(MultiDict(form)).encode('ascii') + body = urlencode(MultiDict(form)).encode("ascii") headers = headers or {} - if not 'Content-Type' in headers: - headers['Content-Type'] = 'application/x-www-form-urlencoded' + if not "Content-Type" in headers: + headers["Content-Type"] = "application/x-www-form-urlencoded" - headers['Content-Length'] = str(len(body)) + headers["Content-Length"] = str(len(body)) wsgi_input = BytesIO(body) else: request_query = parsed_url.query @@ -83,26 +84,27 @@ class _errors: """Internal helper to ignore wsgi.errors close method calls""" def close(self, *ars, **kwargs): - """"Simply ignore""" + """ "Simply ignore""" pass environ = {} - environ['wsgi.errors'] = _errors() - environ['wsgi.input'] = wsgi_input - environ['wsgi.url_scheme'] = parsed_url.scheme - environ['wsgi.version'] = (1, 0) - environ['MIG_CONF'] = configuration.config_file - environ['HTTP_HOST'] = parsed_url.netloc - environ['PATH_INFO'] = parsed_url.path - environ['QUERY_STRING'] = request_query - environ['REQUEST_METHOD'] = method - environ['SCRIPT_URI'] = ''.join( - ('http://', environ['HTTP_HOST'], environ['PATH_INFO'])) + environ["wsgi.errors"] = _errors() + environ["wsgi.input"] = wsgi_input + environ["wsgi.url_scheme"] = parsed_url.scheme + environ["wsgi.version"] = (1, 0) + environ["MIG_CONF"] = configuration.config_file + environ["HTTP_HOST"] = parsed_url.netloc + environ["PATH_INFO"] = parsed_url.path + environ["QUERY_STRING"] = request_query + environ["REQUEST_METHOD"] = method + environ["SCRIPT_URI"] = "".join( + ("http://", environ["HTTP_HOST"], environ["PATH_INFO"]) + ) if headers: for k, v in headers.items(): - header_key = k.replace('-', '_').upper() - if header_key.startswith('CONTENT'): + header_key = k.replace("-", "_").upper() + if header_key.startswith("CONTENT"): # Content-* headers must not be prefixed in WSGI pass else: @@ -119,15 +121,15 @@ def create_wsgi_start_response(): def prepare_wsgi(configuration, url, **kwargs): return _PreparedWsgi( create_wsgi_environ(configuration, url, **kwargs), - create_wsgi_start_response() + create_wsgi_start_response(), ) def _trigger_and_unpack_result(wsgi_result): chunks = list(wsgi_result) assert len(chunks) > 0, "invocation returned no output" - complete_value = b''.join(chunks) - decoded_value = codecs.decode(complete_value, 'utf8') + complete_value = b"".join(chunks) + decoded_value = codecs.decode(complete_value, "utf8") return decoded_value @@ -140,7 +142,7 @@ def assertWsgiResponse(self, wsgi_result, fake_wsgi, expected_status_code): content = _trigger_and_unpack_result(wsgi_result) def called_once(fake): - assert hasattr(fake, 'calls') + assert hasattr(fake, "calls") return len(fake.calls) == 1 fake_start_response = fake_wsgi.start_response diff --git a/tests/test_booleans.py b/tests/test_booleans.py index 5246197ee..a3564b8ff 100644 --- a/tests/test_booleans.py +++ b/tests/test_booleans.py @@ -2,6 +2,7 @@ from tests.support import MigTestCase, testmain + class TestBooleans(MigTestCase): def test_true(self): self.assertEqual(True, True) @@ -10,5 +11,5 @@ def test_false(self): self.assertEqual(False, False) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_install_generateconfs.py b/tests/test_mig_install_generateconfs.py index e8ed90241..c68318cd9 100644 --- a/tests/test_mig_install_generateconfs.py +++ b/tests/test_mig_install_generateconfs.py @@ -33,13 +33,13 @@ import os import sys -from tests.support import MIG_BASE, MigTestCase, testmain, cleanpath +from tests.support import MIG_BASE, MigTestCase, cleanpath, testmain def _import_generateconfs(): """Internal helper to work around non-package import location""" - sys.path.append(os.path.join(MIG_BASE, 'mig/install')) - mod = importlib.import_module('generateconfs') + sys.path.append(os.path.join(MIG_BASE, "mig/install")) + mod = importlib.import_module("generateconfs") sys.path.pop(-1) return mod @@ -51,12 +51,14 @@ def _import_generateconfs(): def create_fake_generate_confs(return_dict=None): """Fake generate confs helper""" + def _generate_confs(*args, **kwargs): _generate_confs.settings = kwargs if return_dict: return (return_dict, {}) else: return ({}, {}) + _generate_confs.settings = None return _generate_confs @@ -69,52 +71,64 @@ class MigInstallGenerateconfs__main(MigTestCase): """Unit test helper for the migrid code pointed to in class name""" def test_option_permanent_freeze(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--permanent_freeze', 'yes'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--permanent_freeze", "yes"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) self.assertEqual(exit_code, 0) def test_option_storage_protocols(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--storage_protocols', 'proto1 proto2 proto3'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--storage_protocols", "proto1 proto2 proto3"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) self.assertEqual(exit_code, 0) settings = fake_generate_confs.settings - self.assertIn('storage_protocols', settings) - self.assertEqual(settings['storage_protocols'], 'proto1 proto2 proto3') + self.assertIn("storage_protocols", settings) + self.assertEqual(settings["storage_protocols"], "proto1 proto2 proto3") def test_option_wwwserve_max_bytes(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--wwwserve_max_bytes', '43211234'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--wwwserve_max_bytes", "43211234"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) settings = fake_generate_confs.settings - self.assertIn('wwwserve_max_bytes', settings) - self.assertEqual(settings['wwwserve_max_bytes'], 43211234) + self.assertIn("wwwserve_max_bytes", settings) + self.assertEqual(settings["wwwserve_max_bytes"], 43211234) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_lib_accounting.py b/tests/test_mig_lib_accounting.py index fc0003fc4..9eb60a931 100644 --- a/tests/test_mig_lib_accounting.py +++ b/tests/test_mig_lib_accounting.py @@ -30,8 +30,11 @@ import os import pickle -from mig.lib.accounting import get_usage, human_readable_filesize, \ - update_accounting +from mig.lib.accounting import ( + get_usage, + human_readable_filesize, + update_accounting, +) from mig.shared.base import client_id_dir from mig.shared.defaults import peers_filename from tests.support import MigTestCase, ensure_dirs_exist @@ -39,71 +42,90 @@ TEST_MTIME = 1768925307 TEST_SOFTLIMIT_BYTES = 109951162777600 TEST_HARDLIMIT_BYTES = 109951162777600 -TEST_CLIENT_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@user.com' +TEST_CLIENT_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@user.com" +) TEST_CLIENT_BYTES = 206128256 -TEST_EXT_DN = '/C=DK/ST=NA/L=NA/O=PEER Org/OU=NA/CN=Test Peer/emailAddress=peer@example.com' +TEST_EXT_DN = "/C=DK/ST=NA/L=NA/O=PEER Org/OU=NA/CN=Test Peer/emailAddress=peer@example.com" TEST_EXT_BYTES = 16806128256 TEST_FREEZE_BYTES = 128256 -TEST_VGRID_NAME1 = 'TestVgrid1' +TEST_VGRID_NAME1 = "TestVgrid1" TEST_VGRID_BYTES1 = 406128256 -TEST_VGRID_NAME2 = 'TestVgrid2' +TEST_VGRID_NAME2 = "TestVgrid2" TEST_VGRID_BYTES2 = 606128256 -TEST_VGRID_NAME3 = 'TestVgrid3' +TEST_VGRID_NAME3 = "TestVgrid3" TEST_VGRID_BYTES3 = 806128256 -TEST_VGRID_TOTAL_BYTES = TEST_VGRID_BYTES1 \ - + TEST_VGRID_BYTES2 \ - + TEST_VGRID_BYTES3 -TEST_TOTAL_BYTES = TEST_CLIENT_BYTES \ - + TEST_EXT_BYTES \ - + TEST_FREEZE_BYTES \ +TEST_VGRID_TOTAL_BYTES = ( + TEST_VGRID_BYTES1 + TEST_VGRID_BYTES2 + TEST_VGRID_BYTES3 +) +TEST_TOTAL_BYTES = ( + TEST_CLIENT_BYTES + + TEST_EXT_BYTES + + TEST_FREEZE_BYTES + TEST_VGRID_TOTAL_BYTES -TEST_LUSTRE_QUOTA_INFO = {'next_pid': 192, 'mtime': TEST_MTIME} -TEST_CLIENT_USAGE = {'lustre_pid': 42, - 'files': 11, - 'bytes': TEST_CLIENT_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE1 = {'lustre_pid': 43, - 'files': 111, - 'bytes': TEST_VGRID_BYTES1, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE2 = {'lustre_pid': 44, - 'files': 222, - 'bytes': TEST_VGRID_BYTES2, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE3 = {'lustre_pid': 45, - 'files': 333, - 'bytes': TEST_VGRID_BYTES3, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_EXT_USAGE = {'lustre_pid': 46, - 'files': 1, - 'bytes': TEST_EXT_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_FREEZE_USAGE = {'lustre_pid': 47, - 'files': 1, - 'bytes': TEST_FREEZE_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_PEERS = {TEST_EXT_DN: {'kind': 'collaboration', - 'distinguished_name': TEST_EXT_DN, - 'country': 'DK', - 'label': 'TEST', - 'state': '', - 'expire': '2222-12-31', - 'full_name': 'Test Peer', - 'organization': 'PEER Org', - 'email': 'peer@example.com' - }} +) +TEST_LUSTRE_QUOTA_INFO = {"next_pid": 192, "mtime": TEST_MTIME} +TEST_CLIENT_USAGE = { + "lustre_pid": 42, + "files": 11, + "bytes": TEST_CLIENT_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE1 = { + "lustre_pid": 43, + "files": 111, + "bytes": TEST_VGRID_BYTES1, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE2 = { + "lustre_pid": 44, + "files": 222, + "bytes": TEST_VGRID_BYTES2, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE3 = { + "lustre_pid": 45, + "files": 333, + "bytes": TEST_VGRID_BYTES3, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_EXT_USAGE = { + "lustre_pid": 46, + "files": 1, + "bytes": TEST_EXT_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_FREEZE_USAGE = { + "lustre_pid": 47, + "files": 1, + "bytes": TEST_FREEZE_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_PEERS = { + TEST_EXT_DN: { + "kind": "collaboration", + "distinguished_name": TEST_EXT_DN, + "country": "DK", + "label": "TEST", + "state": "", + "expire": "2222-12-31", + "full_name": "Test Peer", + "organization": "PEER Org", + "email": "peer@example.com", + } +} class MigLibAccounting(MigTestCase): @@ -111,7 +133,7 @@ class MigLibAccounting(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up test configuration and reset state before each test""" @@ -120,15 +142,17 @@ def before_each(self): self.configuration.site_enable_quota = True self.configuration.site_enable_accounting = True - self.configuration.quota_backend = 'lustre' - - quota_basepath = os.path.join(self.configuration.quota_home, - self.configuration.quota_backend) - quota_user_path = os.path.join(quota_basepath, 'user') - quota_vgrid_path = os.path.join(quota_basepath, 'vgrid') - quota_freeze_path = os.path.join(quota_basepath, 'freeze') - test_client_peers_path = os.path.join(self.configuration.user_settings, - client_id_dir(TEST_CLIENT_DN)) + self.configuration.quota_backend = "lustre" + + quota_basepath = os.path.join( + self.configuration.quota_home, self.configuration.quota_backend + ) + quota_user_path = os.path.join(quota_basepath, "user") + quota_vgrid_path = os.path.join(quota_basepath, "vgrid") + quota_freeze_path = os.path.join(quota_basepath, "freeze") + test_client_peers_path = os.path.join( + self.configuration.user_settings, client_id_dir(TEST_CLIENT_DN) + ) ensure_dirs_exist(self.configuration.vgrid_home) ensure_dirs_exist(self.configuration.user_settings) @@ -141,61 +165,69 @@ def before_each(self): # Ensure fake vgrid and write owner - for vgrid_name in [TEST_VGRID_NAME1, - TEST_VGRID_NAME2, - TEST_VGRID_NAME3]: + for vgrid_name in [ + TEST_VGRID_NAME1, + TEST_VGRID_NAME2, + TEST_VGRID_NAME3, + ]: vgrid_home_path = os.path.join( - self.configuration.vgrid_home, vgrid_name) + self.configuration.vgrid_home, vgrid_name + ) ensure_dirs_exist(vgrid_home_path) - vgrid_owners_filepath = os.path.join(vgrid_home_path, 'owners') - with open(vgrid_owners_filepath, 'wb') as fh: + vgrid_owners_filepath = os.path.join(vgrid_home_path, "owners") + with open(vgrid_owners_filepath, "wb") as fh: fh.write(pickle.dumps([TEST_CLIENT_DN])) # Write fake quota - test_lustre_quota_info_filepath \ - = os.path.join(self.configuration.quota_home, - '%s.pck' % self.configuration.quota_backend) - with open(test_lustre_quota_info_filepath, 'wb') as fh: + test_lustre_quota_info_filepath = os.path.join( + self.configuration.quota_home, + "%s.pck" % self.configuration.quota_backend, + ) + with open(test_lustre_quota_info_filepath, "wb") as fh: fh.write(pickle.dumps(TEST_LUSTRE_QUOTA_INFO)) - quota_test_client_path \ - = os.path.join(quota_user_path, - "%s.pck" % client_id_dir(TEST_CLIENT_DN)) + quota_test_client_path = os.path.join( + quota_user_path, "%s.pck" % client_id_dir(TEST_CLIENT_DN) + ) - with open(quota_test_client_path, 'wb') as fh: + with open(quota_test_client_path, "wb") as fh: fh.write(pickle.dumps(TEST_CLIENT_USAGE)) - quot_test_vgrid_filepath1 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME1) - with open(quot_test_vgrid_filepath1, 'wb') as fh: + quot_test_vgrid_filepath1 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME1 + ) + with open(quot_test_vgrid_filepath1, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE1)) - quot_test_vgrid_filepath2 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME2) - with open(quot_test_vgrid_filepath2, 'wb') as fh: + quot_test_vgrid_filepath2 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME2 + ) + with open(quot_test_vgrid_filepath2, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE2)) - quot_test_vgrid_filepath3 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME3) - with open(quot_test_vgrid_filepath3, 'wb') as fh: + quot_test_vgrid_filepath3 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME3 + ) + with open(quot_test_vgrid_filepath3, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE3)) test_client_peers_filepath = os.path.join( - test_client_peers_path, peers_filename) - with open(test_client_peers_filepath, 'wb') as fh: + test_client_peers_path, peers_filename + ) + with open(test_client_peers_filepath, "wb") as fh: fh.write(pickle.dumps(TEST_PEERS)) - quota_test_client_ext_path \ - = os.path.join(quota_user_path, - "%s.pck" % client_id_dir(TEST_EXT_DN)) - with open(quota_test_client_ext_path, 'wb') as fh: + quota_test_client_ext_path = os.path.join( + quota_user_path, "%s.pck" % client_id_dir(TEST_EXT_DN) + ) + with open(quota_test_client_ext_path, "wb") as fh: fh.write(pickle.dumps(TEST_EXT_USAGE)) - quota_test_freeze_path = os.path.join(quota_freeze_path, - "%s.pck" - % client_id_dir(TEST_CLIENT_DN)) - with open(quota_test_freeze_path, 'wb') as fh: + quota_test_freeze_path = os.path.join( + quota_freeze_path, "%s.pck" % client_id_dir(TEST_CLIENT_DN) + ) + with open(quota_test_freeze_path, "wb") as fh: fh.write(pickle.dumps(TEST_FREEZE_USAGE)) def test_accounting(self): @@ -211,36 +243,43 @@ def test_accounting(self): usage = get_usage(self.configuration) self.assertNotEqual(usage, {}) - accounting = usage.get('accounting', {}) + accounting = usage.get("accounting", {}) test_user_accounting = accounting.get(TEST_CLIENT_DN, {}) self.assertNotEqual(test_user_accounting, {}) - home_total = test_user_accounting.get('home_total', 0) + home_total = test_user_accounting.get("home_total", 0) self.assertEqual(home_total, TEST_CLIENT_BYTES) - vgrid_total = test_user_accounting.get('vgrid_total', 0) + vgrid_total = test_user_accounting.get("vgrid_total", 0) self.assertEqual(vgrid_total, TEST_VGRID_TOTAL_BYTES) - ext_users_total = test_user_accounting.get('ext_users_total', 0) + ext_users_total = test_user_accounting.get("ext_users_total", 0) self.assertEqual(ext_users_total, TEST_EXT_BYTES) - freeze_total = test_user_accounting.get('freeze_total', 0) + freeze_total = test_user_accounting.get("freeze_total", 0) self.assertEqual(freeze_total, TEST_FREEZE_BYTES) - total_bytes = test_user_accounting.get('total_bytes', 0) + total_bytes = test_user_accounting.get("total_bytes", 0) self.assertEqual(total_bytes, TEST_TOTAL_BYTES) def test_human_readable_filesize_valid(self): """Test human-friendly format helper success on valid byte sizes""" - valid = [(0, "0 B"), (42, "42.000 B"), (2**10, "1.000 KiB"), - (2**30, "1.000 GiB"), (2**50, "1.000 PiB"), - (2**89, "512.000 YiB"), (2**90 - 2**70, "1023.999 YiB")] - for (size, expect) in valid: + valid = [ + (0, "0 B"), + (42, "42.000 B"), + (2**10, "1.000 KiB"), + (2**30, "1.000 GiB"), + (2**50, "1.000 PiB"), + (2**89, "512.000 YiB"), + (2**90 - 2**70, "1023.999 YiB"), + ] + for size, expect in valid: self.assertEqual(human_readable_filesize(size), expect) def test_human_readable_filesize_invalid(self): """Test human-friendly format helper failure on invalid byte sizes""" - invalid = [(i, "NaN") for i in [False, None, "", "one", -1, 1.2, 2**90, - 2**128]] - for (size, expect) in invalid: + invalid = [ + (i, "NaN") for i in [False, None, "", "one", -1, 1.2, 2**90, 2**128] + ] + for size, expect in invalid: self.assertEqual(human_readable_filesize(size), expect) diff --git a/tests/test_mig_lib_daemon.py b/tests/test_mig_lib_daemon.py index bd8cd901c..0cd55ea4e 100644 --- a/tests/test_mig_lib_daemon.py +++ b/tests/test_mig_lib_daemon.py @@ -31,10 +31,22 @@ import signal import time -from mig.lib.daemon import _run_event, _stop_event, check_run, check_stop, \ - do_run, interruptible_sleep, register_run_handler, register_stop_handler, \ - reset_run, reset_stop, run_handler, stop_handler, stop_running, \ - unregister_signal_handlers +from mig.lib.daemon import ( + _run_event, + _stop_event, + check_run, + check_stop, + do_run, + interruptible_sleep, + register_run_handler, + register_stop_handler, + reset_run, + reset_stop, + run_handler, + stop_handler, + stop_running, + unregister_signal_handlers, +) from tests.support import FakeConfiguration, FakeLogger, MigTestCase @@ -42,19 +54,25 @@ class MigLibDaemon(MigTestCase): """Unit tests for daemon related helper functions""" # Signals registered across the tests and explicitly unregistered on init - _used_signals = [signal.SIGCONT, signal.SIGINT, signal.SIGALRM, - signal.SIGABRT, signal.SIGUSR1, signal.SIGUSR2] + _used_signals = [ + signal.SIGCONT, + signal.SIGINT, + signal.SIGALRM, + signal.SIGABRT, + signal.SIGUSR1, + signal.SIGUSR2, + ] def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up any test configuration and reset state before each test""" # Create dummy sig and frame values for isolated test use - self.sig = 'SIGNAL' - self.frame = 'FRAME' + self.sig = "SIGNAL" + self.frame = "FRAME" # Reset event states reset_run() @@ -94,7 +112,7 @@ def test_interruptible_sleep(self): max_secs = 4.2 start = time.time() signal.alarm(1) - interruptible_sleep(self.configuration, max_secs, (check_run, )) + interruptible_sleep(self.configuration, max_secs, (check_run,)) self.assertTrue(check_run()) end = time.time() self.assertTrue(end - start < max_secs) @@ -303,14 +321,17 @@ def test_concurrent_event_handling(self): def test_interruptible_sleep_immediate_break(self): """Test interruptible_sleep with immediate break condition""" + def immediate_true(): return True start = time.time() interruptible_sleep(self.configuration, 5.0, [immediate_true]) duration = time.time() - start - self.assertTrue(duration < 0.1, - "Sleep should exit immediately but took %s" % duration) + self.assertTrue( + duration < 0.1, + "Sleep should exit immediately but took %s" % duration, + ) def test_reset_event_helpers(self): """Test simple event reset helpers""" @@ -337,36 +358,41 @@ def test_unregister_signal_handlers_explicit(self): register_stop_handler(self.configuration, signal.SIGABRT) # Verify handlers were set - self.assertEqual(signal.getsignal(signal.SIGALRM).__name__, - 'run_handler') - self.assertEqual(signal.getsignal(signal.SIGABRT).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGALRM).__name__, "run_handler" + ) + self.assertEqual( + signal.getsignal(signal.SIGABRT).__name__, "stop_handler" + ) # Unregister specific signals - unregister_signal_handlers(self.configuration, [signal.SIGALRM, - signal.SIGABRT]) + unregister_signal_handlers( + self.configuration, [signal.SIGALRM, signal.SIGABRT] + ) self.assertEqual(signal.getsignal(signal.SIGALRM), signal.SIG_IGN) self.assertEqual(signal.getsignal(signal.SIGABRT), signal.SIG_IGN) def test_interruptible_sleep_condition_after_interval(self): """Test interruptible_sleep break condition after one interval""" - state = {'count': 0} + state = {"count": 0} def counter_condition(): - state['count'] += 1 - return state['count'] >= 2 + state["count"] += 1 + return state["count"] >= 2 start = time.time() - interruptible_sleep(self.configuration, 5.0, [counter_condition], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 5.0, [counter_condition], nap_secs=0.1 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.2, delta=0.15) def test_interruptible_sleep_maxsecs_equals_napsecs(self): """Test interruptible_sleep with max_secs exactly matching nap_secs""" start = time.time() - interruptible_sleep(self.configuration, 0.1, [lambda: False], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 0.1, [lambda: False], nap_secs=0.1 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.1, delta=0.05) @@ -379,8 +405,9 @@ def faulty_condition(): self.logger.error(SLEEP_ERR) start = time.time() - interruptible_sleep(self.configuration, 0.1, [faulty_condition], - nap_secs=0.01) + interruptible_sleep( + self.configuration, 0.1, [faulty_condition], nap_secs=0.01 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.1, delta=0.05) try: @@ -402,22 +429,26 @@ def test_unregister_default_signals(self): self.assertEqual(signal.getsignal(signal.SIGCONT), signal.SIG_IGN) self.assertEqual(signal.getsignal(signal.SIGUSR2), signal.SIG_IGN) # Verify custom signals remain after default unregister - self.assertEqual(signal.getsignal(signal.SIGINT).__name__, - 'run_handler') - self.assertEqual(signal.getsignal(signal.SIGALRM).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGINT).__name__, "run_handler" + ) + self.assertEqual( + signal.getsignal(signal.SIGALRM).__name__, "stop_handler" + ) def test_register_default_signal(self): """Test handler registration with default signal values""" # Run handler should default to SIGCONT register_run_handler(self.configuration) - self.assertEqual(signal.getsignal(signal.SIGCONT).__name__, - 'run_handler') + self.assertEqual( + signal.getsignal(signal.SIGCONT).__name__, "run_handler" + ) # Stop handler should default to SIGINT register_stop_handler(self.configuration) - self.assertEqual(signal.getsignal(signal.SIGINT).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGINT).__name__, "stop_handler" + ) def test_reset_unregistered_signals(self): """Test unregister responds gracefully to previously unregistered signals""" @@ -440,21 +471,22 @@ def test_interruptible_sleep_break_not_callable(self): def test_interruptible_sleep_all_conditions_checked(self): """Verify all break conditions are checked each sleep interval""" - counter = {'count': 0} + counter = {"count": 0} max_checks = 3 def counter_condition(): - if counter['count'] < max_checks: - counter['count'] += 1 - return counter['count'] >= max_checks + if counter["count"] < max_checks: + counter["count"] += 1 + return counter["count"] >= max_checks start = time.time() - interruptible_sleep(self.configuration, 5.0, [counter_condition], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 5.0, [counter_condition], nap_secs=0.1 + ) duration = time.time() - start # Should run for ~0.3 sec (3 naps of 0.1 sec) self.assertAlmostEqual(duration, 0.3, delta=0.15) - self.assertEqual(counter['count'], max_checks) + self.assertEqual(counter["count"], max_checks) def test_interruptible_sleep_naps_remaining(self): """Test interruptible_sleep counts down remaining naps correctly""" @@ -517,21 +549,20 @@ def test_event_state_persistence(self): def test_signal_handler_dispatch(self): """Verify signal handlers dispatch correct signals""" - test_signals = { - 'run': [signal.SIGUSR1], - 'stop': [signal.SIGUSR2] - } + test_signals = {"run": [signal.SIGUSR1], "stop": [signal.SIGUSR2]} - for func, sigs in [(register_run_handler, test_signals['run']), - (register_stop_handler, test_signals['stop'])]: + for func, sigs in [ + (register_run_handler, test_signals["run"]), + (register_stop_handler, test_signals["stop"]), + ]: for sig in sigs: func(self.configuration, sig) # Verify handler registration dispatch = signal.getsignal(sig) if func == register_run_handler: - self.assertEqual(dispatch.__name__, 'run_handler') + self.assertEqual(dispatch.__name__, "run_handler") else: - self.assertEqual(dispatch.__name__, 'stop_handler') + self.assertEqual(dispatch.__name__, "stop_handler") def test_event_set_unset_lifecycle(self): """Verify full event lifecycle""" diff --git a/tests/test_mig_lib_events.py b/tests/test_mig_lib_events.py index 96ba8c91d..49bf6d5bb 100644 --- a/tests/test_mig_lib_events.py +++ b/tests/test_mig_lib_events.py @@ -32,12 +32,28 @@ import unittest # Imports of the code under test -from mig.lib.events import _restore_env, _save_env, at_remain, cron_match, \ - get_path_expand_map, get_time_expand_map, load_atjobs, load_crontab +from mig.lib.events import ( + _restore_env, + _save_env, + at_remain, + cron_match, + get_path_expand_map, + get_time_expand_map, + load_atjobs, + load_crontab, +) from mig.lib.events import main as events_main -from mig.lib.events import parse_and_save_atjobs, parse_and_save_crontab, \ - parse_atjobs, parse_atjobs_contents, parse_crontab, \ - parse_crontab_contents, run_cron_command, run_events_command +from mig.lib.events import ( + parse_and_save_atjobs, + parse_and_save_crontab, + parse_atjobs, + parse_atjobs_contents, + parse_crontab, + parse_crontab_contents, + run_cron_command, + run_events_command, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist @@ -356,8 +372,7 @@ def test_cron_match_with_wildcards(self): ), ] for job, expected in test_cases: - self.assertEqual(cron_match( - self.configuration, now, job), expected) + self.assertEqual(cron_match(self.configuration, now, job), expected) def test_cron_match_specific_time(self): """Test cron_match rejects non-matching time""" @@ -507,7 +522,8 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -519,7 +535,10 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -531,7 +550,8 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -543,7 +563,10 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ] for job, now in test_cases: @@ -643,7 +666,10 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -655,7 +681,8 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -667,7 +694,10 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -679,7 +709,8 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ] for job, now in test_cases: @@ -979,7 +1010,13 @@ def test_at_remain_with_past_leap_second(self): microsecond=0, ) t_plus_sixtyone = now.replace( - year=2017, month=1, day=1, hour=0, minute=1, second=0, microsecond=0 + year=2017, + month=1, + day=1, + hour=0, + minute=1, + second=0, + microsecond=0, ) self.assertEqual( (t_plus_sixty - t_minus_sixty).total_seconds(), @@ -1409,8 +1446,7 @@ def test_get_path_expand_map_with_relative_path(self): trigger_path = "../relative/path/file.txt" rule = {"vgrid_name": "test", "run_as": DUMMY_USER_DN} expanded = get_path_expand_map(trigger_path, rule, "modified") - self.assertEqual(expanded["+TRIGGERPATH+"], - "../relative/path/file.txt") + self.assertEqual(expanded["+TRIGGERPATH+"], "../relative/path/file.txt") self.assertEqual(expanded["+TRIGGERFILENAME+"], "file.txt") self.assertEqual(expanded["+TRIGGERPREFIX+"], "file") self.assertEqual(expanded["+TRIGGEREXTENSION+"], ".txt") @@ -1807,8 +1843,9 @@ def test_parse_and_save_crontab(self): def test_parse_atjobs(self): """Test parsing atjobs content lines""" parsed = parse_atjobs_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_ATJOBS_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_ATJOBS_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 1) self.assertEqual(parsed[0]["command"], ["/bin/future_command"]) @@ -1816,8 +1853,9 @@ def test_parse_atjobs(self): def test_parse_atjobs_contents(self): """Test parsing atjobs content lines""" parsed = parse_atjobs_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_ATJOBS_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_ATJOBS_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 1) self.assertEqual(parsed[0]["command"], ["/bin/future_command"]) @@ -1825,8 +1863,9 @@ def test_parse_atjobs_contents(self): def test_parse_crontab(self): """Test parsing crontab content lines""" parsed = parse_crontab_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_CRONTAB_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_CRONTAB_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0]["command"], ["/bin/test_command"]) @@ -1834,8 +1873,9 @@ def test_parse_crontab(self): def test_parse_crontab_contents(self): """Test parsing crontab content lines""" parsed = parse_crontab_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_CRONTAB_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_CRONTAB_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0]["command"], ["/bin/test_command"]) @@ -2655,7 +2695,10 @@ def test_run_cron_command_with_invalid_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -2980,7 +3023,10 @@ def test_run_cron_command_with_here_document(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3014,7 +3060,10 @@ def test_run_cron_command_with_subshell(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3074,7 +3123,10 @@ def test_run_cron_command_with_true_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3108,7 +3160,10 @@ def test_run_cron_command_with_false_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3142,7 +3197,10 @@ def test_run_cron_command_with_null_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3176,7 +3234,10 @@ def test_run_cron_command_with_builtin_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3210,7 +3271,10 @@ def test_run_cron_command_with_alias_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3244,7 +3308,10 @@ def test_run_cron_command_with_function_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3278,7 +3345,10 @@ def test_run_cron_command_with_reserved_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3295,7 +3365,10 @@ def test_run_cron_command_with_dot_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3329,7 +3402,10 @@ def test_run_cron_command_with_colon_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3346,7 +3422,10 @@ def test_run_cron_command_with_bracket_command(self): with self.assertLogs(level="ERROR") as log_capture: with self.assertRaises(Exception): run_cron_command( - command_list, target_path, crontab_entry, self.configuration + command_list, + target_path, + crontab_entry, + self.configuration, ) self.assertTrue( any( @@ -3387,6 +3466,7 @@ def before_each(self): def test_existing_main(self): """Wrap existing self-tests""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: diff --git a/tests/test_mig_lib_janitor.py b/tests/test_mig_lib_janitor.py index 59a522cde..2daaadbc5 100644 --- a/tests/test_mig_lib_janitor.py +++ b/tests/test_mig_lib_janitor.py @@ -32,18 +32,36 @@ import time import unittest -from mig.lib.janitor import EXPIRE_DUMMY_JOBS_DAYS, EXPIRE_REQ_DAYS, \ - EXPIRE_STATE_DAYS, EXPIRE_TWOFACTOR_DAYS, MANAGE_TRIVIAL_REQ_MINUTES, \ - REMIND_REQ_DAYS, SECS_PER_DAY, SECS_PER_HOUR, SECS_PER_MINUTE, \ - _clean_stale_state_files, _lookup_last_run, _update_last_run, \ - clean_mig_system_files, clean_no_job_helpers, \ - clean_sessid_to_mrls_link_home, clean_twofactor_sessions, \ - clean_webserver_home, handle_cache_updates, handle_janitor_tasks, \ - handle_pending_requests, handle_session_cleanup, handle_state_cleanup, \ - manage_single_req, manage_trivial_user_requests, \ - remind_and_expire_user_pending, task_triggers +from mig.lib.janitor import ( + EXPIRE_DUMMY_JOBS_DAYS, + EXPIRE_REQ_DAYS, + EXPIRE_STATE_DAYS, + EXPIRE_TWOFACTOR_DAYS, + MANAGE_TRIVIAL_REQ_MINUTES, + REMIND_REQ_DAYS, + SECS_PER_DAY, + SECS_PER_HOUR, + SECS_PER_MINUTE, + _clean_stale_state_files, + _lookup_last_run, + _update_last_run, + clean_mig_system_files, + clean_no_job_helpers, + clean_sessid_to_mrls_link_home, + clean_twofactor_sessions, + clean_webserver_home, + handle_cache_updates, + handle_janitor_tasks, + handle_pending_requests, + handle_session_cleanup, + handle_state_cleanup, + manage_single_req, + manage_trivial_user_requests, + remind_and_expire_user_pending, + task_triggers, +) from mig.shared.accountreq import save_account_request -from mig.shared.base import distinguished_name_to_user, client_id_dir +from mig.shared.base import client_id_dir, distinguished_name_to_user from mig.shared.pwcrypto import generate_reset_token from tests.support import MigTestCase, ensure_dirs_exist @@ -52,25 +70,31 @@ TEST_USER_ORG = "Test Org" TEST_USER_EMAIL = "test@example.com" # TODO: move next to support.usersupp? -TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=%s/OU=NA/CN=%s/emailAddress=%s' % \ - (TEST_USER_ORG, TEST_USER_FULLNAME, TEST_USER_EMAIL) -TEST_SKIP_EMAIL = '' +TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=%s/OU=NA/CN=%s/emailAddress=%s" % ( + TEST_USER_ORG, + TEST_USER_FULLNAME, + TEST_USER_EMAIL, +) +TEST_SKIP_EMAIL = "" # TODO: adjust password reset token helpers to handle configured services # it currently silently fails if not in migoid(c) or migcert # TEST_SERVICE = 'dummy-svc' -TEST_AUTH = TEST_SERVICE = 'migoid' -TEST_USERDB = 'MiG-users.db' -TEST_PEER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com' +TEST_AUTH = TEST_SERVICE = "migoid" +TEST_USERDB = "MiG-users.db" +TEST_PEER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com" # NOTE: these passwords are not and should not ever be used outside unit tests -TEST_MODERN_PW = 'NoSuchPassword_42' -TEST_MODERN_PW_PBKDF2 = \ +TEST_MODERN_PW = "NoSuchPassword_42" +TEST_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix" -TEST_NEW_MODERN_PW_PBKDF2 = \ +) +TEST_NEW_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf10n581pjXFHrn" -TEST_INVALID_PW_PBKDF2 = \ +) +TEST_INVALID_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf1rn1pjX0n58FH" +) # NOTE: tokens always should contain a multiple of 4 chars -INVALID_TEST_TOKEN = 'THIS_RESET_TOKEN_WAS_NEVER_VALID' +INVALID_TEST_TOKEN = "THIS_RESET_TOKEN_WAS_NEVER_VALID" class MigLibJanitor(MigTestCase): @@ -78,11 +102,11 @@ class MigLibJanitor(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" - def _prepare_test_file(self, path, times=None, content='test'): + def _prepare_test_file(self, path, times=None, content="test"): """Prepare file in path with optional times for timestamp""" - with open(path, 'w') as fp: + with open(path, "w") as fp: fp.write(content) os.utime(path, times) @@ -94,8 +118,9 @@ def before_each(self): self.configuration.site_login_methods.append(TEST_AUTH) # Prevent admin email during reject, etc. self.configuration.admin_email = TEST_SKIP_EMAIL - self.user_db_path = os.path.join(self.configuration.user_db_home, - TEST_USERDB) + self.user_db_path = os.path.join( + self.configuration.user_db_home, TEST_USERDB + ) # Create fake fs layout matching real systems ensure_dirs_exist(self.configuration.user_pending) ensure_dirs_exist(self.configuration.user_db_home) @@ -110,8 +135,9 @@ def before_each(self): ensure_dirs_exist(self.configuration.sessid_to_mrsl_link_home) ensure_dirs_exist(self.configuration.mrsl_files_dir) ensure_dirs_exist(self.configuration.resource_pending) - dummy_job = os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler") + dummy_job = os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ) ensure_dirs_exist(dummy_job) # Prepare user DB with a single dummy user for all tests @@ -124,20 +150,20 @@ def before_each(self): def test_last_run_bookkeeping(self): """Register a last run timestamp and check it""" expect = -1 - stamp = _lookup_last_run(self.configuration, 'janitor_task') + stamp = _lookup_last_run(self.configuration, "janitor_task") self.assertEqual(stamp, expect) expect = 42 - stamp = _update_last_run(self.configuration, 'janitor_task', expect) + stamp = _update_last_run(self.configuration, "janitor_task", expect) self.assertEqual(stamp, expect) expect = time.time() - stamp = _update_last_run(self.configuration, 'janitor_task', expect) + stamp = _update_last_run(self.configuration, "janitor_task", expect) self.assertEqual(stamp, expect) def test_clean_mig_system_files(self): """Test clean_mig system files helper""" test_time = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 - valid_filenames = ['fresh.log', 'current.tmp'] - stale_filenames = ['tmp_expired.txt', 'no_grid_jobs.123'] + valid_filenames = ["fresh.log", "current.tmp"] + stale_filenames = ["tmp_expired.txt", "no_grid_jobs.123"] for name in valid_filenames + stale_filenames: path = os.path.join(self.configuration.mig_system_files, name) self._prepare_test_file(path, (test_time, test_time)) @@ -145,8 +171,10 @@ def test_clean_mig_system_files(self): handled = clean_mig_system_files(self.configuration) self.assertEqual(handled, len(stale_filenames)) - self.assertEqual(len(os.listdir(self.configuration.mig_system_files)), - len(valid_filenames)) + self.assertEqual( + len(os.listdir(self.configuration.mig_system_files)), + len(valid_filenames), + ) for name in valid_filenames: path = os.path.join(self.configuration.mig_system_files, name) self.assertTrue(os.path.exists(path)) @@ -158,8 +186,8 @@ def test_clean_webserver_home(self): """Test clean webserver files helper""" stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.webserver_home - valid_filename = 'fresh.log' - stale_filename = 'stale.log' + valid_filename = "fresh.log" + stale_filename = "stale.log" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -175,10 +203,11 @@ def test_clean_webserver_home(self): def test_clean_no_job_helpers(self): """Test clean dummy job helper files""" stale_stamp = time.time() - EXPIRE_DUMMY_JOBS_DAYS * SECS_PER_DAY - 1 - test_dir = os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler") - valid_filename = 'alive.txt' - stale_filename = 'expired.txt' + test_dir = os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ) + valid_filename = "alive.txt" + stale_filename = "expired.txt" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -195,8 +224,8 @@ def test_clean_twofactor_sessions(self): """Test clean twofactor sessions""" stale_stamp = time.time() - EXPIRE_TWOFACTOR_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.twofactor_home - valid_filename = 'current' - stale_filename = 'expired' + valid_filename = "current" + stale_filename = "expired" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -213,8 +242,8 @@ def test_clean_sessid_to_mrls_link_home(self): """Test clean session MRSL link files""" stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.sessid_to_mrsl_link_home - valid_filename = 'active_session_link' - stale_filename = 'expired_session_link' + valid_filename = "active_session_link" + stale_filename = "expired_session_link" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -232,12 +261,14 @@ def test_handle_state_cleanup(self): # Create a stale file in each location to clean up stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 mig_path = os.path.join( - self.configuration.mig_system_files, 'tmpAbCd1234') - web_path = os.path.join(self.configuration.webserver_home, 'stale.txt') + self.configuration.mig_system_files, "tmpAbCd1234" + ) + web_path = os.path.join(self.configuration.webserver_home, "stale.txt") empty_job_path = os.path.join( - os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler"), - 'sleep.job' + os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ), + "sleep.job", ) stale_paths = [mig_path, web_path, empty_job_path] for path in stale_paths: @@ -252,12 +283,17 @@ def test_handle_state_cleanup(self): def test_handle_session_cleanup(self): """Test combined session cleanup""" - stale_stamp = time.time() - max(EXPIRE_STATE_DAYS, - EXPIRE_TWOFACTOR_DAYS) * SECS_PER_DAY - 1 + stale_stamp = ( + time.time() + - max(EXPIRE_STATE_DAYS, EXPIRE_TWOFACTOR_DAYS) * SECS_PER_DAY + - 1 + ) session_path = os.path.join( - self.configuration.sessid_to_mrsl_link_home, 'expired.txt') + self.configuration.sessid_to_mrsl_link_home, "expired.txt" + ) twofactor_path = os.path.join( - self.configuration.twofactor_home, 'expired.txt') + self.configuration.twofactor_home, "expired.txt" + ) test_paths = [session_path, twofactor_path] for path in test_paths: os.makedirs(os.path.dirname(path), exist_ok=True) @@ -271,17 +307,17 @@ def test_handle_session_cleanup(self): def test_manage_pending_user_request(self): """Test pending user request management""" - req_id = 'req_id' + req_id = "req_id" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) @@ -293,24 +329,25 @@ def test_manage_pending_user_request(self): os.utime(req_path, (req_age, req_age)) # Need user DB and path to simulate existing user - user_dir = os.path.join(self.configuration.user_home, - client_id_dir(TEST_USER_DN)) + user_dir = os.path.join( + self.configuration.user_home, client_id_dir(TEST_USER_DN) + ) os.makedirs(user_dir, exist_ok=True) handled = manage_trivial_user_requests(self.configuration) self.assertEqual(handled, 1) def test_expire_user_pending(self): """Test pending user request expiration reminders""" - req_id = 'expired_req' + req_id = "expired_req" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) saved, req_path = save_account_request(self.configuration, req_dict) @@ -333,42 +370,46 @@ def test_handle_pending_requests(self): """Test combined request handling""" # Create requests (valid, expired) valid_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) - saved, valid_req_path = save_account_request(self.configuration, - valid_dict) + saved, valid_req_path = save_account_request( + self.configuration, valid_dict + ) self.assertTrue(saved, "failed to save valid req") self.assertDirNotEmpty(self.configuration.user_pending) valid_id = os.path.basename(valid_req_path) - expired_id = 'expired_req' + expired_id = "expired_req" expired_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } saved, expired_req_path = save_account_request( - self.configuration, expired_dict) + self.configuration, expired_dict + ) self.assertTrue(saved, "failed to save expired req") expired_id = os.path.basename(expired_req_path) # Make just one old enough to expire expire_time = time.time() - EXPIRE_REQ_DAYS * SECS_PER_DAY - 1 - os.utime(os.path.join(self.configuration.user_pending, expired_id), - (expire_time, expire_time)) + os.utime( + os.path.join(self.configuration.user_pending, expired_id), + (expire_time, expire_time), + ) # NOTE: when using real user mail we currently hit send email errors. # We forgive those errors here and only check any known warnings. @@ -383,25 +424,29 @@ def test_handle_janitor_tasks_full(self): """Test full janitor task scheduler""" # Prepare environment with pending tasks of each kind mig_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 - mig_path = os.path.join(self.configuration.mig_system_files, - 'tmp-stale.txt') - two_path = os.path.join(self.configuration.twofactor_home, 'stale.txt') + mig_path = os.path.join( + self.configuration.mig_system_files, "tmp-stale.txt" + ) + two_path = os.path.join(self.configuration.twofactor_home, "stale.txt") two_stamp = time.time() - EXPIRE_TWOFACTOR_DAYS * SECS_PER_DAY - 1 - stale_tests = ((mig_path, mig_stamp), (two_path, two_stamp), ) - for (stale_path, stale_stamp) in stale_tests: + stale_tests = ( + (mig_path, mig_stamp), + (two_path, two_stamp), + ) + for stale_path, stale_stamp in stale_tests: self._prepare_test_file(stale_path, (stale_stamp, stale_stamp)) self.assertTrue(os.path.exists(stale_path)) - req_id = 'expired_request' + req_id = "expired_request" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) saved, req_path = save_account_request(self.configuration, req_dict) @@ -410,8 +455,10 @@ def test_handle_janitor_tasks_full(self): req_id = os.path.basename(req_path) # Make request very old req_age = time.time() - EXPIRE_REQ_DAYS * SECS_PER_DAY - 1 - os.utime(os.path.join(self.configuration.user_pending, req_id), - (req_age, req_age)) + os.utime( + os.path.join(self.configuration.user_pending, req_id), + (req_age, req_age), + ) # Set no last run timestamps to trigger all tasks now = time.time() @@ -426,26 +473,26 @@ def test_handle_janitor_tasks_full(self): handled = handle_janitor_tasks(self.configuration, now=now) # self.assertEqual(handled, 3) # state+session+requests self.assertEqual(handled, 5) # state+session+3*request - for (stale_path, _) in stale_tests: + for stale_path, _ in stale_tests: self.assertFalse(os.path.exists(stale_path), stale_path) def test__clean_stale_state_files(self): """Test core stale state file cleaner helper""" - test_dir = self.temppath('stale_state_test', ensure_dir=True) - patterns = ['tmp_*', 'session_*'] + test_dir = self.temppath("stale_state_test", ensure_dir=True) + patterns = ["tmp_*", "session_*"] # Create test files (fresh, expired, unexpired, non-matching) test_remove = [ - ('tmp_expired.txt', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), - ('session_old.dat', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("tmp_expired.txt", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("session_old.dat", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), ] test_keep = [ - ('tmp_fresh.txt', -1), - ('session_valid.dat', 0), - ('other_file.log', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("tmp_fresh.txt", -1), + ("session_valid.dat", 0), + ("other_file.log", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), ] - for (name, age_diff) in test_keep + test_remove: + for name, age_diff in test_keep + test_remove: path = os.path.join(test_dir, name) stamp = time.time() - age_diff self._prepare_test_file(path, (stamp, stamp)) @@ -457,27 +504,27 @@ def test__clean_stale_state_files(self): patterns, EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=False + include_dotfiles=False, ) self.assertEqual(handled, 2) # tmp_expired.txt + session_old.dat - for (name, _) in test_keep: + for name, _ in test_keep: path = os.path.join(test_dir, name) self.assertTrue(os.path.exists(path)) - for (name, _) in test_remove: + for name, _ in test_remove: path = os.path.join(test_dir, name) self.assertFalse(os.path.exists(path)) def test_manage_single_req_invalid(self): """Test request handling for invalid request""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'invalid': ['Missing required field: organization'], - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "invalid": ["Missing required field: organization"], + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -486,17 +533,18 @@ def test_manage_single_req_invalid(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('invalid account request' in msg - for msg in log_capture.output)) + self.assertTrue( + any("invalid account request" in msg for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed to clean invalid req for %s" % req_path) @@ -504,24 +552,24 @@ def test_manage_single_req_invalid(self): def test_manage_single_req_expired_token(self): """Test request handling with expired reset token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, } # Mimic proper but old expired token timestamp = 42 # IMPORTANT: we can't use a fixed token here due to dynamic crypto seed - req_dict['reset_token'] = generate_reset_token(self.configuration, - req_dict, TEST_SERVICE, - timestamp) + req_dict["reset_token"] = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -529,17 +577,21 @@ def test_manage_single_req_expired_token(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('reject expired reset token' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "reject expired reset token" in msg + for msg in log_capture.output + ) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed to clean token req for %s" % req_path) @@ -548,52 +600,57 @@ def test_manage_single_req_expired_token(self): def test_manage_single_req_invalid_token(self): """Test request handling with invalid reset token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() - SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() - SECS_PER_DAY, } # Inject known invalid reset token - req_dict['reset_token'] = INVALID_TEST_TOKEN + req_dict["reset_token"] = INVALID_TEST_TOKEN # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('reset with bad token' in msg - for msg in log_capture.output)) - self.assertFalse(os.path.exists(req_path), - "Failed to clean token req for %s" % req_path) + self.assertTrue( + any("reset with bad token" in msg for msg in log_capture.output) + ) + self.assertFalse( + os.path.exists(req_path), + "Failed to clean token req for %s" % req_path, + ) def test_manage_single_req_collision(self): """Test request handling with existing user collision""" # Create collision with the already provisioned user with TEST_USER_DN changed_full_name = "Changed Test Name" req_dict = { - 'client_id': TEST_USER_DN.replace(TEST_USER_FULLNAME, - changed_full_name), - 'distinguished_name': TEST_USER_DN.replace(TEST_USER_FULLNAME, - changed_full_name), - 'auth': [TEST_AUTH], - 'full_name': changed_full_name, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN.replace( + TEST_USER_FULLNAME, changed_full_name + ), + "distinguished_name": TEST_USER_DN.replace( + TEST_USER_FULLNAME, changed_full_name + ), + "auth": [TEST_AUTH], + "full_name": changed_full_name, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to cause collision - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -602,16 +659,17 @@ def test_manage_single_req_collision(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), + ) + self.assertTrue( + any("ID collision" in msg for msg in log_capture.output) ) - self.assertTrue(any('ID collision' in msg - for msg in log_capture.output)) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup collision for %s" % req_path) @@ -619,20 +677,20 @@ def test_manage_single_req_collision(self): def test_manage_single_req_auth_change(self): """Test request handling with auth password change""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password': '', - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password": "", + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, } # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 - req_dict['authorized'] = True + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["authorized"] = True saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -640,19 +698,20 @@ def test_manage_single_req_auth_change(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue( - any('accepted' in msg for msg in log_capture.output)) - self.assertFalse(os.path.exists(req_path), - "Failed to clean token req for %s" % req_path) + self.assertTrue(any("accepted" in msg for msg in log_capture.output)) + self.assertFalse( + os.path.exists(req_path), + "Failed to clean token req for %s" % req_path, + ) def test_handle_cache_updates_stub(self): """Test handle_cache_updates placeholder returns zero""" @@ -662,7 +721,7 @@ def test_handle_cache_updates_stub(self): def test_janitor_update_timestamps(self): """Test task trigger timestamp updates in janitor""" now = time.time() - task = 'test-task' + task = "test-task" # Initial state stamp = _lookup_last_run(self.configuration, task) @@ -678,24 +737,24 @@ def test_janitor_update_timestamps(self): def test__clean_stale_state_files_edge(self): """Test state file cleaner with special cases""" - test_dir = self.temppath('edge_case_test', ensure_dir=True) + test_dir = self.temppath("edge_case_test", ensure_dir=True) # Dot file - dot_path = os.path.join(test_dir, '.hidden.tmp') + dot_path = os.path.join(test_dir, ".hidden.tmp") stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 self._prepare_test_file(dot_path, (stamp, stamp)) # Directory - dir_path = os.path.join(test_dir, 'subdir') + dir_path = os.path.join(test_dir, "subdir") os.makedirs(dir_path) handled = _clean_stale_state_files( self.configuration, test_dir, - ['*'], + ["*"], EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=False + include_dotfiles=False, ) self.assertEqual(handled, 0) @@ -703,46 +762,50 @@ def test__clean_stale_state_files_edge(self): handled = _clean_stale_state_files( self.configuration, test_dir, - ['*'], + ["*"], EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=True + include_dotfiles=True, ) self.assertEqual(handled, 1) @unittest.skip("TODO: enable once unpickling error handling is improved") def test_manage_single_req_corrupted_file(self): """Test manage_single_req with corrupted request file""" - req_id = 'corrupted_req' + req_id = "corrupted_req" req_path = os.path.join(self.configuration.user_pending, req_id) - with open(req_path, 'w') as fp: - fp.write('invalid pickle content') + with open(req_path, "w") as fp: + fp.write("invalid pickle content") - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('Failed to load request from' in msg - or 'Could not load saved request' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "Failed to load request from" in msg + or "Could not load saved request" in msg + for msg in log_capture.output + ) + ) self.assertFalse(os.path.exists(req_path)) def test_manage_single_req_nonexistent_userdb(self): """Test manage_single_req with missing user database""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -750,38 +813,39 @@ def test_manage_single_req_nonexistent_userdb(self): # Remove user database os.remove(self.user_db_path) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('Failed to load user DB' in msg - for msg in log_capture.output)) + self.assertTrue( + any("Failed to load user DB" in msg for msg in log_capture.output) + ) def test_verify_reset_token_failure_logging(self): """Test token verification failure creates proper log entries""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, # Future expiration + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, # Future expiration } timestamp = time.time() # Now change to another pw hash and generate invalid token from it - req_dict['password_hash'] = TEST_INVALID_PW_PBKDF2 - req_dict['reset_token'] = generate_reset_token(self.configuration, - req_dict, TEST_SERVICE, - timestamp) + req_dict["password_hash"] = TEST_INVALID_PW_PBKDF2 + req_dict["reset_token"] = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -790,17 +854,18 @@ def test_verify_reset_token_failure_logging(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('wrong hash' in msg.lower() - for msg in log_capture.output)) + self.assertTrue( + any("wrong hash" in msg.lower() for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup invalid token for %s" % req_path) @@ -808,23 +873,24 @@ def test_verify_reset_token_failure_logging(self): def test_verify_reset_token_success(self): """Test token verification success with valid token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'email': TEST_USER_EMAIL, - 'password': '', - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, # Future expiration + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "email": TEST_USER_EMAIL, + "password": "", + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, # Future expiration } timestamp = time.time() - reset_token = generate_reset_token(self.configuration, req_dict, - TEST_SERVICE, timestamp) - req_dict['reset_token'] = reset_token + reset_token = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) + req_dict["reset_token"] = reset_token # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -832,17 +898,18 @@ def test_verify_reset_token_success(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('accepted' in msg.lower() - for msg in log_capture.output)) + self.assertTrue( + any("accepted" in msg.lower() for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup invalid token for %s" % req_path) @@ -851,23 +918,22 @@ def test_remind_and_expire_edge_cases(self): """Test request expiration with exact boundary timestamps""" now = time.time() test_cases = [ - ('exact_remind', now - REMIND_REQ_DAYS * SECS_PER_DAY), - ('exact_expire', now - EXPIRE_REQ_DAYS * SECS_PER_DAY), + ("exact_remind", now - REMIND_REQ_DAYS * SECS_PER_DAY), + ("exact_expire", now - EXPIRE_REQ_DAYS * SECS_PER_DAY), ] - for (req_id, mtime) in test_cases: + for req_id, mtime in test_cases: req_path = os.path.join(self.configuration.user_pending, req_id) req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "email": TEST_USER_EMAIL, } - saved, req_path = save_account_request( - self.configuration, req_dict) + saved, req_path = save_account_request(self.configuration, req_dict) os.utime(req_path, (mtime, mtime)) # NOTE: when using real user mail we currently hit send email errors. @@ -884,29 +950,43 @@ def test_handle_janitor_tasks_time_thresholds(self): """Test janitor task frequency thresholds""" now = time.time() - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "pending-reqs"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), -1) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "pending-reqs"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), -1 + ) # Test all tasks EXCEPT cache-updates are past threshold last_state_cleanup = now - SECS_PER_DAY - 3 last_session_cleanup = now - SECS_PER_HOUR - 3 last_pending_reqs = now - SECS_PER_MINUTE - 3 last_cache_update = now - SECS_PER_MINUTE + 10 # Not expired - task_triggers.update({'state-cleanup': last_state_cleanup, - 'session-cleanup': last_session_cleanup, - 'pending-reqs': last_pending_reqs, - 'cache-updates': last_cache_update}) - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), last_state_cleanup) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), last_session_cleanup) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), last_cache_update) + task_triggers.update( + { + "state-cleanup": last_state_cleanup, + "session-cleanup": last_session_cleanup, + "pending-reqs": last_pending_reqs, + "cache-updates": last_cache_update, + } + ) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), + last_state_cleanup, + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), + last_session_cleanup, + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), + last_cache_update, + ) # TODO: handled does NOT count no action runs - add dummies to handle? handled = handle_janitor_tasks(self.configuration, now=now) @@ -914,26 +994,32 @@ def test_handle_janitor_tasks_time_thresholds(self): self.assertEqual(handled, 0) # ran with nothing to do # Verify last run timestamps updated - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "pending-reqs"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), last_cache_update) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "pending-reqs"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), + last_cache_update, + ) @unittest.skip("TODO: enable once cleaner has improved error handling") def test_clean_stale_files_nonexistent_dir(self): """Test state cleaner with invalid directory path""" - target_dir = os.path.join(self.configuration.mig_system_files, - "non_existing_dir") + target_dir = os.path.join( + self.configuration.mig_system_files, "non_existing_dir" + ) handled = _clean_stale_state_files( self.configuration, target_dir, ["*"], EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 0) @@ -947,13 +1033,13 @@ def test_clean_stale_files_permission_error(self): stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 self._prepare_test_file(test_path, (stamp, stamp)) - with self.assertLogs(level='ERROR'): + with self.assertLogs(level="ERROR"): handled = _clean_stale_state_files( self.configuration, test_dir, ["*"], EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 0) @@ -976,14 +1062,14 @@ def test_handle_empty_pending_dir(self): def test_janitor_task_cleanup_after_reject(self): """Verify proper cleanup after request rejection""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'invalid': ['Test intentional invalid'], - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "invalid": ["Test intentional invalid"], + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -1000,7 +1086,7 @@ def test_janitor_task_cleanup_after_reject(self): req_id, req_path, self.user_db_path, - time.time() + time.time(), ) # TODO: enable check for removed req once skip email allows it @@ -1009,15 +1095,15 @@ def test_janitor_task_cleanup_after_reject(self): def test_cleaner_with_multiple_patterns(self): """Test state cleaner with multiple filename patterns""" - test_dir = self.temppath('multi_pattern_test', ensure_dir=True) - clean_patterns = ['*.tmp', '*.log', 'temp*'] + test_dir = self.temppath("multi_pattern_test", ensure_dir=True) + clean_patterns = ["*.tmp", "*.log", "temp*"] clean_pairs = [ - ('should_keep_recent.log', EXPIRE_STATE_DAYS - 1), - ('should_remove_stale.tmp', EXPIRE_STATE_DAYS + 1), - ('should_keep_other.pck', EXPIRE_STATE_DAYS + 1) + ("should_keep_recent.log", EXPIRE_STATE_DAYS - 1), + ("should_remove_stale.tmp", EXPIRE_STATE_DAYS + 1), + ("should_keep_other.pck", EXPIRE_STATE_DAYS + 1), ] - for (name, age_days) in clean_pairs: + for name, age_days in clean_pairs: path = os.path.join(test_dir, name) stamp = time.time() - age_days * SECS_PER_DAY self._prepare_test_file(path, (stamp, stamp)) @@ -1028,15 +1114,18 @@ def test_cleaner_with_multiple_patterns(self): test_dir, clean_patterns, EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 1) - self.assertTrue(os.path.exists( - os.path.join(test_dir, 'should_keep_recent.log'))) - self.assertFalse(os.path.exists( - os.path.join(test_dir, 'should_remove_stale.tmp'))) - self.assertTrue(os.path.exists( - os.path.join(test_dir, 'should_keep_other.pck'))) + self.assertTrue( + os.path.exists(os.path.join(test_dir, "should_keep_recent.log")) + ) + self.assertFalse( + os.path.exists(os.path.join(test_dir, "should_remove_stale.tmp")) + ) + self.assertTrue( + os.path.exists(os.path.join(test_dir, "should_keep_other.pck")) + ) def test_absent_jobs_flag(self): """Test clean_no_job_helpers with site_enable_jobs disabled""" diff --git a/tests/test_mig_lib_quota.py b/tests/test_mig_lib_quota.py index e4057b5d1..c99b0a920 100644 --- a/tests/test_mig_lib_quota.py +++ b/tests/test_mig_lib_quota.py @@ -29,6 +29,7 @@ # Imports of the code under test from mig.lib.quota import update_quota + # Imports required for the unit tests themselves from tests.support import MigTestCase @@ -38,7 +39,7 @@ class MigLibQouta(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up test configuration and reset state before each test""" @@ -47,7 +48,9 @@ def before_each(self): def test_invalid_quota_backend(self): """Test invalid quota_backend in configuration""" self.configuration.quota_backend = "NEVERNEVER" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: update_quota(self.configuration) - self.assertTrue("'NEVERNEVER' not in supported_quota_backends:" in msg - for msg in log_capture.output) + self.assertTrue( + "'NEVERNEVER' not in supported_quota_backends:" in msg + for msg in log_capture.output + ) diff --git a/tests/test_mig_lib_xgicore.py b/tests/test_mig_lib_xgicore.py index e63e22b7d..106b08eea 100644 --- a/tests/test_mig_lib_xgicore.py +++ b/tests/test_mig_lib_xgicore.py @@ -30,9 +30,8 @@ import os import sys -from tests.support import MigTestCase, FakeConfiguration, testmain - from mig.lib.xgicore import * +from tests.support import FakeConfiguration, MigTestCase, testmain class MigLibXgicore__get_output_format(MigTestCase): @@ -42,28 +41,32 @@ def test_default_when_missing(self): """Test that default output_format is returned when not set.""" expected = "html" user_args = {} - actual = get_output_format(FakeConfiguration(), user_args, - default_format=expected) - self.assertEqual(actual, expected, - "mismatch in default output_format") + actual = get_output_format( + FakeConfiguration(), user_args, default_format=expected + ) + self.assertEqual(actual, expected, "mismatch in default output_format") def test_get_single_requested_format(self): """Test that the requested output_format is returned.""" expected = "file" - user_args = {'output_format': [expected]} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) def test_get_first_requested_format(self): """Test that first requested output_format is returned.""" expected = "file" - user_args = {'output_format': [expected, 'BOGUS']} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected, "BOGUS"]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) class MigLibXgicore__override_output_format(MigTestCase): @@ -74,29 +77,35 @@ def test_unchanged_without_override(self): expected = "html" user_args = {} out_objs = [] - actual = override_output_format(FakeConfiguration(), user_args, - out_objs, expected) - self.assertEqual(actual, expected, - "mismatch in unchanged output_format") + actual = override_output_format( + FakeConfiguration(), user_args, out_objs, expected + ) + self.assertEqual( + actual, expected, "mismatch in unchanged output_format" + ) def test_get_single_requested_format(self): """Test that the requested output_format is returned if overriden.""" expected = "file" - user_args = {'output_format': [expected]} - out_objs = [{'object_type': 'start', 'override_format': True}] - actual = override_output_format(FakeConfiguration(), user_args, - out_objs, 'OVERRIDE') - self.assertEqual(actual, expected, - "mismatch in overriden output_format") + user_args = {"output_format": [expected]} + out_objs = [{"object_type": "start", "override_format": True}] + actual = override_output_format( + FakeConfiguration(), user_args, out_objs, "OVERRIDE" + ) + self.assertEqual( + actual, expected, "mismatch in overriden output_format" + ) def test_get_first_requested_format(self): """Test that first requested output_format is returned if overriden.""" expected = "file" - user_args = {'output_format': [expected, 'BOGUS']} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected, "BOGUS"]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) class MigLibXgicore__fill_start_headers(MigTestCase): @@ -105,35 +114,41 @@ class MigLibXgicore__fill_start_headers(MigTestCase): def test_unchanged_when_set(self): """Test that existing valid start entry is returned as-is.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream'), - ('Content-Size', 42)] - expected = {'object_type': 'start', 'headers': headers} - out_objs = [expected, {'object_type': 'binary', 'data': 42*b'0'}] + headers = [ + ("Content-Type", "application/octet-stream"), + ("Content-Size", 42), + ] + expected = {"object_type": "start", "headers": headers} + out_objs = [expected, {"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in unchanged start entry") + self.assertEqual(actual, expected, "mismatch in unchanged start entry") def test_headers_added_when_missing(self): """Test that start entry headers are added if missing.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream')] - minimal_start = {'object_type': 'start'} - expected = {'object_type': 'start', 'headers': headers} - out_objs = [minimal_start, {'object_type': 'binary', 'data': 42*b'0'}] + headers = [("Content-Type", "application/octet-stream")] + minimal_start = {"object_type": "start"} + expected = {"object_type": "start", "headers": headers} + out_objs = [ + minimal_start, + {"object_type": "binary", "data": 42 * b"0"}, + ] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in auto initialized start entry") + self.assertEqual( + actual, expected, "mismatch in auto initialized start entry" + ) def test_start_added_when_missing(self): """Test that start entry is added if missing.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream')] - expected = {'object_type': 'start', 'headers': headers} - out_objs = [{'object_type': 'binary', 'data': 42*b'0'}] + headers = [("Content-Type", "application/octet-stream")] + expected = {"object_type": "start", "headers": headers} + out_objs = [{"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in auto initialized start entry") + self.assertEqual( + actual, expected, "mismatch in auto initialized start entry" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_accountreq.py b/tests/test_mig_shared_accountreq.py index ecba42ff3..0b7a6919d 100644 --- a/tests/test_mig_shared_accountreq.py +++ b/tests/test_mig_shared_accountreq.py @@ -35,9 +35,11 @@ # Imports of the code under test import mig.shared.accountreq as accountreq + # Imports required for the unit test wrapping from mig.shared.base import distinguished_name_to_user, fill_distinguished_name from mig.shared.defaults import keyword_auto + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain from tests.support.fixturesupp import FixtureAssertMixin @@ -47,8 +49,8 @@ class MigSharedAccountreq__peers(MigTestCase, FixtureAssertMixin): """Unit tests for peers related functions within the accountreq module""" - TEST_PEER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com' - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_PEER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com" + TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" @property def user_settings_dir(self): @@ -60,40 +62,46 @@ def user_pending_dir(self): def _load_saved_peer(self, absolute_path): self.assertPathWithin(absolute_path, start=self.user_pending_dir) - with open(absolute_path, 'rb') as pickle_file: + with open(absolute_path, "rb") as pickle_file: value = pickle.load(pickle_file) def _string_if_bytes(value): if isinstance(value, bytes): - return str(value, 'utf8') + return str(value, "utf8") else: return value - return {_string_if_bytes(x): _string_if_bytes(y) - for x, y in value.items()} + + return { + _string_if_bytes(x): _string_if_bytes(y) for x, y in value.items() + } def _peer_dict_from_fixture(self): - prepared_fixture = self.prepareFixtureAssert("peer_user_dict", - fixture_format="json") + prepared_fixture = self.prepareFixtureAssert( + "peer_user_dict", fixture_format="json" + ) fixture_data = prepared_fixture.fixture_data assert fixture_data["distinguished_name"] == self.TEST_PEER_DN return fixture_data - def _record_peer_acceptance(self, test_client_dir_name, - peer_distinguished_name): - """Fabricate a peer acceptance record in a particular user settings dir. - """ + def _record_peer_acceptance( + self, test_client_dir_name, peer_distinguished_name + ): + """Fabricate a peer acceptance record in a particular user settings dir.""" test_user_accepted_peers_file = os.path.join( - self.user_settings_dir, test_client_dir_name, "peers") + self.user_settings_dir, test_client_dir_name, "peers" + ) expire_tomorrow = datetime.date.today() + datetime.timedelta(days=1) - with open(test_user_accepted_peers_file, "wb") as \ - test_user_accepted_peers: - pickle.dump({peer_distinguished_name: - {'expire': str(expire_tomorrow)}}, - test_user_accepted_peers) + with open( + test_user_accepted_peers_file, "wb" + ) as test_user_accepted_peers: + pickle.dump( + {peer_distinguished_name: {"expire": str(expire_tomorrow)}}, + test_user_accepted_peers, + ) def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): ensure_dirs_exist(self.configuration.user_cache) @@ -108,8 +116,9 @@ def test_a_new_peer(self): self.assertDirEmpty(self.configuration.user_pending) request_dict = self._peer_dict_from_fixture() - success, _ = accountreq.save_account_request(self.configuration, - request_dict) + success, _ = accountreq.save_account_request( + self.configuration, request_dict + ) # check that we have an output directory now absolute_files = self.assertDirNotEmpty(self.user_pending_dir) @@ -131,10 +140,11 @@ def test_listing_peers(self): # check the fabricated peer was listed # sadly listing returns _relative_ dirs peer_temp_file_name = listing[0] - peer_pickle_file = os.path.join(self.user_pending_dir, - peer_temp_file_name) + peer_pickle_file = os.path.join( + self.user_pending_dir, peer_temp_file_name + ) peer_pickle = self._load_saved_peer(peer_pickle_file) - self.assertEqual(peer_pickle['distinguished_name'], self.TEST_PEER_DN) + self.assertEqual(peer_pickle["distinguished_name"], self.TEST_PEER_DN) def test_peer_acceptance(self): test_client_dir = self._provision_test_user(self, self.TEST_USER_DN) @@ -142,17 +152,18 @@ def test_peer_acceptance(self): self._record_peer_acceptance(test_client_dir_name, self.TEST_PEER_DN) self.assertDirEmpty(self.user_pending_dir) request_dict = self._peer_dict_from_fixture() - success, req_path = accountreq.save_account_request(self.configuration, - request_dict) + success, req_path = accountreq.save_account_request( + self.configuration, request_dict + ) arranged_req_id = os.path.basename(req_path) # NOTE: when using real user mail we currently hit send email errors. # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - success, message = accountreq.accept_account_req(arranged_req_id, - self.configuration, - keyword_auto) + success, message = accountreq.accept_account_req( + arranged_req_id, self.configuration, keyword_auto + ) self.assertTrue(success) @@ -160,29 +171,44 @@ def test_peer_acceptance(self): class MigSharedAccountreq__filters(MigTestCase, UserAssertMixin): """Unit tests for filter related functions within the accountreq module""" - TEST_SERVICE = 'migoid' - TEST_INTERNAL_DN = '/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=Test Name/emailAddress=test@local.org' - TEST_EXTERNAL_DN = '/C=DK/ST=NA/L=NA/O=External Org/OU=NA/CN=Test User/emailAddress=test@external.org' - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' - TEST_ADMIN_DN = '/C=DK/ST=NA/L=NA/O=DIKU/OU=NA/CN=Test Admin/emailAddress=siteadm@di.ku.dk' - - TEST_INT_PW = 'PW74deb6609F109f504d' - TEST_EXT_PW = 'PW174db6509F109e1531' - TEST_USER_PW = 'foobar' - TEST_INT_PW_HASH = 'PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$epib2rEg/HYTQZFnCp7hmIGZ6rzHnViy' - TEST_EXT_PW_HASH = 'PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$TQZFnCp7hmIGZ6ep2rEg/HYrzHnVyiib' - TEST_USER_PW_HASH = 'PBKDF2$sha256$10000$/TkhLk4yMGf6XhaY$7HUeQ9iwCkE4YMQAaCd+ZdrN+y8EzkJH' - - TEST_INTERNAL_EMAILS = ['john.doe@science.ku.dk', 'abc123@ku.dk', - 'john.doe@a.b.c.ku.dk'] - TEST_EXTERNAL_EMAILS = ['john@doe.org', 'a@b.c.org', 'a@ku.dk.com', - 'a@sci.ku.dk.org', 'a@diku.dk', 'a@nbi.dk'] - TEST_EXTERNAL_EMAIL_PATTERN = r'^.+(?@') + "connection_string": "user:password@db.example.com", + "other_field": "some_value", } + subst_map = {"connection_string": (r":.*@", r":@")} masked = mask_creds(user_dict, subst_map=subst_map) - self.assertEqual(masked['connection_string'], - 'user:@db.example.com') - self.assertEqual(masked['other_field'], 'some_value') + self.assertEqual( + masked["connection_string"], "user:@db.example.com" + ) + self.assertEqual(masked["other_field"], "some_value") def test_mask_creds_no_maskable_fields(self): """Test mask_creds with a dictionary containing no maskable fields.""" - user_dict = {'username': 'test', 'role': 'user'} + user_dict = {"username": "test", "role": "user"} masked = mask_creds(user_dict) self.assertEqual(user_dict, masked) @@ -230,55 +275,58 @@ def test_mask_creds_empty_dict(self): def test_mask_creds_csrf_field(self): """Test that the default csrf_field is masked.""" - user_dict = {csrf_field: 'some_csrf_token', 'other': 'value'} + user_dict = {csrf_field: "some_csrf_token", "other": "value"} masked = mask_creds(user_dict) - self.assertEqual(masked[csrf_field], '**HIDDEN**') - self.assertEqual(masked['other'], 'value') + self.assertEqual(masked[csrf_field], "**HIDDEN**") + self.assertEqual(masked["other"], "value") def test_extract_field_exists(self): """Test extracting an existing field from a distinguished name.""" - self.assertEqual(extract_field(TEST_USER_ID, 'full_name'), - TEST_FULL_NAME) - self.assertEqual(extract_field(TEST_USER_ID, 'organization'), - TEST_ORGANIZATION) - self.assertEqual(extract_field(TEST_USER_ID, 'country'), TEST_COUNTRY) - self.assertEqual(extract_field(TEST_USER_ID, 'email'), TEST_EMAIL) + self.assertEqual( + extract_field(TEST_USER_ID, "full_name"), TEST_FULL_NAME + ) + self.assertEqual( + extract_field(TEST_USER_ID, "organization"), TEST_ORGANIZATION + ) + self.assertEqual(extract_field(TEST_USER_ID, "country"), TEST_COUNTRY) + self.assertEqual(extract_field(TEST_USER_ID, "email"), TEST_EMAIL) def test_extract_field_not_exists(self): """Test extracting a non-existent field returns None.""" - self.assertIsNone(extract_field(TEST_USER_ID, 'missing')) - self.assertIsNone(extract_field(TEST_USER_ID, 'dummy')) + self.assertIsNone(extract_field(TEST_USER_ID, "missing")) + self.assertIsNone(extract_field(TEST_USER_ID, "dummy")) def test_extract_field_with_na_value(self): """Test extracting a field with 'NA' value, which should be an empty string.""" - self.assertEqual(extract_field('/C=DK/DUMMY=NA/CN=TEST', 'DUMMY'), '') + self.assertEqual(extract_field("/C=DK/DUMMY=NA/CN=TEST", "DUMMY"), "") def test_extract_field_custom_field(self): """Test extracting a custom (non-standard) field.""" - self.assertEqual(extract_field('/C=DK/DUMMY=proj1/CN=Test', 'DUMMY'), - 'proj1') + self.assertEqual( + extract_field("/C=DK/DUMMY=proj1/CN=Test", "DUMMY"), "proj1" + ) def test_extract_field_empty_dn(self): """Test extracting from an empty distinguished name.""" - self.assertIsNone(extract_field("", 'full_name')) + self.assertIsNone(extract_field("", "full_name")) def test_extract_field_malformed_dn(self): """Test extracting from a malformed distinguished name.""" dn_empty_val = "/C=US/O=/CN=John Doe" - self.assertEqual(extract_field(dn_empty_val, 'organization'), '') + self.assertEqual(extract_field(dn_empty_val, "organization"), "") dn_no_equals = "/C=US/O/CN=John Doe" - self.assertIsNone(extract_field(dn_no_equals, 'organization')) + self.assertIsNone(extract_field(dn_no_equals, "organization")) def test_distinguished_name_to_user_basic(self): """Test basic conversion from distinguished name to user dictionary.""" user_dict = distinguished_name_to_user(TEST_USER_ID) expected = { - 'distinguished_name': TEST_USER_ID, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, + "distinguished_name": TEST_USER_ID, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, } self.assertEqual(user_dict, expected) @@ -287,12 +335,12 @@ def test_distinguished_name_to_user_with_na(self): dn = "%s/dummy=NA" % TEST_USER_ID user_dict = distinguished_name_to_user(dn) expected = { - 'distinguished_name': dn, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'dummy': '' + "distinguished_name": dn, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "dummy": "", } self.assertEqual(user_dict, expected) @@ -301,28 +349,29 @@ def test_distinguished_name_to_user_with_custom_field(self): dn = "%s/dummy=proj1" % TEST_USER_ID user_dict = distinguished_name_to_user(dn) expected = { - 'distinguished_name': dn, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'dummy': 'proj1' + "distinguished_name": dn, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "dummy": "proj1", } self.assertEqual(user_dict, expected) def test_distinguished_name_to_user_empty_and_malformed(self): """Test behavior with empty and malformed distinguished names.""" # Empty DN - self.assertEqual(distinguished_name_to_user(""), - {'distinguished_name': ''}) + self.assertEqual( + distinguished_name_to_user(""), {"distinguished_name": ""} + ) # Malformed part (no '=') dn_malformed = "/C=US/O/CN=John Doe" user_dict_malformed = distinguished_name_to_user(dn_malformed) expected_malformed = { - 'distinguished_name': dn_malformed, - 'country': 'US', - 'full_name': TEST_FULL_NAME + "distinguished_name": dn_malformed, + "country": "US", + "full_name": TEST_FULL_NAME, } self.assertEqual(user_dict_malformed, expected_malformed) @@ -330,45 +379,49 @@ def test_distinguished_name_to_user_empty_and_malformed(self): dn_empty_val = "/C=DK/O=/CN=John Doe" user_dict_empty_val = distinguished_name_to_user(dn_empty_val) expected_empty_val = { - 'distinguished_name': dn_empty_val, - 'country': TEST_COUNTRY, - 'organization': '', - 'full_name': TEST_FULL_NAME + "distinguished_name": dn_empty_val, + "country": TEST_COUNTRY, + "organization": "", + "full_name": TEST_FULL_NAME, } self.assertEqual(user_dict_empty_val, expected_empty_val) def test_fill_distinguished_name_from_fields(self): """Test filling distinguished_name from other user fields.""" user = { - 'full_name': 'Jane Doe', - 'organization': 'Test Corp', - 'country': TEST_COUNTRY, - 'email': 'jane.doe@example.com' + "full_name": "Jane Doe", + "organization": "Test Corp", + "country": TEST_COUNTRY, + "email": "jane.doe@example.com", } fill_distinguished_name(user) - expected_dn = "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" \ - "/emailAddress=jane.doe@example.com" - self.assertEqual(user['distinguished_name'], expected_dn) + expected_dn = ( + "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" + "/emailAddress=jane.doe@example.com" + ) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_distinguished_name_with_gdp(self): """Test filling distinguished_name with a GDP project field.""" user = { - 'full_name': 'Jane Doe', - 'organization': 'Test Corp', - 'country': TEST_COUNTRY, - gdp_distinguished_field: 'project_x' + "full_name": "Jane Doe", + "organization": "Test Corp", + "country": TEST_COUNTRY, + gdp_distinguished_field: "project_x", } fill_distinguished_name(user) - expected_dn = "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" \ - "/emailAddress=NA/GDP=project_x" - self.assertEqual(user['distinguished_name'], expected_dn) + expected_dn = ( + "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" + "/emailAddress=NA/GDP=project_x" + ) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_distinguished_name_already_exists(self): """Test that an existing distinguished_name is not overwritten.""" user = { - 'distinguished_name': TEST_USER_ID, - 'full_name': 'Jane Doe', - 'country': 'US' + "distinguished_name": TEST_USER_ID, + "full_name": "Jane Doe", + "country": "US", } original_user = user.copy() returned_user = fill_distinguished_name(user) @@ -380,24 +433,21 @@ def test_fill_distinguished_name_empty_user(self): user = {} fill_distinguished_name(user) expected_dn = "/C=NA/ST=NA/L=NA/O=NA/OU=NA/CN=NA/emailAddress=NA" - self.assertEqual(user['distinguished_name'], expected_dn) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_user_completes_dict(self): """Test that fill_user adds missing fields and preserves existing ones.""" - user = { - 'full_name': TEST_FULL_NAME, - 'extra_field': 'extra_value' - } + user = {"full_name": TEST_FULL_NAME, "extra_field": "extra_value"} fill_user(user) # Check that existing values are preserved - self.assertEqual(user['full_name'], TEST_FULL_NAME) - self.assertEqual(user['extra_field'], 'extra_value') + self.assertEqual(user["full_name"], TEST_FULL_NAME) + self.assertEqual(user["extra_field"], "extra_value") # Check that missing standard fields are added with empty strings - self.assertEqual(user['organization'], '') - self.assertEqual(user['country'], '') + self.assertEqual(user["organization"], "") + self.assertEqual(user["country"], "") # Check that all standard keys are present for key, _ in cert_field_order: @@ -409,7 +459,7 @@ def test_fill_user_with_empty_dict(self): fill_user(user) self.assertEqual(len(user), len(cert_field_order)) for key, _ in cert_field_order: - self.assertEqual(user[key], '') + self.assertEqual(user[key], "") def test_fill_user_modifies_in_place_and_returns_self(self): """Test that fill_user modifies the dictionary in-place and returns @@ -421,158 +471,171 @@ def test_fill_user_modifies_in_place_and_returns_self(self): def test_canonical_user_transformations(self): """Test canonical_user applies all transformations correctly.""" user_dict = { - 'full_name': ' john doe ', - 'email': 'John@DoE.ORG', - 'country': 'dk', - 'state': 'vt', - 'organization': ' Test Org ', - 'extra_field': 'should be removed', - 'id': 123 + "full_name": " john doe ", + "email": "John@DoE.ORG", + "country": "dk", + "state": "vt", + "organization": " Test Org ", + "extra_field": "should be removed", + "id": 123, } - limit_fields = ['full_name', 'email', - 'country', 'state', 'organization', 'id'] + limit_fields = [ + "full_name", + "email", + "country", + "state", + "organization", + "id", + ] canonical = canonical_user(self.configuration, user_dict, limit_fields) expected = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'country': TEST_COUNTRY, - 'state': TEST_STATE, - 'organization': TEST_ORGANIZATION, - 'id': 123 + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "country": TEST_COUNTRY, + "state": TEST_STATE, + "organization": TEST_ORGANIZATION, + "id": 123, } self.assertEqual(canonical, expected) - self.assertNotIn('extra_field', canonical) + self.assertNotIn("extra_field", canonical) def test_canonical_user_with_peers_legacy(self): """Test canonical_user_with_peers with legacy peers list""" - self.configuration.site_peers_explicit_fields = ['email', 'full_name'] + self.configuration.site_peers_explicit_fields = ["email", "full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers": [ + "/C=DK/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_email'], - 'alice@example.com, bob@example.com') - self.assertEqual(canonical['peers_full_name'], 'Alice, Bob') + self.assertEqual( + canonical["peers_email"], "alice@example.com, bob@example.com" + ) + self.assertEqual(canonical["peers_full_name"], "Alice, Bob") def test_canonical_user_with_peers_explicit(self): """Test canonical_user_with_peers with explicit peers fields""" - self.configuration.site_peers_explicit_fields = ['email', 'full_name'] + self.configuration.site_peers_explicit_fields = ["email", "full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers_email': 'custom@example.com', - 'peers_full_name': 'Custom Name', - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers_email": "custom@example.com", + "peers_full_name": "Custom Name", + "peers": [ + "/C=DK/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_email'], 'custom@example.com') - self.assertEqual(canonical['peers_full_name'], 'Custom Name') + self.assertEqual(canonical["peers_email"], "custom@example.com") + self.assertEqual(canonical["peers_full_name"], "Custom Name") def test_canonical_user_with_peers_mixed(self): """Test canonical_user_with_peers with mixed explicit and legacy peers""" self.configuration.site_peers_explicit_fields = [ - 'email', 'organization'] + "email", + "organization", + ] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers_organization': TEST_ORGANIZATION, - 'peers': [ - '/C=DK/O=Legacy Org/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers_organization": TEST_ORGANIZATION, + "peers": [ + "/C=DK/O=Legacy Org/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email', 'organization'] + limit_fields = ["full_name", "email", "organization"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) # Explicit field should be preserved - self.assertEqual(canonical['peers_organization'], 'Test Org') + self.assertEqual(canonical["peers_organization"], "Test Org") # Legacy peers should be converted for email - self.assertEqual(canonical['peers_email'], - 'alice@example.com, bob@example.com') + self.assertEqual( + canonical["peers_email"], "alice@example.com, bob@example.com" + ) def test_canonical_user_with_peers_empty(self): """Test canonical_user_with_peers with no peers data""" - self.configuration.site_peers_explicit_fields = ['email'] - user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL - } - limit_fields = ['full_name', 'email'] + self.configuration.site_peers_explicit_fields = ["email"] + user_dict = {"full_name": TEST_FULL_NAME, "email": TEST_EMAIL} + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertNotIn('peers_email', canonical) - self.assertNotIn('peers', canonical) + self.assertNotIn("peers_email", canonical) + self.assertNotIn("peers", canonical) def test_canonical_user_with_peers_no_explicit_fields(self): """Test canonical_user_with_peers with no peer fields configured""" self.configuration.site_peers_explicit_fields = [] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': 'john@example.com', - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": "john@example.com", + "peers": ["/C=DK/CN=Alice/emailAddress=alice@example.com"], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) # Should not create any peer fields - self.assertNotIn('peers_email', canonical) - self.assertNotIn('peers_full_name', canonical) + self.assertNotIn("peers_email", canonical) + self.assertNotIn("peers_full_name", canonical) def test_canonical_user_with_peers_special_chars(self): """Test canonical_user_with_peers handles special characters in DNs""" - self.configuration.site_peers_explicit_fields = ['full_name'] + self.configuration.site_peers_explicit_fields = ["full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'peers': [ - '/C=DK/CN=Jérôme Müller', - '/C=DK/CN=O‘‘Reilly', - '/C=DK/CN=Alice "Ace" Smith' - ] + "full_name": TEST_FULL_NAME, + "peers": [ + "/C=DK/CN=Jérôme Müller", + "/C=DK/CN=O‘‘Reilly", + '/C=DK/CN=Alice "Ace" Smith', + ], } - limit_fields = ['full_name'] + limit_fields = ["full_name"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_full_name'], - 'Jérôme Müller, O‘‘Reilly, Alice "Ace" Smith') + self.assertEqual( + canonical["peers_full_name"], + 'Jérôme Müller, O‘‘Reilly, Alice "Ace" Smith', + ) def test_canonical_user_unicode_name(self): """Test canonical_user with unicode characters in full_name.""" # Using a name that title() might mess up without unicode conversion - user_dict = {'full_name': u'josé de la vega'} - limit_fields = ['full_name'] + user_dict = {"full_name": "josé de la vega"} + limit_fields = ["full_name"] canonical = canonical_user(self.configuration, user_dict, limit_fields) - self.assertEqual(canonical['full_name'], u'José De La Vega') + self.assertEqual(canonical["full_name"], "José De La Vega") def test_canonical_user_empty_input(self): """Test canonical_user with empty inputs.""" self.assertEqual(canonical_user(self.configuration, {}, []), {}) - self.assertEqual(canonical_user(self.configuration, {'a': 1}, []), {}) - self.assertEqual(canonical_user(self.configuration, {}, ['a']), {}) + self.assertEqual(canonical_user(self.configuration, {"a": 1}, []), {}) + self.assertEqual(canonical_user(self.configuration, {}, ["a"]), {}) def test_generate_https_urls_single_method_cgi(self): """Test generate_https_urls with a single method and cgi-bin.""" - self.configuration.site_login_methods = ['migcert'] + self.configuration.site_login_methods = ["migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" expected = "https://mig.cert/cgi-bin/script.py" result = generate_https_urls(self.configuration, template, {}) @@ -581,7 +644,7 @@ def test_generate_https_urls_single_method_cgi(self): def test_generate_https_urls_single_method_wsgi(self): """Test generate_https_urls with a single method and wsgi-bin.""" self.configuration.site_enable_wsgi = True - self.configuration.site_login_methods = ['migcert'] + self.configuration.site_login_methods = ["migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" expected = "https://mig.cert/wsgi-bin/script.py" result = generate_https_urls(self.configuration, template, {}) @@ -590,29 +653,32 @@ def test_generate_https_urls_single_method_wsgi(self): def test_generate_https_urls_multiple_methods(self): """Test generate_https_urls with multiple methods.""" template = "%(auto_base)s/%(auto_bin)s/script.py" - self.configuration.site_login_methods = ['migcert', 'extoidc'] + self.configuration.site_login_methods = ["migcert", "extoidc"] result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://mig.cert/cgi-bin/script.py" expected_url2 = "https://ext.oidc/cgi-bin/script.py" expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_generate_https_urls_with_helper_dict(self): """Test generate_https_urls with a helper_dict.""" - self.configuration.site_login_methods = ['extoid'] + self.configuration.site_login_methods = ["extoid"] template = "%(auto_base)s/%(auto_bin)s/%(script)s" - helper = {'script': 'login.py'} + helper = {"script": "login.py"} result = generate_https_urls(self.configuration, template, helper) self.assertEqual(result, "https://ext.oid/cgi-bin/login.py") def test_generate_https_urls_method_enabled_but_url_missing(self): """Test that methods with no configured URL are skipped.""" self.configuration.migserver_https_ext_cert_url = "" # URL is missing - self.configuration.site_login_methods = ['migcert', 'extcert'] + self.configuration.site_login_methods = ["migcert", "extcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) self.assertEqual(result, "https://mig.cert/cgi-bin/script.py") @@ -626,7 +692,7 @@ def test_generate_https_urls_no_methods_enabled(self): def test_generate_https_urls_respects_order(self): """Test that the order of site_login_methods is respected.""" - self.configuration.site_login_methods = ['extoidc', 'migcert'] + self.configuration.site_login_methods = ["extoidc", "migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://ext.oidc/cgi-bin/script.py" @@ -634,14 +700,20 @@ def test_generate_https_urls_respects_order(self): expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_generate_https_urls_avoids_duplicates(self): """Test that duplicate URLs are not generated.""" self.configuration.site_login_methods = [ - 'migcert', 'extoidc', 'migcert'] + "migcert", + "extoidc", + "migcert", + ] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://mig.cert/cgi-bin/script.py" @@ -649,22 +721,35 @@ def test_generate_https_urls_avoids_duplicates(self): expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_auth_type_description_all(self): """Test auth_type_description returns full dict when requested""" from mig.shared.defaults import keyword_all + result = auth_type_description(self.configuration, keyword_all) - expected_keys = ['migoid', 'migoidc', 'migcert', 'extoid', 'extoidc', - 'extcert'] + expected_keys = [ + "migoid", + "migoidc", + "migcert", + "extoid", + "extoidc", + "extcert", + ] self.assertEqual(sorted(result.keys()), sorted(expected_keys)) def test_auth_type_description_individual(self): """Test auth_type_description returns expected strings for each type""" - from mig.shared.defaults import AUTH_CERTIFICATE, AUTH_OPENID_CONNECT, \ - AUTH_OPENID_V2 + from mig.shared.defaults import ( + AUTH_CERTIFICATE, + AUTH_OPENID_CONNECT, + AUTH_OPENID_V2, + ) # Setup titles in configuration self.configuration.user_mig_oid_title = "MiG OpenID" @@ -673,31 +758,37 @@ def test_auth_type_description_individual(self): self.configuration.user_ext_cert_title = "External Certificate" test_cases = [ - ('migoid', 'MiG OpenID %s' % AUTH_OPENID_V2), - ('migoidc', 'MiG OpenID %s' % AUTH_OPENID_CONNECT), - ('migcert', 'MiG Certificate %s' % AUTH_CERTIFICATE), - ('extoid', 'External OpenID %s' % AUTH_OPENID_V2), - ('extoidc', 'External OpenID %s' % AUTH_OPENID_CONNECT), - ('extcert', 'External Certificate %s' % AUTH_CERTIFICATE), + ("migoid", "MiG OpenID %s" % AUTH_OPENID_V2), + ("migoidc", "MiG OpenID %s" % AUTH_OPENID_CONNECT), + ("migcert", "MiG Certificate %s" % AUTH_CERTIFICATE), + ("extoid", "External OpenID %s" % AUTH_OPENID_V2), + ("extoidc", "External OpenID %s" % AUTH_OPENID_CONNECT), + ("extcert", "External Certificate %s" % AUTH_CERTIFICATE), ] - for (auth_type, expected) in test_cases: + for auth_type, expected in test_cases: result = auth_type_description(self.configuration, auth_type) self.assertEqual(result, expected) def test_auth_type_description_unknown(self): """Test auth_type_description returns 'UNKNOWN' for invalid types""" - self.assertEqual(auth_type_description(self.configuration, 'invalid'), - 'UNKNOWN') - self.assertEqual(auth_type_description( - self.configuration, ''), 'UNKNOWN') - self.assertEqual(auth_type_description( - self.configuration, None), 'UNKNOWN') + self.assertEqual( + auth_type_description(self.configuration, "invalid"), "UNKNOWN" + ) + self.assertEqual( + auth_type_description(self.configuration, ""), "UNKNOWN" + ) + self.assertEqual( + auth_type_description(self.configuration, None), "UNKNOWN" + ) def test_auth_type_description_empty_titles(self): """Test auth_type_description handles empty titles in configuration""" - from mig.shared.defaults import AUTH_CERTIFICATE, AUTH_OPENID_CONNECT, \ - AUTH_OPENID_V2 + from mig.shared.defaults import ( + AUTH_CERTIFICATE, + AUTH_OPENID_CONNECT, + AUTH_OPENID_V2, + ) self.configuration.user_mig_oid_title = "" self.configuration.user_mig_cert_title = "" @@ -705,15 +796,15 @@ def test_auth_type_description_empty_titles(self): self.configuration.user_ext_cert_title = "" test_cases = [ - ('migoid', ' %s' % AUTH_OPENID_V2), - ('migoidc', ' %s' % AUTH_OPENID_CONNECT), - ('migcert', ' %s' % AUTH_CERTIFICATE), - ('extoid', ' %s' % AUTH_OPENID_V2), - ('extoidc', ' %s' % AUTH_OPENID_CONNECT), - ('extcert', ' %s' % AUTH_CERTIFICATE), + ("migoid", " %s" % AUTH_OPENID_V2), + ("migoidc", " %s" % AUTH_OPENID_CONNECT), + ("migcert", " %s" % AUTH_CERTIFICATE), + ("extoid", " %s" % AUTH_OPENID_V2), + ("extoidc", " %s" % AUTH_OPENID_CONNECT), + ("extcert", " %s" % AUTH_CERTIFICATE), ] - for (auth_type, expected) in test_cases: + for auth_type, expected in test_cases: result = auth_type_description(self.configuration, auth_type) self.assertEqual(result, expected) @@ -721,11 +812,16 @@ def test_allow_script_gdp_enabled_anonymous_allowed(self): """Test allow_script with GDP enabled, anonymous user, and script allowed.""" self.configuration.site_enable_gdp = True - script_name = valid_gdp_anon_scripts[0] if valid_gdp_anon_scripts \ - else 'allowed_script.py' # Use a valid script or a default + script_name = ( + valid_gdp_anon_scripts[0] + if valid_gdp_anon_scripts + else "allowed_script.py" + ) # Use a valid script or a default if not valid_gdp_anon_scripts: - print("WARNING: valid_gdp_anon_scripts is empty. Using " - "'allowed_script.py' which may cause a test failure.") + print( + "WARNING: valid_gdp_anon_scripts is empty. Using " + "'allowed_script.py' which may cause a test failure." + ) allow, msg = allow_script(self.configuration, script_name, None) self.assertTrue(allow) self.assertEqual(msg, "") @@ -734,29 +830,41 @@ def test_allow_script_gdp_enabled_anonymous_disallowed(self): """Test allow_script with GDP enabled, anonymous user, and script disallowed.""" self.configuration.site_enable_gdp = True - script_name = 'disallowed_script.py' + script_name = "disallowed_script.py" # Ensure the script is not in valid_gdp_anon_scripts if script_name in valid_gdp_anon_scripts: valid_gdp_anon_scripts.remove(script_name) allow, msg = allow_script(self.configuration, script_name, None) self.assertFalse(allow) - self.assertEqual(msg, "anonoymous access to functionality disabled " - "by site configuration!") + self.assertEqual( + msg, + "anonoymous access to functionality disabled " + "by site configuration!", + ) def test_allow_script_gdp_enabled_authenticated_allowed(self): """Test allow_script with GDP enabled, authenticated user, and script allowed.""" self.configuration.site_enable_gdp = True - script_name = valid_gdp_auth_scripts[0] if valid_gdp_auth_scripts \ - else valid_gdp_anon_scripts[0] if valid_gdp_anon_scripts \ - else 'allowed_script.py' + script_name = ( + valid_gdp_auth_scripts[0] + if valid_gdp_auth_scripts + else ( + valid_gdp_anon_scripts[0] + if valid_gdp_anon_scripts + else "allowed_script.py" + ) + ) if not valid_gdp_auth_scripts and not valid_gdp_anon_scripts: - print("WARNING: valid_gdp_auth_scripts and " - "valid_gdp_anon_scripts are empty. Using " - "'allowed_script.py' which may cause a test failure.") + print( + "WARNING: valid_gdp_auth_scripts and " + "valid_gdp_anon_scripts are empty. Using " + "'allowed_script.py' which may cause a test failure." + ) allow, msg = allow_script( - self.configuration, script_name, 'test_client') + self.configuration, script_name, "test_client" + ) self.assertTrue(allow) self.assertEqual(msg, "") @@ -764,7 +872,7 @@ def test_allow_script_gdp_enabled_authenticated_disallowed(self): """Test allow_script with GDP enabled, authenticated user, and script disallowed.""" self.configuration.site_enable_gdp = True - script_name = 'disallowed_script.py' + script_name = "disallowed_script.py" # Ensure the script is not in valid_gdp_auth_scripts or # valid_gdp_anon_scripts @@ -774,98 +882,108 @@ def test_allow_script_gdp_enabled_authenticated_disallowed(self): valid_gdp_anon_scripts.remove(script_name) allow, msg = allow_script( - self.configuration, script_name, 'test_client') + self.configuration, script_name, "test_client" + ) self.assertFalse(allow) - self.assertEqual(msg, "all access to functionality disabled by site " - "configuration!") + self.assertEqual( + msg, + "all access to functionality disabled by site " "configuration!", + ) def test_allow_script_gdp_disabled(self): """Test allow_script with GDP disabled.""" self.configuration.site_enable_gdp = False - allow, msg = allow_script(self.configuration, 'any_script.py', - 'test_client') + allow, msg = allow_script( + self.configuration, "any_script.py", "test_client" + ) self.assertTrue(allow) self.assertEqual(msg, "") def test_allow_script_gdp_disabled_anonymous(self): """Test allow_script with GDP disabled and anonymous user.""" self.configuration.site_enable_gdp = False - allow, msg = allow_script(self.configuration, 'any_script.py', None) + allow, msg = allow_script(self.configuration, "any_script.py", None) self.assertTrue(allow) self.assertEqual(msg, "") def test_requested_page_normal(self): """Test requested_page with basic environment""" fake_env = { - 'SCRIPT_NAME': '/cgi-bin/home.py', - 'REQUEST_URI': '/cgi-bin/home.py' + "SCRIPT_NAME": "/cgi-bin/home.py", + "REQUEST_URI": "/cgi-bin/home.py", } - self.assertEqual(requested_page(fake_env), '/cgi-bin/home.py') + self.assertEqual(requested_page(fake_env), "/cgi-bin/home.py") def test_requested_page_name_only(self): """Test requested_page with name_only argument""" fake_env = { - 'BACKEND_NAME': 'search.py', - 'PATH_INFO': '/cgi-bin/search.py/path' + "BACKEND_NAME": "search.py", + "PATH_INFO": "/cgi-bin/search.py/path", } - result = requested_page(fake_env, name_only=True, fallback='fallback') - self.assertEqual(result, 'search.py') + result = requested_page(fake_env, name_only=True, fallback="fallback") + self.assertEqual(result, "search.py") def test_requested_page_strip_extension(self): """Test requested_page with strip_ext argument""" - fake_env = {'REQUEST_URI': '/cgi-bin/file.py?query=val'} + fake_env = {"REQUEST_URI": "/cgi-bin/file.py?query=val"} result = requested_page(fake_env, strip_ext=True) - self.assertEqual(result, '/cgi-bin/file') + self.assertEqual(result, "/cgi-bin/file") def test_requested_page_priority(self): """Test environment variable priority order""" _init_env = { - 'BACKEND_NAME': 'backend.py', - 'SCRIPT_URL': '/cgi-bin/script_url.py', - 'SCRIPT_URI': 'https://host/cgi-bin/script_uri.py', - 'PATH_INFO': '/cgi-bin/path_info.py', - 'REQUEST_URI': '/cgi-bin/req_uri.py' + "BACKEND_NAME": "backend.py", + "SCRIPT_URL": "/cgi-bin/script_url.py", + "SCRIPT_URI": "https://host/cgi-bin/script_uri.py", + "PATH_INFO": "/cgi-bin/path_info.py", + "REQUEST_URI": "/cgi-bin/req_uri.py", } - priority_order = ['BACKEND_NAME', 'SCRIPT_URL', 'SCRIPT_URI', - 'PATH_INFO', 'REQUEST_URI'] + priority_order = [ + "BACKEND_NAME", + "SCRIPT_URL", + "SCRIPT_URI", + "PATH_INFO", + "REQUEST_URI", + ] for var in priority_order: # Reset fake_env each time fake_env = dict([pair for pair in _init_env.items()]) current_env = {var: fake_env[var]} - if var != 'SCRIPT_URI': + if var != "SCRIPT_URI": expected = fake_env[var] else: - expected = 'https://host/cgi-bin/script_uri.py' + expected = "https://host/cgi-bin/script_uri.py" result = requested_page(current_env) - self.assertEqual(result, expected, - "failed priority for %s" % var) + self.assertEqual(result, expected, "failed priority for %s" % var) # Remove higher priority variables one by one - for higher_var in priority_order[:priority_order.index(var)]: + for higher_var in priority_order[: priority_order.index(var)]: del fake_env[higher_var] result = requested_page(fake_env) - self.assertEqual(result, fake_env[var], - "failed fallthrough to %s" % var) + self.assertEqual( + result, fake_env[var], "failed fallthrough to %s" % var + ) def test_requested_page_unsafe_filter(self): """Test unsafe character filtering""" test_cases = [ - ('/cgi-bin/unsafe.py' - fake_env = {'REQUEST_URI': dangerous} + dangerous = "/cgi-bin/unsafe.py" + fake_env = {"REQUEST_URI": dangerous} unsafe_result = requested_page(fake_env, include_unsafe=True) self.assertEqual(unsafe_result, dangerous) @@ -874,39 +992,39 @@ def test_requested_page_include_unsafe(self): def test_requested_page_query_stripping(self): """Test removal of query parameters""" - test_input = '/cgi-bin/script.py?query=value¶m=data' - fake_env = {'REQUEST_URI': test_input} + test_input = "/cgi-bin/script.py?query=value¶m=data" + fake_env = {"REQUEST_URI": test_input} result = requested_page(fake_env) - self.assertEqual(result, '/cgi-bin/script.py') + self.assertEqual(result, "/cgi-bin/script.py") def test_requested_page_fallback(self): """Test fallback to default""" fake_env = {} - fallback = 'special.py' + fallback = "special.py" result = requested_page(fake_env, fallback=fallback) self.assertEqual(result, fallback) def test_requested_page_fallback_despite_os_environ_value(self): """Test fallback to default""" fake_env = {} - fallback = 'special.py' - os.environ['BACKEND_NAME'] = 'BOGUS' + fallback = "special.py" + os.environ["BACKEND_NAME"] = "BOGUS" result = requested_page(fake_env, fallback=fallback) - del os.environ['BACKEND_NAME'] + del os.environ["BACKEND_NAME"] self.assertEqual(result, fallback) def test_requested_url_base_normal(self): """Test requested_url_base with basic complete URL""" - fake_env = {'SCRIPT_URI': 'https://example.com/path/to/script.py'} + fake_env = {"SCRIPT_URI": "https://example.com/path/to/script.py"} result = requested_url_base(fake_env) - expected = 'https://example.com' + expected = "https://example.com" self.assertEqual(result, expected) def test_requested_url_base_custom_field(self): """Test requested_url_base with custom uri_field parameter""" - fake_env = {'CUSTOM_FIELD_URI': 'http://server.org:8001/base/'} - result = requested_url_base(fake_env, uri_field='CUSTOM_FIELD_URI') - expected = 'http://server.org:8001' + fake_env = {"CUSTOM_FIELD_URI": "http://server.org:8001/base/"} + result = requested_url_base(fake_env, uri_field="CUSTOM_FIELD_URI") + expected = "http://server.org:8001" self.assertEqual(result, expected) # TODO: adjust tested function to bail out on missing uri_field @@ -915,56 +1033,55 @@ def test_requested_url_base_missing(self): """Test requested_url_base when uri_field not present""" fake_env = {} result = requested_url_base(fake_env) - self.assertEqual(result, '') + self.assertEqual(result, "") def test_requested_url_base_safe_filter(self): """Test unsafe character filtering in url base""" - test_url = 'https://user:pass@evil.com/' - fake_env = {'SCRIPT_URI': test_url} + test_url = "https://user:pass@evil.com/" + fake_env = {"SCRIPT_URI": test_url} safe_result = requested_url_base(fake_env) - expected_safe = 'https://user:passevil.com' + expected_safe = "https://user:passevil.com" self.assertEqual(safe_result, expected_safe) def test_requested_url_base_include_unsafe(self): """Test include_unsafe argument behavior""" - test_url = 'http://[::1]?' - fake_env = {'SCRIPT_URI': test_url} + test_url = "http://[::1]?" + fake_env = {"SCRIPT_URI": test_url} unsafe_result = requested_url_base(fake_env, include_unsafe=True) - self.assertEqual(unsafe_result, 'http://[::1]?') + self.assertEqual(unsafe_result, "http://[::1]?") safe_result = requested_url_base(fake_env, include_unsafe=False) - self.assertEqual(safe_result, 'http://::1markup') + self.assertEqual(safe_result, "http://::1markup") safe_result = requested_url_base(fake_env) - self.assertEqual(safe_result, 'http://::1markup') + self.assertEqual(safe_result, "http://::1markup") def test_requested_url_base_split_valid_edge_cases(self): """Test URL base splitting on valid edge cases""" test_cases = [ - ('https://site.com', 'https://site.com'), - ('http://a/single/slash', 'http://a'), - ('file:///absolute/path', 'file://'), - ('invalid.proto://double/slash', 'invalid.proto://double') + ("https://site.com", "https://site.com"), + ("http://a/single/slash", "http://a"), + ("file:///absolute/path", "file://"), + ("invalid.proto://double/slash", "invalid.proto://double"), ] - for (input_url, expected) in test_cases: - fake_env = {'SCRIPT_URI': input_url} + for input_url, expected in test_cases: + fake_env = {"SCRIPT_URI": input_url} result = requested_url_base(fake_env) - self.assertEqual(result, expected, - "failed for %s" % input_url) + self.assertEqual(result, expected, "failed for %s" % input_url) # TODO: adjust function to bail out on invalid URLs and enable next @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_split_invalid_edge_cases(self): """Test URL base splitting on invalid edge cases""" test_cases = [ - ('', ''), - ('/', '/'), - ('/single', '/single'), - ('/double/slash', '/double/slash'), - ('invalid.proto:/1st/2nd/slash', 'invalid.proto:/1st/2nd/slash'), - ('invalid.proto://double/slash', 'invalid.proto://double') + ("", ""), + ("/", "/"), + ("/single", "/single"), + ("/double/slash", "/double/slash"), + ("invalid.proto:/1st/2nd/slash", "invalid.proto:/1st/2nd/slash"), + ("invalid.proto://double/slash", "invalid.proto://double"), ] - for (input_url, expected) in test_cases: - fake_env = {'SCRIPT_URI': input_url} + for input_url, expected in test_cases: + fake_env = {"SCRIPT_URI": input_url} try: result = requested_url_base(fake_env) except ValueError: @@ -975,7 +1092,7 @@ def test_requested_url_base_split_invalid_edge_cases(self): @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_relative_path(self): """Test relative paths in URL""" - fake_env = {'SCRIPT_URI': '/cgi-bin/script.py'} + fake_env = {"SCRIPT_URI": "/cgi-bin/script.py"} try: result = requested_url_base(fake_env) except ValueError: @@ -986,10 +1103,10 @@ def test_requested_url_base_relative_path(self): @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_special_chars(self): """Test handling of special characters in URL""" - test_url = 'http://üñîçøðê.net/path' - fake_env = {'SCRIPT_URI': test_url} + test_url = "http://üñîçøðê.net/path" + fake_env = {"SCRIPT_URI": test_url} result = requested_url_base(fake_env) - self.assertEqual(result, 'http://üñîçøðê.net') + self.assertEqual(result, "http://üñîçøðê.net") def test_verify_local_url_direct_match(self): """Test verify_local_url with direct match to known site URL""" @@ -1005,7 +1122,9 @@ def test_verify_local_url_subpath_match(self): def test_verify_local_url_public_alias(self): """Test verify_local_url with public alias domain""" - self.configuration.migserver_public_alias_url = "https://grid.example.org" + self.configuration.migserver_public_alias_url = ( + "https://grid.example.org" + ) test_url = "https://grid.example.org/cgi-bin/file.py" self.assertTrue(verify_local_url(self.configuration, test_url)) @@ -1017,103 +1136,122 @@ def test_verify_local_url_absolute_path(self): def test_verify_local_url_relative_path(self): """Test verify_local_url with relative path""" test_url = "subdir/script.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_external_domain(self): """Test verify_local_url rejects external domains""" test_url = "https://evil.com/malicious.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_invalid_url(self): """Test verify_local_url rejects invalid/malformed URLs""" test_url = "javascript:alert('xss')" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_missing_https(self): """Test verify_local_url with HTTP when only HTTPS supported""" test_url = "http://mig.cert/cgi-bin/home.py" self.configuration.migserver_https_mig_cert_url = "https://mig.cert" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_different_port(self): """Test verify_local_url rejects same domain with different port""" self.configuration.migserver_https_ext_cert_url = "https://ext.cert:443" test_url = "https://ext.cert:444/cgi-bin/file.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_invisible_path_file(self): """Test invisible_path detects names in invisible files""" - invisible_filename = '.htaccess' - visible_filename = 'visible.txt' + invisible_filename = ".htaccess" + visible_filename = "visible.txt" # Test with invisible filename - self.assertTrue(invisible_path('/some/path/%s' % invisible_filename)) + self.assertTrue(invisible_path("/some/path/%s" % invisible_filename)) self.assertTrue(invisible_path(invisible_filename)) self.assertTrue(invisible_path(invisible_filename, True)) # Test with visible filename self.assertFalse(invisible_path(visible_filename)) - self.assertFalse(invisible_path('/some/path/%s' % visible_filename)) + self.assertFalse(invisible_path("/some/path/%s" % visible_filename)) def test_invisible_path_dir(self): """Test invisible_path detects paths in invisible dir""" - invisible_dirname = '.vgridscm' - visible_dirname = 'somedir' + invisible_dirname = ".vgridscm" + visible_dirname = "somedir" # Test different forms of invisible directory path self.assertTrue(invisible_path(invisible_dirname)) - self.assertTrue(invisible_path('/%s' % invisible_dirname)) - self.assertTrue(invisible_path('/parent/%s' % invisible_dirname)) - self.assertTrue(invisible_path('%s/sub' % invisible_dirname)) - self.assertTrue(invisible_path('/%s/file' % invisible_dirname)) + self.assertTrue(invisible_path("/%s" % invisible_dirname)) + self.assertTrue(invisible_path("/parent/%s" % invisible_dirname)) + self.assertTrue(invisible_path("%s/sub" % invisible_dirname)) + self.assertTrue(invisible_path("/%s/file" % invisible_dirname)) # Test visible directory self.assertFalse(invisible_path(visible_dirname)) - self.assertFalse(invisible_path('/%s' % visible_dirname)) - self.assertFalse(invisible_path('/parent/%s' % visible_dirname)) + self.assertFalse(invisible_path("/%s" % visible_dirname)) + self.assertFalse(invisible_path("/parent/%s" % visible_dirname)) def test_invisible_path_vgrid_exception(self): """Test allow_vgrid_scripts excludes valid vgrid xgi scripts""" - invisible_dirname = '.vgridscm' - vgrid_script = '.vgridscm/cgi-bin/hgweb.cgi' + invisible_dirname = ".vgridscm" + vgrid_script = ".vgridscm/cgi-bin/hgweb.cgi" test_patterns = [ - '/%s/%s' % (invisible_dirname, vgrid_script), - '/root/%s/sub/%s' % (invisible_dirname, vgrid_script), - '/%s/prefix%ssuffix' % (invisible_dirname, vgrid_script), - '/%s/similar_script.py' % invisible_dirname, - '/path/to/%s' % vgrid_script - + "/%s/%s" % (invisible_dirname, vgrid_script), + "/root/%s/sub/%s" % (invisible_dirname, vgrid_script), + "/%s/prefix%ssuffix" % (invisible_dirname, vgrid_script), + "/%s/similar_script.py" % invisible_dirname, + "/path/to/%s" % vgrid_script, ] test_expects = [False, False, False, True, False] test_iter = zip(test_patterns, test_expects) - for (i, (path, expected)) in enumerate(test_iter): + for i, (path, expected) in enumerate(test_iter): self.assertEqual( invisible_path(path, allow_vgrid_scripts=True), expected, "test case %d: path %r should %sbe invisible with scripts" - % (i, path, "" if expected else "not ") + % (i, path, "" if expected else "not "), ) # Should still be invisible without exception flag @@ -1122,23 +1260,24 @@ def test_invisible_path_vgrid_exception(self): invisible_path(path, allow_vgrid_scripts=False), expect_no_exception, "test case %d: path %r should %sbe invisible without scripts" - % (i, path, "" if expect_no_exception else "not ") + % (i, path, "" if expect_no_exception else "not "), ) def test_invisible_path_edge_cases(self): """Test invisible_path handles edge cases""" from mig.shared.defaults import _user_invisible_dirs + invisible_dirname = _user_invisible_dirs[0] # Empty path - self.assertFalse(invisible_path('')) - self.assertFalse(invisible_path('', allow_vgrid_scripts=True)) + self.assertFalse(invisible_path("")) + self.assertFalse(invisible_path("", allow_vgrid_scripts=True)) # Root path - self.assertFalse(invisible_path('/')) + self.assertFalse(invisible_path("/")) # Path that only contains invisible directory substring - substring_path = '/prefix%ssuffix/file' % invisible_dirname + substring_path = "/prefix%ssuffix/file" % invisible_dirname self.assertFalse(invisible_path(substring_path)) def test_client_alias(self): @@ -1176,21 +1315,21 @@ def test_get_short_id_with_gdp(self): def test_get_user_id_x509_format(self): """Test get_user_id returns DN for X509 format""" self.configuration.site_user_id_format = "X509" - user = {'distinguished_name': TEST_USER_ID} + user = {"distinguished_name": TEST_USER_ID} result = get_user_id(self.configuration, user) self.assertEqual(result, TEST_USER_ID) def test_get_user_id_uuid_format(self): """Test get_user_id returns UUID when configured""" self.configuration.site_user_id_format = "UUID" - user = {'unique_id': "123e4567-e89b-12d3-a456-426614174000"} + user = {"unique_id": "123e4567-e89b-12d3-a456-426614174000"} result = get_user_id(self.configuration, user) self.assertEqual(result, "123e4567-e89b-12d3-a456-426614174000") def test_get_client_id(self): """Test get_client_id extracts DN from user dict""" test_dn = "/C=US/CN=Alice" - user = {'distinguished_name': test_dn, 'other': 'field'} + user = {"distinguished_name": test_dn, "other": "field"} result = get_client_id(user) self.assertEqual(result, test_dn) @@ -1204,14 +1343,8 @@ def test_hexlify_unhexlify_roundtrip(self): def test_is_gdp_user_detection(self): """Test is_gdp_user detects GDP project presence""" - self.assertTrue(is_gdp_user( - self.configuration, - "/GDP_PROJ=12345" - )) - self.assertFalse(is_gdp_user( - self.configuration, - "/CN=Regular User" - )) + self.assertTrue(is_gdp_user(self.configuration, "/GDP_PROJ=12345")) + self.assertFalse(is_gdp_user(self.configuration, "/CN=Regular User")) def test_sandbox_resource_identification(self): """Test sandbox_resource identifies sandboxes""" @@ -1235,8 +1368,8 @@ def test_invisible_dir_detection(self): def test_requested_backend_extraction(self): """Test requested_backend extracts backend name from environ""" test_env = { - 'BACKEND_NAME': '/cgi-bin/fileman.py', - 'PATH_TRANSLATED': '/wsgi-bin/fileman.py' + "BACKEND_NAME": "/cgi-bin/fileman.py", + "PATH_TRANSLATED": "/wsgi-bin/fileman.py", } result = requested_backend(test_env) self.assertEqual(result, "fileman") @@ -1258,15 +1391,16 @@ def test_get_xgi_bin_wsgi_vs_cgi(self): """Test get_xgi_bin returns correct script bin based on config""" # Test WSGI enabled self.configuration.site_enable_wsgi = True - self.assertEqual(get_xgi_bin(self.configuration), 'wsgi-bin') + self.assertEqual(get_xgi_bin(self.configuration), "wsgi-bin") # Test WSGI disabled self.configuration.site_enable_wsgi = False - self.assertEqual(get_xgi_bin(self.configuration), 'cgi-bin') + self.assertEqual(get_xgi_bin(self.configuration), "cgi-bin") # Test legacy force - self.assertEqual(get_xgi_bin(self.configuration, force_legacy=True), - 'cgi-bin') + self.assertEqual( + get_xgi_bin(self.configuration, force_legacy=True), "cgi-bin" + ) def test_valid_dir_input(self): """Test valid_dir_input prevents path traversal attempts""" @@ -1277,11 +1411,11 @@ def test_valid_dir_input(self): ("../illegal", False), ("/absolute", False), ] - for (relative_path, expected) in test_cases: + for relative_path, expected in test_cases: self.assertEqual( valid_dir_input(base, relative_path), expected, - "failed for %s" % relative_path + "failed for %s" % relative_path, ) def test_user_base_dir(self): @@ -1306,32 +1440,56 @@ def test_brief_list(self): # List longer than max_entries gets shortened long_list = list(range(15)) - expected_long = [0, 1, 2, 3, 4, ' ... shortened ... ', 10, 11, 12, - 13, 14] + expected_long = [ + 0, + 1, + 2, + 3, + 4, + " ... shortened ... ", + 10, + 11, + 12, + 13, + 14, + ] self.assertEqual(brief_list(long_list), expected_long) # Custom max_entries with odd number custom_odd_list = list(range(10)) - expected_odd = [0, 1, 2, 3, ' ... shortened ... ', 6, 7, 8, 9] + expected_odd = [0, 1, 2, 3, " ... shortened ... ", 6, 7, 8, 9] self.assertEqual(brief_list(custom_odd_list, 9), expected_odd) # Range objects should be handled properly input_range = range(20) - expected_range = [0, 1, 2, 3, 4, ' ... shortened ... ', 15, 16, 17, - 18, 19] + expected_range = [ + 0, + 1, + 2, + 3, + 4, + " ... shortened ... ", + 15, + 16, + 17, + 18, + 19, + ] self.assertEqual(brief_list(input_range), expected_range) # Edge case - max_entries=2 - self.assertEqual(brief_list([1, 2, 3, 4], 2), - [1, ' ... shortened ... ', 4]) + self.assertEqual( + brief_list([1, 2, 3, 4], 2), [1, " ... shortened ... ", 4] + ) # Very small max_entries - self.assertEqual(brief_list([1, 2, 3, 4], 3), - [1, ' ... shortened ... ', 4]) + self.assertEqual( + brief_list([1, 2, 3, 4], 3), [1, " ... shortened ... ", 4] + ) # Non-integer input - str_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - expected_str = ['a', 'b', 'c', ' ... shortened ... ', 'e', 'f', 'g'] + str_list = ["a", "b", "c", "d", "e", "f", "g"] + expected_str = ["a", "b", "c", " ... shortened ... ", "e", "f", "g"] self.assertEqual(brief_list(str_list, 7), str_list) # At max_entries # TODO: fix tested function to handle these and enable test @@ -1339,12 +1497,14 @@ def test_brief_list(self): def test_brief_list_edge_cases(self): """Test brief_list helper function for compact list on edge cases""" # Edge case - max_entries=1 - self.assertEqual(brief_list([1, 2, 3], 1), [' ... shortened ... ']) + self.assertEqual(brief_list([1, 2, 3], 1), [" ... shortened ... "]) # Edge case - even short number of max_entries - str_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - self.assertEqual(brief_list(str_list, 6), - ['a', 'b', ' ... shortened ... ', 'f', 'g']) + str_list = ["a", "b", "c", "d", "e", "f", "g"] + self.assertEqual( + brief_list(str_list, 6), + ["a", "b", " ... shortened ... ", "f", "g"], + ) class TestMigSharedBase__legacy(MigTestCase): @@ -1353,15 +1513,17 @@ class TestMigSharedBase__legacy(MigTestCase): # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): """Run built-in self-tests and check output""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % - (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): diff --git a/tests/test_mig_shared_cloud.py b/tests/test_mig_shared_cloud.py index 241a79970..00e473b99 100644 --- a/tests/test_mig_shared_cloud.py +++ b/tests/test_mig_shared_cloud.py @@ -34,63 +34,74 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.cloud import cloud_load_instance, cloud_save_instance, \ - allowed_cloud_images - -DUMMY_USER = 'dummy-user' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +from mig.shared.cloud import ( + allowed_cloud_images, + cloud_load_instance, + cloud_save_instance, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) + +DUMMY_USER = "dummy-user" +DUMMY_SETTINGS_DIR = "dummy_user_settings" DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) DUMMY_CLOUD = "CLOUD" -DUMMY_FLAVOR = 'openstack' -DUMMY_LABEL = 'dummy-label' -DUMMY_IMAGE = 'dummy-image' -DUMMY_HEX_ID = 'deadbeef-dead-beef-dead-beefdeadbeef' - -DUMMY_CLOUD_SPEC = {'service_title': 'CLOUDTITLE', 'service_name': 'CLOUDNAME', - 'service_desc': 'A Cloud for migrid site', - 'service_provider_flavor': 'openstack', - 'service_hosts': 'https://myopenstack-cloud.org:5000/v3', - 'service_rules_of_conduct': 'rules-of-conduct.pdf', - 'service_max_user_instances': '0', - 'service_max_user_instances_map': {DUMMY_USER: '1'}, - 'service_allowed_images': DUMMY_IMAGE, - 'service_allowed_images_map': {DUMMY_USER: 'ALL'}, - 'service_user_map': {DUMMY_IMAGE, 'user'}, - 'service_image_alias_map': {DUMMY_IMAGE.lower(): - DUMMY_IMAGE}, - 'service_flavor_id': DUMMY_HEX_ID, - 'service_flavor_id_map': {DUMMY_USER: DUMMY_HEX_ID}, - 'service_network_id': DUMMY_HEX_ID, - 'service_key_id_map': {}, - 'service_sec_group_id': DUMMY_HEX_ID, - 'service_floating_network_id': DUMMY_HEX_ID, - 'service_availability_zone': 'myopenstack', - 'service_jumphost_address': 'jumphost.somewhere.org', - 'service_jumphost_user': 'cloud', - 'service_jumphost_manage_keys_script': - 'cloud_manage_keys.py', - 'service_jumphost_manage_keys_coding': 'base16', - 'service_network_id_map': {}, - 'service_sec_group_id_map': {}, - 'service_floating_network_id_map': {}, - 'service_availability_zone_map': {}, - 'service_jumphost_address_map': {}, - 'service_jumphost_user_map': {}} -DUMMY_CONF = FakeConfiguration(user_settings=DUMMY_SETTINGS_PATH, - site_cloud_access=[('distinguished_name', '.*')], - cloud_services=[DUMMY_CLOUD_SPEC]) - -DUMMY_INSTANCE_ID = '%s:%s:%s' % (DUMMY_USER, DUMMY_LABEL, DUMMY_HEX_ID) +DUMMY_FLAVOR = "openstack" +DUMMY_LABEL = "dummy-label" +DUMMY_IMAGE = "dummy-image" +DUMMY_HEX_ID = "deadbeef-dead-beef-dead-beefdeadbeef" + +DUMMY_CLOUD_SPEC = { + "service_title": "CLOUDTITLE", + "service_name": "CLOUDNAME", + "service_desc": "A Cloud for migrid site", + "service_provider_flavor": "openstack", + "service_hosts": "https://myopenstack-cloud.org:5000/v3", + "service_rules_of_conduct": "rules-of-conduct.pdf", + "service_max_user_instances": "0", + "service_max_user_instances_map": {DUMMY_USER: "1"}, + "service_allowed_images": DUMMY_IMAGE, + "service_allowed_images_map": {DUMMY_USER: "ALL"}, + "service_user_map": {DUMMY_IMAGE, "user"}, + "service_image_alias_map": {DUMMY_IMAGE.lower(): DUMMY_IMAGE}, + "service_flavor_id": DUMMY_HEX_ID, + "service_flavor_id_map": {DUMMY_USER: DUMMY_HEX_ID}, + "service_network_id": DUMMY_HEX_ID, + "service_key_id_map": {}, + "service_sec_group_id": DUMMY_HEX_ID, + "service_floating_network_id": DUMMY_HEX_ID, + "service_availability_zone": "myopenstack", + "service_jumphost_address": "jumphost.somewhere.org", + "service_jumphost_user": "cloud", + "service_jumphost_manage_keys_script": "cloud_manage_keys.py", + "service_jumphost_manage_keys_coding": "base16", + "service_network_id_map": {}, + "service_sec_group_id_map": {}, + "service_floating_network_id_map": {}, + "service_availability_zone_map": {}, + "service_jumphost_address_map": {}, + "service_jumphost_user_map": {}, +} +DUMMY_CONF = FakeConfiguration( + user_settings=DUMMY_SETTINGS_PATH, + site_cloud_access=[("distinguished_name", ".*")], + cloud_services=[DUMMY_CLOUD_SPEC], +) + +DUMMY_INSTANCE_ID = "%s:%s:%s" % (DUMMY_USER, DUMMY_LABEL, DUMMY_HEX_ID) DUMMY_INSTANCE_DICT = { DUMMY_INSTANCE_ID: { - 'INSTANCE_LABEL': DUMMY_LABEL, - 'INSTANCE_IMAGE': DUMMY_IMAGE, - 'INSTANCE_ID': DUMMY_INSTANCE_ID, - 'IMAGE_ID': DUMMY_IMAGE, - 'CREATED_TIMESTAMP': "%d" % time.time(), - 'USER_CERT': DUMMY_USER + "INSTANCE_LABEL": DUMMY_LABEL, + "INSTANCE_IMAGE": DUMMY_IMAGE, + "INSTANCE_ID": DUMMY_INSTANCE_ID, + "IMAGE_ID": DUMMY_IMAGE, + "CREATED_TIMESTAMP": "%d" % time.time(), + "USER_CERT": DUMMY_USER, } } @@ -102,38 +113,46 @@ def test_cloud_save_load(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) cleanpath(DUMMY_SETTINGS_DIR, self) - save_status = cloud_save_instance(DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, - DUMMY_LABEL, DUMMY_INSTANCE_DICT) + save_status = cloud_save_instance( + DUMMY_CONF, + DUMMY_USER, + DUMMY_CLOUD, + DUMMY_LABEL, + DUMMY_INSTANCE_DICT, + ) self.assertTrue(save_status) - saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, - '%s.state' % DUMMY_CLOUD) + saved_path = os.path.join( + DUMMY_SETTINGS_PATH, DUMMY_USER, "%s.state" % DUMMY_CLOUD + ) self.assertTrue(os.path.exists(saved_path)) - instance = cloud_load_instance(DUMMY_CONF, DUMMY_USER, - DUMMY_CLOUD, DUMMY_LABEL) + instance = cloud_load_instance( + DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, DUMMY_LABEL + ) # NOTE: instance should be a non-empty dict at this point self.assertTrue(isinstance(instance, dict)) # print(instance) self.assertTrue(DUMMY_INSTANCE_ID in instance) instance_dict = instance[DUMMY_INSTANCE_ID] - self.assertEqual(instance_dict['INSTANCE_LABEL'], DUMMY_LABEL) - self.assertEqual(instance_dict['INSTANCE_IMAGE'], DUMMY_IMAGE) - self.assertEqual(instance_dict['INSTANCE_ID'], DUMMY_INSTANCE_ID) - self.assertEqual(instance_dict['IMAGE_ID'], DUMMY_IMAGE) - self.assertEqual(instance_dict['USER_CERT'], DUMMY_USER) + self.assertEqual(instance_dict["INSTANCE_LABEL"], DUMMY_LABEL) + self.assertEqual(instance_dict["INSTANCE_IMAGE"], DUMMY_IMAGE) + self.assertEqual(instance_dict["INSTANCE_ID"], DUMMY_INSTANCE_ID) + self.assertEqual(instance_dict["IMAGE_ID"], DUMMY_IMAGE) + self.assertEqual(instance_dict["USER_CERT"], DUMMY_USER) - @unittest.skip('Work in progress - currently requires remote openstack') + @unittest.skip("Work in progress - currently requires remote openstack") def test_cloud_allowed_images(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) cleanpath(DUMMY_SETTINGS_DIR, self) - allowed_images = allowed_cloud_images(DUMMY_CONF, DUMMY_USER, - DUMMY_CLOUD, DUMMY_FLAVOR) + allowed_images = allowed_cloud_images( + DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, DUMMY_FLAVOR + ) self.assertTrue(isinstance(allowed_images, list)) print(allowed_images) self.assertTrue(DUMMY_IMAGE in allowed_images) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_compat.py b/tests/test_mig_shared_compat.py index 05f2de2bc..faad707f7 100644 --- a/tests/test_mig_shared_compat.py +++ b/tests/test_mig_shared_compat.py @@ -31,13 +31,13 @@ import os import sys +from mig.shared.compat import PY2, ensure_native_string from tests.support import MigTestCase, testmain -from mig.shared.compat import PY2, ensure_native_string +DUMMY_BYTECHARS = b"DEADBEEF" +DUMMY_BYTESRAW = binascii.unhexlify("DEADBEEF") # 4 bytes +DUMMY_UNICODE = "UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®" -DUMMY_BYTECHARS = b'DEADBEEF' -DUMMY_BYTESRAW = binascii.unhexlify('DEADBEEF') # 4 bytes -DUMMY_UNICODE = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' class MigSharedCompat__ensure_native_string(MigTestCase): """Unit test helper for the migrid code pointed to in class name""" @@ -45,7 +45,7 @@ class MigSharedCompat__ensure_native_string(MigTestCase): def test_char_bytes_conversion(self): actual = ensure_native_string(DUMMY_BYTECHARS) self.assertIs(type(actual), str) - self.assertEqual(actual, 'DEADBEEF') + self.assertEqual(actual, "DEADBEEF") def test_raw_bytes_conversion(self): with self.assertRaises(UnicodeDecodeError): @@ -60,5 +60,5 @@ def test_unicode_conversion(self): self.assertEqual(actual, DUMMY_UNICODE) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_configuration.py b/tests/test_mig_shared_configuration.py index 1108ea3e0..c02b33559 100644 --- a/tests/test_mig_shared_configuration.py +++ b/tests/test_mig_shared_configuration.py @@ -31,16 +31,23 @@ import os import unittest -from tests.support import MigTestCase, TEST_DATA_DIR, PY2, testmain +from mig.shared.configuration import ( + _CONFIGURATION_ARGUMENTS, + _CONFIGURATION_PROPERTIES, + Configuration, +) +from tests.support import PY2, TEST_DATA_DIR, MigTestCase, testmain from tests.support.fixturesupp import FixtureAssertMixin -from mig.shared.configuration import Configuration, \ - _CONFIGURATION_ARGUMENTS, _CONFIGURATION_PROPERTIES - def _to_dict(obj): - return {k: v for k, v in inspect.getmembers(obj) - if not (k.startswith('__') or inspect.ismethod(v) or inspect.isfunction(v))} + return { + k: v + for k, v in inspect.getmembers(obj) + if not ( + k.startswith("__") or inspect.ismethod(v) or inspect.isfunction(v) + ) + } class MigSharedConfiguration__static_definitions(MigTestCase): @@ -50,8 +57,9 @@ def test_consistent_parameters(self): configuration_defaults_keys = set(_CONFIGURATION_PROPERTIES.keys()) mismatched = _CONFIGURATION_ARGUMENTS - configuration_defaults_keys - self.assertEqual(len(mismatched), 0, - "configuration defaults do not match arguments") + self.assertEqual( + len(mismatched), 0, "configuration defaults do not match arguments" + ) class MigSharedConfiguration__loaded_configurations(MigTestCase): @@ -59,19 +67,23 @@ class MigSharedConfiguration__loaded_configurations(MigTestCase): def test_argument_new_user_default_ui_is_replaced(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.new_user_default_ui, 'V3') + self.assertEqual(configuration.new_user_default_ui, "V3") def test_argument_storage_protocols(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # TODO: add a test to cover filtering of a mix of valid+invalid protos # self.assertEqual(configuration.storage_protocols, ['xxx', 'yyy', 'zzz']) @@ -81,90 +93,110 @@ def test_argument_storage_protocols(self): def test_argument_wwwserve_max_bytes(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.wwwserve_max_bytes, 43211234) def test_argument_include_sections(self): """Test that include_sections path default is set""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.include_sections, - '/home/mig/mig/server/MiGserver.d') + self.assertEqual( + configuration.include_sections, "/home/mig/mig/server/MiGserver.d" + ) def test_argument_custom_include_sections(self): """Test that include_sections path override is correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") self.assertTrue(os.path.isdir(test_conf_section_dir)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.include_sections, - test_conf_section_dir) + self.assertEqual(configuration.include_sections, test_conf_section_dir) def test_argument_include_sections_quota(self): """Test that QUOTA conf section overrides are correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'quota.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "quota.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertEqual(configuration.quota_backend, 'dummy') + self.assertEqual(configuration.quota_backend, "dummy") self.assertEqual(configuration.quota_user_limit, 4242) self.assertEqual(configuration.quota_vgrid_limit, 4242424242) def test_argument_include_sections_cloud_misty(self): """Test that CLOUD_MISTY conf section is correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'cloud_misty.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "cloud_misty.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.cloud_services, list) self.assertTrue(configuration.cloud_services) self.assertIsInstance(configuration.cloud_services[0], dict) - self.assertTrue(configuration.cloud_services[0].get('service_name', - False)) - self.assertEqual(configuration.cloud_services[0]['service_name'], - 'MISTY') - self.assertEqual(configuration.cloud_services[0]['service_desc'], - 'MISTY service') - self.assertEqual(configuration.cloud_services[0]['service_provider_flavor'], - 'nostack') + self.assertTrue( + configuration.cloud_services[0].get("service_name", False) + ) + self.assertEqual( + configuration.cloud_services[0]["service_name"], "MISTY" + ) + self.assertEqual( + configuration.cloud_services[0]["service_desc"], "MISTY service" + ) + self.assertEqual( + configuration.cloud_services[0]["service_provider_flavor"], + "nostack", + ) def test_argument_include_sections_global_accepted(self): """Test that peripheral GLOBAL conf overrides are accepted (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'global.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "global.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertEqual(configuration.admin_email, "admin@somewhere.org") @@ -176,93 +208,105 @@ def test_argument_include_sections_global_accepted(self): def test_argument_include_sections_global_rejected(self): """Test that core GLOBAL conf overrides are rejected (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'global.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "global.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Run through the snippet values and check that override didn't succeed # and then that default is left set. The former _could_ be left out but # is kept explicit for clarity in case something breaks by changes. - self.assertNotEqual(configuration.include_sections, '/tmp/MiGserver.d') + self.assertNotEqual(configuration.include_sections, "/tmp/MiGserver.d") self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertNotEqual(configuration.mig_path, '/tmp/mig/mig') - self.assertEqual(configuration.mig_path, '/home/mig/mig') - self.assertNotEqual(configuration.logfile, '/tmp/mig.log') - self.assertEqual(configuration.logfile, 'mig.log') - self.assertNotEqual(configuration.loglevel, 'warning') - self.assertEqual(configuration.loglevel, 'info') - self.assertNotEqual(configuration.server_fqdn, 'somewhere.org') - self.assertEqual(configuration.server_fqdn, '') - self.assertNotEqual(configuration.migserver_public_url, - 'https://somewhere.org') - self.assertEqual(configuration.migserver_public_url, '') - self.assertNotEqual(configuration.migserver_https_sid_url, - 'https://somewhere.org') - self.assertEqual(configuration.migserver_https_sid_url, '') - self.assertNotEqual(configuration.user_openid_address, 'somewhere.org') - self.assertNotEqual(configuration.user_openid_address, 'somewhere.org') - self.assertEqual(configuration.user_openid_address, '') + self.assertNotEqual(configuration.mig_path, "/tmp/mig/mig") + self.assertEqual(configuration.mig_path, "/home/mig/mig") + self.assertNotEqual(configuration.logfile, "/tmp/mig.log") + self.assertEqual(configuration.logfile, "mig.log") + self.assertNotEqual(configuration.loglevel, "warning") + self.assertEqual(configuration.loglevel, "info") + self.assertNotEqual(configuration.server_fqdn, "somewhere.org") + self.assertEqual(configuration.server_fqdn, "") + self.assertNotEqual( + configuration.migserver_public_url, "https://somewhere.org" + ) + self.assertEqual(configuration.migserver_public_url, "") + self.assertNotEqual( + configuration.migserver_https_sid_url, "https://somewhere.org" + ) + self.assertEqual(configuration.migserver_https_sid_url, "") + self.assertNotEqual(configuration.user_openid_address, "somewhere.org") + self.assertNotEqual(configuration.user_openid_address, "somewhere.org") + self.assertEqual(configuration.user_openid_address, "") self.assertNotEqual(configuration.user_openid_port, 4242) self.assertEqual(configuration.user_openid_port, 8443) - self.assertNotEqual(configuration.user_openid_key, '/tmp/openid.key') - self.assertEqual(configuration.user_openid_key, '') - self.assertNotEqual(configuration.user_openid_log, '/tmp/openid.log') - self.assertEqual(configuration.user_openid_log, - '/home/mig/state/log/openid.log') + self.assertNotEqual(configuration.user_openid_key, "/tmp/openid.key") + self.assertEqual(configuration.user_openid_key, "") + self.assertNotEqual(configuration.user_openid_log, "/tmp/openid.log") + self.assertEqual( + configuration.user_openid_log, "/home/mig/state/log/openid.log" + ) def test_argument_include_sections_site_accepted(self): """Test that peripheral SITE conf overrides are accepted (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'site.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "site.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertEqual(configuration.short_title, 'ACME Site') - self.assertEqual(configuration.new_user_default_ui, 'V3') - self.assertEqual(configuration.site_password_legacy_policy, 'MEDIUM') - self.assertEqual(configuration.site_support_text, - 'Custom support text') - self.assertEqual(configuration.site_privacy_text, - 'Custom privacy text') - self.assertEqual(configuration.site_peers_notice, - 'Custom peers notice') - self.assertEqual(configuration.site_peers_contact_hint, - 'Custom peers contact hint') + self.assertEqual(configuration.short_title, "ACME Site") + self.assertEqual(configuration.new_user_default_ui, "V3") + self.assertEqual(configuration.site_password_legacy_policy, "MEDIUM") + self.assertEqual(configuration.site_support_text, "Custom support text") + self.assertEqual(configuration.site_privacy_text, "Custom privacy text") + self.assertEqual(configuration.site_peers_notice, "Custom peers notice") + self.assertEqual( + configuration.site_peers_contact_hint, "Custom peers contact hint" + ) self.assertIsInstance(configuration.site_freeze_admins, list) self.assertTrue(len(configuration.site_freeze_admins) == 1) - self.assertTrue('BOFH' in configuration.site_freeze_admins) - self.assertEqual(configuration.site_freeze_to_tape, - 'Custom freeze to tape') - self.assertEqual(configuration.site_freeze_doi_text, - 'Custom freeze doi text') - self.assertEqual(configuration.site_freeze_doi_url, - 'https://somewhere.org/mint-doi') - self.assertEqual(configuration.site_freeze_doi_url_field, - 'archiveurl') + self.assertTrue("BOFH" in configuration.site_freeze_admins) + self.assertEqual( + configuration.site_freeze_to_tape, "Custom freeze to tape" + ) + self.assertEqual( + configuration.site_freeze_doi_text, "Custom freeze doi text" + ) + self.assertEqual( + configuration.site_freeze_doi_url, "https://somewhere.org/mint-doi" + ) + self.assertEqual(configuration.site_freeze_doi_url_field, "archiveurl") def test_argument_include_sections_site_rejected(self): """Test that core SITE conf overrides are rejected (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'site.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "site.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertEqual(configuration.site_enable_openid, False) @@ -279,56 +323,63 @@ def test_argument_include_sections_site_rejected(self): def test_argument_include_sections_with_invalid_conf_filename(self): """Test that conf snippet with missing .conf extension gets ignored""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'dummy') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join(test_conf_section_dir, "dummy") self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf only contains SETTINGS section which is ignored due to mismatch self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Pig Latin' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Pig Latin" in configuration.language) + self.assertEqual(configuration.language, ["English"]) def test_argument_include_sections_with_section_name_mismatch(self): """Test that conf section must match filename""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'section-mismatch.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "section-mismatch.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf only contains SETTINGS section which is ignored due to mismatch self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Pig Latin' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Pig Latin" in configuration.language) + self.assertEqual(configuration.language, ["English"]) def test_argument_include_sections_multi_ignores_other_sections(self): """Test that conf section must match filename and others are ignored""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'multi.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "multi.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf contains MULTI and SETTINGS sections and latter must be ignored self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Spanglish' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Spanglish" in configuration.language) + self.assertEqual(configuration.language, ["English"]) # TODO: rename file to valid section name we can check and enable next? # self.assertEqual(configuration.multi, 'blabla') @@ -339,15 +390,16 @@ class MigSharedConfiguration__new_instance(MigTestCase, FixtureAssertMixin): @unittest.skipIf(PY2, "Python 3 only") def test_default_object(self): prepared_fixture = self.prepareFixtureAssert( - 'mig_shared_configuration--new', fixture_format='json') + "mig_shared_configuration--new", fixture_format="json" + ) configuration = Configuration(None) # TODO: the following work-around default values set for these on the # instance that no longer make total sense but fiddling with them # is better as a follow-up. - configuration.certs_path = '/some/place/certs' - configuration.state_path = '/some/place/state' - configuration.mig_path = '/some/place/mig' + configuration.certs_path = "/some/place/certs" + configuration.state_path = "/some/place/state" + configuration.mig_path = "/some/place/mig" actual_values = _to_dict(configuration) @@ -358,11 +410,11 @@ def test_object_isolation(self): configuration_2 = Configuration(None) # change one of the configuration objects - configuration_1.default_page.append('foobar') + configuration_1.default_page.append("foobar") # check the other was not affected - self.assertEqual(configuration_2.default_page, ['']) + self.assertEqual(configuration_2.default_page, [""]) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_fileio.py b/tests/test_mig_shared_fileio.py index f43f013e8..277da3a7c 100644 --- a/tests/test_mig_shared_fileio.py +++ b/tests/test_mig_shared_fileio.py @@ -35,52 +35,53 @@ # Imports of the code under test import mig.shared.fileio as fileio + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -DUMMY_BYTES = binascii.unhexlify('DEADBEEF') # 4 bytes +DUMMY_BYTES = binascii.unhexlify("DEADBEEF") # 4 bytes DUMMY_BYTES_LENGTH = 4 -DUMMY_UNICODE = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' +DUMMY_UNICODE = "UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®" DUMMY_UNICODE_LENGTH = len(DUMMY_UNICODE) -DUMMY_TEXT = 'dummy' -DUMMY_TWICE = 'dummy - dummy' -DUMMY_TESTDIR = 'fileio' -DUMMY_SUBDIR = 'subdir' -DUMMY_FILE_ONE = 'file1.txt' -DUMMY_FILE_TWO = 'file2.txt' -DUMMY_FILE_MISSING = 'missing.txt' -DUMMY_FILE_RO = 'readonly.txt' -DUMMY_FILE_WO = 'writeonly.txt' -DUMMY_FILE_RW = 'readwrite.txt' -DUMMY_DIRECTORY_NESTED = 'nested/dir/structure' -DUMMY_DIRECTORY_EMPTY = 'empty_dir' -DUMMY_DIRECTORY_MOVE_SRC = 'move_dir_src' -DUMMY_DIRECTORY_MOVE_DST = 'move_dir_dst' -DUMMY_DIRECTORY_REMOVE = 'remove_dir' -DUMMY_DIRECTORY_CHECKACCESS = 'check_access' -DUMMY_DIRECTORY_MAKEDIRSREC = 'makedirs_rec' -DUMMY_DIRECTORY_COPYRECSRC = 'copy_dir_src' -DUMMY_DIRECTORY_COPYRECDST = 'copy_dir_dst' -DUMMY_DIRECTORY_REMOVEREC = 'remove_rec' +DUMMY_TEXT = "dummy" +DUMMY_TWICE = "dummy - dummy" +DUMMY_TESTDIR = "fileio" +DUMMY_SUBDIR = "subdir" +DUMMY_FILE_ONE = "file1.txt" +DUMMY_FILE_TWO = "file2.txt" +DUMMY_FILE_MISSING = "missing.txt" +DUMMY_FILE_RO = "readonly.txt" +DUMMY_FILE_WO = "writeonly.txt" +DUMMY_FILE_RW = "readwrite.txt" +DUMMY_DIRECTORY_NESTED = "nested/dir/structure" +DUMMY_DIRECTORY_EMPTY = "empty_dir" +DUMMY_DIRECTORY_MOVE_SRC = "move_dir_src" +DUMMY_DIRECTORY_MOVE_DST = "move_dir_dst" +DUMMY_DIRECTORY_REMOVE = "remove_dir" +DUMMY_DIRECTORY_CHECKACCESS = "check_access" +DUMMY_DIRECTORY_MAKEDIRSREC = "makedirs_rec" +DUMMY_DIRECTORY_COPYRECSRC = "copy_dir_src" +DUMMY_DIRECTORY_COPYRECDST = "copy_dir_dst" +DUMMY_DIRECTORY_REMOVEREC = "remove_rec" # File/dir paths for move/copy operations -DUMMY_FILE_MOVE_SRC = 'move_src' -DUMMY_FILE_MOVE_DST = 'move_dst' -DUMMY_FILE_COPY_SRC = 'copy_src' -DUMMY_FILE_COPY_DST = 'copy_dst' -DUMMY_FILE_WRITECHUNK = 'write_chunk' -DUMMY_FILE_WRITEFILE = 'write_file' -DUMMY_FILE_WRITEFILELINES = 'write_file_lines' -DUMMY_FILE_READFILE = 'read_file' -DUMMY_FILE_READFILELINES = 'read_file_lines' -DUMMY_FILE_READHEADLINES = 'read_head_lines' -DUMMY_FILE_READTAILLINES = 'read_tail_lines' -DUMMY_FILE_DELETEFILE = 'delete_file' -DUMMY_FILE_GETFILESIZE = 'get_file_size' -DUMMY_FILE_MAKESYMLINKSRC = 'link_src' -DUMMY_FILE_MAKESYMLINKDST = 'link_target' -DUMMY_FILE_DELETESYMLINKSRC = 'link_src' -DUMMY_FILE_DELETESYMLINKDST = 'link_target' -DUMMY_FILE_TOUCH = 'touch_file' +DUMMY_FILE_MOVE_SRC = "move_src" +DUMMY_FILE_MOVE_DST = "move_dst" +DUMMY_FILE_COPY_SRC = "copy_src" +DUMMY_FILE_COPY_DST = "copy_dst" +DUMMY_FILE_WRITECHUNK = "write_chunk" +DUMMY_FILE_WRITEFILE = "write_file" +DUMMY_FILE_WRITEFILELINES = "write_file_lines" +DUMMY_FILE_READFILE = "read_file" +DUMMY_FILE_READFILELINES = "read_file_lines" +DUMMY_FILE_READHEADLINES = "read_head_lines" +DUMMY_FILE_READTAILLINES = "read_tail_lines" +DUMMY_FILE_DELETEFILE = "delete_file" +DUMMY_FILE_GETFILESIZE = "get_file_size" +DUMMY_FILE_MAKESYMLINKSRC = "link_src" +DUMMY_FILE_MAKESYMLINKDST = "link_target" +DUMMY_FILE_DELETESYMLINKSRC = "link_src" +DUMMY_FILE_DELETESYMLINKDST = "link_target" +DUMMY_FILE_TOUCH = "touch_file" assert isinstance(DUMMY_BYTES, bytes) @@ -90,12 +91,13 @@ class MigSharedFileio__temporary_umask(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_file_one = os.path.join(self.tmp_base, DUMMY_FILE_ONE) try: @@ -121,63 +123,63 @@ def before_each(self): def test_creates_new_file_with_temporary_umask_777(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o777): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o000) def test_creates_new_file_with_temporary_umask_277(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o277): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o400) def test_creates_new_file_with_temporary_umask_227(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o227): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o440) def test_creates_new_file_with_temporary_umask_077(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o077): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o600) def test_creates_new_file_with_temporary_umask_027(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o027): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o640) def test_creates_new_file_with_temporary_umask_007(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o007): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o660) def test_creates_new_file_with_temporary_umask_022(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o022): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o644) def test_creates_new_file_with_temporary_umask_002(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o002): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o664) def test_creates_new_file_with_temporary_umask_000(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o000): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o666) @@ -257,8 +259,9 @@ def test_restores_original_umask_after_exit(self): current_umask = os.umask(original_umask) # Retrieve and reset # Cleanup: Restore environment os.umask(current_umask) - self.assertEqual(current_umask, 0o022, - "Failed to restore original umask") + self.assertEqual( + current_umask, 0o022, "Failed to restore original umask" + ) finally: os.umask(original_umask) # Ensure cleanup @@ -267,20 +270,20 @@ def test_nested_temporary_umask(self): original_umask = os.umask(0o022) try: with fileio.temporary_umask(0o027): # Outer context - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode1 = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode1, 0o640) # 666 & ~027 = 640 with fileio.temporary_umask(0o077): # Inner context - open(self.tmp_file_two, 'w').close() + open(self.tmp_file_two, "w").close() mode2 = os.stat(self.tmp_file_two).st_mode & 0o777 self.assertEqual(mode2, 0o600) # 666 & ~077 # Back to outer context umask os.remove(self.tmp_file_one) - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode1_after = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode1_after, 0o640) # 666 & ~027 # Back to original umask - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode_original = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode_original, 0o640) # 666 & ~002 finally: @@ -309,7 +312,7 @@ def test_restores_umask_after_exception(self): def test_umask_does_not_affect_existing_files(self): """Test temporary_umask doesn't modify existing file permissions""" - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() os.chmod(self.tmp_file_one, 0o644) # Explicit permissions with fileio.temporary_umask(0o077): # Shouldn't affect existing file # Change permissions inside context @@ -324,12 +327,13 @@ class MigSharedFileio__write_chunk(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_WRITECHUNK) @@ -338,16 +342,18 @@ def test_return_false_on_invalid_data(self): self.logger.forgive_errors() # NOTE: we make sure to disable any forced stringification here - did_succeed = fileio.write_chunk(self.tmp_path, 1234, 0, self.logger, - force_string=False) + did_succeed = fileio.write_chunk( + self.tmp_path, 1234, 0, self.logger, force_string=False + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_offset(self): """Test write_chunk returns False with negative offset value""" self.logger.forgive_errors() - did_succeed = fileio.write_chunk(self.tmp_path, DUMMY_BYTES, -42, - self.logger) + did_succeed = fileio.write_chunk( + self.tmp_path, DUMMY_BYTES, -42, self.logger + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_dir(self): @@ -368,7 +374,7 @@ def test_store_bytes(self): """Test write_chunk stores byte data correctly at offset 0""" fileio.write_chunk(self.tmp_path, DUMMY_BYTES, 0, self.logger) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) @@ -379,42 +385,52 @@ def test_store_bytes_at_offset(self): fileio.write_chunk(self.tmp_path, DUMMY_BYTES, offset, self.logger) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH + offset) - self.assertEqual(content[0:3], bytearray([0, 0, 0]), - "expected a hole was left") + self.assertEqual( + content[0:3], bytearray([0, 0, 0]), "expected a hole was left" + ) self.assertEqual(content[3:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_bytes_in_text_mode(self): """Test write_chunk stores byte data in text mode""" - fileio.write_chunk(self.tmp_path, DUMMY_BYTES, 0, self.logger, - mode="r+") + fileio.write_chunk( + self.tmp_path, DUMMY_BYTES, 0, self.logger, mode="r+" + ) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode(self): """Test write_chunk stores unicode data in text mode""" - fileio.write_chunk(self.tmp_path, DUMMY_UNICODE, 0, self.logger, - mode='r+') + fileio.write_chunk( + self.tmp_path, DUMMY_UNICODE, 0, self.logger, mode="r+" + ) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode_in_binary_mode(self): """Test write_chunk stores unicode data in binary mode""" - fileio.write_chunk(self.tmp_path, DUMMY_UNICODE, 0, self.logger, - mode='r+b') + fileio.write_chunk( + self.tmp_path, DUMMY_UNICODE, 0, self.logger, mode="r+b" + ) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) @@ -425,12 +441,13 @@ class MigSharedFileio__write_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) # NOTE: we inject sub-directory to test with missing and existing self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) @@ -441,24 +458,25 @@ def test_return_false_on_invalid_data(self): self.logger.forgive_errors() # NOTE: we make sure to disable any forced stringification here - did_succeed = fileio.write_file(1234, self.tmp_path, self.logger, - force_string=False) + did_succeed = fileio.write_file( + 1234, self.tmp_path, self.logger, force_string=False + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_dir(self): """Test write_file returns False when path is a directory""" self.logger.forgive_errors() ensure_dirs_exist(self.tmp_path) - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, - self.logger) + did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger) self.assertFalse(did_succeed) def test_return_false_on_missing_dir(self): """Test write_file returns False on missing parent dir""" self.logger.forgive_errors() - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, - self.logger, make_parent=False) + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, make_parent=False + ) self.assertFalse(did_succeed) def test_creates_directory(self): @@ -466,7 +484,7 @@ def test_creates_directory(self): # TODO: temporarily use empty string to avoid any byte/unicode issues # did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, # self.logger) - did_succeed = fileio.write_file('', self.tmp_path, self.logger) + did_succeed = fileio.write_file("", self.tmp_path, self.logger) self.assertTrue(did_succeed) path_kind = self.assertPathExists(self.tmp_path) @@ -475,49 +493,59 @@ def test_creates_directory(self): # TODO: replace next test once we have auto adjust mode in write helper def test_store_bytes_with_manual_adjust_mode(self): """Test write_file stores byte data in with manual adjust mode call""" - mode = 'w' + mode = "w" mode = fileio._auto_adjust_mode(DUMMY_BYTES, mode) - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger, - mode=mode) + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, mode=mode + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_bytes_in_text_mode(self): """Test write_file stores byte data when opening in text mode""" - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger, - mode="w") + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, mode="w" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode(self): """Test write_file stores unicode string when opening in text mode""" - did_succeed = fileio.write_file(DUMMY_UNICODE, self.tmp_path, - self.logger, mode='w') + did_succeed = fileio.write_file( + DUMMY_UNICODE, self.tmp_path, self.logger, mode="w" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode_in_binary_mode(self): """Test write_file handles unicode strings when opening in binary mode""" - did_succeed = fileio.write_file(DUMMY_UNICODE, self.tmp_path, - self.logger, mode='wb') + did_succeed = fileio.write_file( + DUMMY_UNICODE, self.tmp_path, self.logger, mode="wb" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) @@ -528,12 +556,13 @@ class MigSharedFileio__write_file_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) # NOTE: we inject sub-directory to test with missing and existing self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) @@ -542,8 +571,7 @@ def before_each(self): def test_write_lines(self): """Test write_file_lines writes lines to a file""" test_lines = ["line1\n", "line2\n", "line3"] - result = fileio.write_file_lines( - test_lines, self.tmp_path, self.logger) + result = fileio.write_file_lines(test_lines, self.tmp_path, self.logger) self.assertTrue(result) # Verify with read_file_lines @@ -559,8 +587,7 @@ def test_invalid_data(self): def test_creates_directory(self): """Test write_file_lines creates parent directory when needed""" test_lines = ["test line"] - result = fileio.write_file_lines( - test_lines, self.tmp_path, self.logger) + result = fileio.write_file_lines(test_lines, self.tmp_path, self.logger) self.assertTrue(result) path_kind = self.assertPathExists(self.tmp_path) @@ -571,14 +598,16 @@ def test_return_false_on_invalid_dir(self): self.logger.forgive_errors() ensure_dirs_exist(self.tmp_path) result = fileio.write_file_lines( - [DUMMY_TEXT], self.tmp_path, self.logger) + [DUMMY_TEXT], self.tmp_path, self.logger + ) self.assertFalse(result) def test_return_false_on_missing_dir(self): """Test write_file_lines fails when parent directory missing""" self.logger.forgive_errors() - result = fileio.write_file_lines([DUMMY_TEXT], self.tmp_path, self.logger, - make_parent=False) + result = fileio.write_file_lines( + [DUMMY_TEXT], self.tmp_path, self.logger, make_parent=False + ) self.assertFalse(result) @@ -587,40 +616,43 @@ class MigSharedFileio__read_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READFILE) def test_reads_bytes(self): """Test read_file returns byte content with binary mode""" - with open(self.tmp_path, 'wb') as fh: + with open(self.tmp_path, "wb") as fh: fh.write(DUMMY_BYTES) - content = fileio.read_file(self.tmp_path, self.logger, mode='rb') + content = fileio.read_file(self.tmp_path, self.logger, mode="rb") self.assertEqual(content, DUMMY_BYTES) def test_reads_text(self): """Test read_file returns text with text mode""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write(DUMMY_UNICODE) - content = fileio.read_file(self.tmp_path, self.logger, mode='r') + content = fileio.read_file(self.tmp_path, self.logger, mode="r") self.assertEqual(content, DUMMY_UNICODE) def test_allows_missing_file(self): """Test read_file returns None with allow_missing=True""" content = fileio.read_file( - 'missing.txt', self.logger, allow_missing=True) + "missing.txt", self.logger, allow_missing=True + ) self.assertIsNone(content) def test_reports_missing_file(self): """Test read_file returns None with allow_missing=False""" self.logger.forgive_errors() content = fileio.read_file( - 'missing.txt', self.logger, allow_missing=False) + "missing.txt", self.logger, allow_missing=False + ) self.assertIsNone(content) def test_handles_directory_path(self): @@ -636,31 +668,32 @@ class MigSharedFileio__read_file_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READFILELINES) def test_returns_empty_list_for_empty_file(self): """Test read_file_lines returns empty list for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_file_lines(self.tmp_path, self.logger) self.assertEqual(lines, []) def test_reads_lines_from_file(self): """Test read_file_lines returns lines from text file""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3") lines = fileio.read_file_lines(self.tmp_path, self.logger) self.assertEqual(lines, ["line1\n", "line2\n", "line3"]) def test_none_for_missing_file(self): self.logger.forgive_errors() - lines = fileio.read_file_lines('missing.txt', self.logger) + lines = fileio.read_file_lines("missing.txt", self.logger) self.assertIsNone(lines) @@ -669,18 +702,19 @@ class MigSharedFileio__get_file_size(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_GETFILESIZE) def test_returns_file_size(self): """Test get_file_size returns correct file size""" - with open(self.tmp_path, 'wb') as fh: + with open(self.tmp_path, "wb") as fh: fh.write(DUMMY_BYTES) size = fileio.get_file_size(self.tmp_path, self.logger) self.assertEqual(size, DUMMY_BYTES_LENGTH) @@ -688,7 +722,7 @@ def test_returns_file_size(self): def test_handles_missing_file(self): """Test get_file_size returns -1 for missing file""" self.logger.forgive_errors() - size = fileio.get_file_size('missing.txt', self.logger) + size = fileio.get_file_size("missing.txt", self.logger) self.assertEqual(size, -1) def test_handles_directory(self): @@ -707,18 +741,19 @@ class MigSharedFileio__delete_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_DELETEFILE) def test_deletes_existing_file(self): """Test delete_file removes existing file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() result = fileio.delete_file(self.tmp_path, self.logger) self.assertTrue(result) self.assertFalse(os.path.exists(self.tmp_path)) @@ -726,15 +761,16 @@ def test_deletes_existing_file(self): def test_handles_missing_file_with_allow_missing(self): """Test delete_file succeeds with allow_missing=True""" result = fileio.delete_file( - 'missing.txt', self.logger, allow_missing=True) + "missing.txt", self.logger, allow_missing=True + ) self.assertTrue(result) def test_false_for_missing_file_without_allow_missing(self): """Test delete_file returns False with allow_missing=False""" self.logger.forgive_errors() - result = fileio.delete_file('missing.txt', - self.logger, - allow_missing=False) + result = fileio.delete_file( + "missing.txt", self.logger, allow_missing=False + ) self.assertFalse(result) @@ -743,39 +779,40 @@ class MigSharedFileio__read_head_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READHEADLINES) def test_reads_requested_lines(self): """Test read_head_lines returns requested number of lines""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3\nline4") lines = fileio.read_head_lines(self.tmp_path, 2, self.logger) self.assertEqual(lines, ["line1\n", "line2\n"]) def test_returns_all_lines_when_requested_more(self): """Test read_head_lines returns all lines when file has fewer""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2") lines = fileio.read_head_lines(self.tmp_path, 5, self.logger) self.assertEqual(lines, ["line1\n", "line2"]) def test_returns_empty_list_for_empty_file(self): """Test read_head_lines returns empty for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_head_lines(self.tmp_path, 3, self.logger) self.assertEqual(lines, []) def test_empty_for_missing_file(self): """Test read_head_lines returns [] for missing file""" self.logger.forgive_errors() - lines = fileio.read_head_lines('missing.txt', 3, self.logger) + lines = fileio.read_head_lines("missing.txt", 3, self.logger) self.assertEqual(lines, []) @@ -784,39 +821,40 @@ class MigSharedFileio__read_tail_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READTAILLINES) def test_reads_requested_lines(self): """Test read_tail_lines returns requested number of lines""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3\nline4") lines = fileio.read_tail_lines(self.tmp_path, 2, self.logger) self.assertEqual(lines, ["line3\n", "line4"]) def test_returns_all_lines_when_requested_more(self): """Test read_tail_lines returns all lines when file has fewer""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2") lines = fileio.read_tail_lines(self.tmp_path, 5, self.logger) self.assertEqual(lines, ["line1\n", "line2"]) def test_returns_empty_list_for_empty_file(self): """Test read_tail_lines returns empty for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_tail_lines(self.tmp_path, 3, self.logger) self.assertEqual(lines, []) def test_empty_for_missing_file(self): """Test read_tail_lines returns [] for missing file""" self.logger.forgive_errors() - lines = fileio.read_tail_lines('missing.txt', 3, self.logger) + lines = fileio.read_tail_lines("missing.txt", 3, self.logger) self.assertEqual(lines, []) @@ -825,49 +863,52 @@ class MigSharedFileio__make_symlink(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) ensure_dirs_exist(self.tmp_dir) self.tmp_link = os.path.join(self.tmp_dir, DUMMY_FILE_MAKESYMLINKSRC) - self.tmp_target = os.path.join(self.tmp_dir, - DUMMY_FILE_MAKESYMLINKDST) - with open(self.tmp_target, 'w') as fh: + self.tmp_target = os.path.join(self.tmp_dir, DUMMY_FILE_MAKESYMLINKDST) + with open(self.tmp_target, "w") as fh: fh.write(DUMMY_TEXT) def test_creates_symlink(self): """Test make_symlink creates working symlink""" result = fileio.make_symlink( - self.tmp_target, self.tmp_link, self.logger) + self.tmp_target, self.tmp_link, self.logger + ) self.assertTrue(result) self.assertTrue(os.path.islink(self.tmp_link)) self.assertEqual(os.readlink(self.tmp_link), self.tmp_target) def test_force_overwrites_existing_link(self): """Test make_symlink force replaces existing link""" - os.symlink('/dummy', self.tmp_link) - result = fileio.make_symlink(self.tmp_target, self.tmp_link, - self.logger, force=True) + os.symlink("/dummy", self.tmp_link) + result = fileio.make_symlink( + self.tmp_target, self.tmp_link, self.logger, force=True + ) self.assertTrue(result) self.assertEqual(os.readlink(self.tmp_link), self.tmp_target) def test_fails_on_existing_link_without_force(self): """Test make_symlink fails on existing link without force""" self.logger.forgive_errors() - os.symlink('/dummy', self.tmp_link) - result = fileio.make_symlink(self.tmp_target, self.tmp_link, self.logger, - force=False) + os.symlink("/dummy", self.tmp_link) + result = fileio.make_symlink( + self.tmp_target, self.tmp_link, self.logger, force=False + ) self.assertFalse(result) def test_handles_nonexistent_target(self): """Test make_symlink still creates broken symlink""" self.logger.forgive_errors() - broken_target = self.tmp_target + '-nonexistent' + broken_target = self.tmp_target + "-nonexistent" result = fileio.make_symlink(broken_target, self.tmp_link, self.logger) self.assertTrue(result) self.assertTrue(os.path.islink(self.tmp_link)) @@ -879,20 +920,21 @@ class MigSharedFileio__delete_symlink(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) ensure_dirs_exist(self.tmp_dir) - self.tmp_link = os.path.join(self.tmp_dir, - DUMMY_FILE_DELETESYMLINKSRC) - self.tmp_target = os.path.join(self.tmp_dir, - DUMMY_FILE_DELETESYMLINKDST) - with open(self.tmp_target, 'w') as fh: + self.tmp_link = os.path.join(self.tmp_dir, DUMMY_FILE_DELETESYMLINKSRC) + self.tmp_target = os.path.join( + self.tmp_dir, DUMMY_FILE_DELETESYMLINKDST + ) + with open(self.tmp_target, "w") as fh: fh.write(DUMMY_TEXT) def create_symlink(self, target=None, link=None): @@ -915,33 +957,36 @@ def test_handles_missing_file_with_allow_missing(self): # First make sure file doesn't exist if os.path.exists(self.tmp_link): os.remove(self.tmp_link) - result = fileio.delete_symlink(self.tmp_link, self.logger, - allow_missing=True) + result = fileio.delete_symlink( + self.tmp_link, self.logger, allow_missing=True + ) self.assertTrue(result) def test_handles_missing_symlink_without_allow_missing(self): """Test delete_symlink fails with allow_missing=False""" self.logger.forgive_errors() - result = fileio.delete_symlink('missing_symlink', self.logger, - allow_missing=False) + result = fileio.delete_symlink( + "missing_symlink", self.logger, allow_missing=False + ) self.assertFalse(result) @unittest.skip("TODO: implement check in tested function and enable again") def test_rejects_regular_file(self): """Test delete_symlink returns False when path is a regular file""" - with open(self.tmp_link, 'w') as fh: + with open(self.tmp_link, "w") as fh: fh.write(DUMMY_TEXT) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: result = fileio.delete_symlink(self.tmp_link, self.logger) self.assertFalse(result) - self.assertTrue(any('Could not remove' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Could not remove" in msg for msg in log_capture.output) + ) def test_deletes_broken_symlink(self): """Test delete_symlink removes broken symlink""" # Create broken symlink - broken_target = self.tmp_target + '-nonexistent' + broken_target = self.tmp_target + "-nonexistent" self.create_symlink(broken_target) self.assertTrue(os.path.islink(self.tmp_link)) # Now delete it @@ -954,12 +999,13 @@ class MigSharedFileio__touch(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_TOUCH) @@ -971,11 +1017,13 @@ def test_creates_new_file(self): self.assertTrue(os.path.exists(self.tmp_path)) self.assertTrue(os.path.isfile(self.tmp_path)) - @unittest.skip("TODO: fix invalid open 'r+w' in tested function and enable again") + @unittest.skip( + "TODO: fix invalid open 'r+w' in tested function and enable again" + ) def test_updates_timestamp_on_existing_file(self): """Test touch updates timestamp on existing file""" # Create initial file - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write(DUMMY_TEXT) orig_mtime = os.path.getmtime(self.tmp_path) time.sleep(0.1) @@ -984,7 +1032,9 @@ def test_updates_timestamp_on_existing_file(self): new_mtime = os.path.getmtime(self.tmp_path) self.assertNotEqual(orig_mtime, new_mtime) - @unittest.skip("TODO: fix handling of directory in tested function and enable again") + @unittest.skip( + "TODO: fix handling of directory in tested function and enable again" + ) def test_succeeds_on_directory(self): """Test touch succeeds for existing directory and updates timestamp""" ensure_dirs_exist(self.tmp_path) @@ -999,7 +1049,7 @@ def test_succeeds_on_directory(self): def test_fails_on_missing_parent(self): """Test touch fails when parent directory doesn't exist""" self.logger.forgive_errors() - nested_path = os.path.join(self.tmp_path, 'missing', DUMMY_FILE_ONE) + nested_path = os.path.join(self.tmp_path, "missing", DUMMY_FILE_ONE) result = fileio.touch(nested_path, self.configuration) self.assertFalse(result) self.assertFalse(os.path.exists(nested_path)) @@ -1010,12 +1060,13 @@ class MigSharedFileio__remove_dir(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVE) # NOTE: we prepare tmp_path as directory here @@ -1032,7 +1083,7 @@ def test_fails_on_nonempty_directory(self): """Test remove_dir returns False for non-empty directory""" self.logger.forgive_errors() # Add a file to the directory - with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) result = fileio.remove_dir(self.tmp_path, self.configuration) self.assertFalse(result) @@ -1043,7 +1094,7 @@ def test_fails_on_file(self): self.logger.forgive_errors() # Add a file to the directory file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with open(file_path, 'w') as fh: + with open(file_path, "w") as fh: fh.write(DUMMY_TEXT) result = fileio.remove_dir(file_path, self.configuration) self.assertFalse(result) @@ -1055,12 +1106,13 @@ class MigSharedFileio__remove_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVEREC) # Create a nested directory structure with files @@ -1069,10 +1121,11 @@ def before_each(self): # └── subdir/ # └── file2.txt ensure_dirs_exist(os.path.join(self.tmp_path, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_path, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_path, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_removes_directory_recursively(self): @@ -1085,7 +1138,7 @@ def test_removes_directory_recursively(self): def test_removes_directory_recursively_with_symlink(self): """Test remove_rec removes directory and contents with symlink""" link_src = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - link_dst = os.path.join(self.tmp_path, DUMMY_FILE_ONE + '.lnk') + link_dst = os.path.join(self.tmp_path, DUMMY_FILE_ONE + ".lnk") os.symlink(link_src, link_dst) self.assertTrue(os.path.exists(self.tmp_path)) result = fileio.remove_rec(self.tmp_path, self.configuration) @@ -1095,7 +1148,7 @@ def test_removes_directory_recursively_with_symlink(self): def test_removes_directory_recursively_with_broken_symlink(self): """Test remove_rec removes directory and contents with broken symlink""" link_src = os.path.join(self.tmp_path, DUMMY_FILE_MISSING) - link_dst = os.path.join(self.tmp_path, DUMMY_FILE_MISSING + '.lnk') + link_dst = os.path.join(self.tmp_path, DUMMY_FILE_MISSING + ".lnk") os.symlink(link_src, link_dst) self.assertTrue(os.path.exists(self.tmp_path)) result = fileio.remove_rec(self.tmp_path, self.configuration) @@ -1115,11 +1168,12 @@ def test_removes_directory_recursively_despite_readonly(self): def test_rejects_regular_file(self): """Test remove_rec returns False when path is a regular file""" file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: result = fileio.remove_rec(file_path, self.configuration) self.assertFalse(result) - self.assertTrue(any('Could not remove' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Could not remove" in msg for msg in log_capture.output) + ) self.assertTrue(os.path.exists(file_path)) @@ -1128,22 +1182,24 @@ class MigSharedFileio__move_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_FILE_MOVE_SRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_FILE_MOVE_DST) - with open(self.tmp_src, 'w') as fh: + with open(self.tmp_src, "w") as fh: fh.write(DUMMY_TEXT) def test_moves_file(self): """Test move_file successfully moves a file""" - success, msg = fileio.move_file(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_file( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) self.assertFalse(os.path.exists(self.tmp_src)) @@ -1152,13 +1208,14 @@ def test_moves_file(self): def test_overwrites_existing_destination(self): """Test move_file overwrites existing destination file""" # Create initial destination file - with open(self.tmp_dst, 'w') as fh: + with open(self.tmp_dst, "w") as fh: fh.write(DUMMY_TWICE) - success, msg = fileio.move_file(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_file( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) - with open(self.tmp_dst, 'r') as fh: + with open(self.tmp_dst, "r") as fh: content = fh.read() self.assertEqual(content, DUMMY_TEXT) @@ -1168,12 +1225,13 @@ class MigSharedFileio__move_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVE) self.tmp_src = os.path.join(self.tmp_base, DUMMY_DIRECTORY_MOVE_SRC) @@ -1184,43 +1242,54 @@ def before_each(self): # └── subdir/ # └── file2.txt ensure_dirs_exist(os.path.join(self.tmp_src, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_src, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_src, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_moves_directory_recursively(self): """Test move_rec moves directory and contents""" - result = fileio.move_rec(self.tmp_src, self.tmp_dst, - self.configuration) + result = fileio.move_rec(self.tmp_src, self.tmp_dst, self.configuration) self.assertTrue(result) self.assertFalse(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) # Verify structure - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, - DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, DUMMY_SUBDIR, - DUMMY_FILE_TWO))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO) + ) + ) def test_extends_existing_destination(self): """Test move_rec extends existing destination directory""" # Create initial destination with some content ensure_dirs_exist(os.path.join(self.tmp_dst, DUMMY_TESTDIR)) - success, msg = fileio.move_rec(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_rec( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) # Verify structure with new src subdir and existing dir new_sub = os.path.basename(DUMMY_DIRECTORY_MOVE_SRC) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, new_sub, - DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, new_sub, - DUMMY_SUBDIR, - DUMMY_FILE_TWO))) - self.assertTrue(os.path.exists( - os.path.join(self.tmp_dst, DUMMY_TESTDIR))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, new_sub, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join( + self.tmp_dst, new_sub, DUMMY_SUBDIR, DUMMY_FILE_TWO + ) + ) + ) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_TESTDIR)) + ) class MigSharedFileio__copy_file(MigTestCase): @@ -1228,23 +1297,25 @@ class MigSharedFileio__copy_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_FILE_COPY_SRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_FILE_COPY_DST) - with open(self.tmp_src, 'w') as fh: + with open(self.tmp_src, "w") as fh: fh.write(DUMMY_TEXT) def test_copies_file(self): """Test copy_file successfully copies a file""" result = fileio.copy_file( - self.tmp_src, self.tmp_dst, self.configuration) + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(result) self.assertTrue(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) @@ -1252,12 +1323,13 @@ def test_copies_file(self): def test_overwrites_existing_destination(self): """Test copy_file overwrites existing destination file""" # Create initial destination file - with open(self.tmp_dst, 'w') as fh: + with open(self.tmp_dst, "w") as fh: fh.write(DUMMY_TWICE) result = fileio.copy_file( - self.tmp_src, self.tmp_dst, self.configuration) + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(result) - with open(self.tmp_dst, 'r') as fh: + with open(self.tmp_dst, "r") as fh: content = fh.read() self.assertEqual(content, DUMMY_TEXT) @@ -1267,36 +1339,41 @@ class MigSharedFileio__copy_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_DIRECTORY_COPYRECSRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_DIRECTORY_COPYRECDST) # Create a nested directory structure with files ensure_dirs_exist(self.tmp_src) ensure_dirs_exist(os.path.join(self.tmp_src, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_src, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_src, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_copies_directory_recursively(self): """Test copy_rec copies directory and contents""" - result = fileio.copy_rec( - self.tmp_src, self.tmp_dst, self.configuration) + result = fileio.copy_rec(self.tmp_src, self.tmp_dst, self.configuration) self.assertTrue(result) self.assertTrue(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) # Verify structure - self.assertTrue(os.path.exists(os.path.join( - self.tmp_dst, DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join( - self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO) + ) + ) class MigSharedFileio__check_empty_dir(MigTestCase): @@ -1304,20 +1381,20 @@ class MigSharedFileio__check_empty_dir(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.empty_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_EMPTY) - self.nonempty_path = os.path.join( - self.tmp_base, DUMMY_DIRECTORY_NESTED) + self.nonempty_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_NESTED) ensure_dirs_exist(self.empty_path) # Create non-empty directory structure ensure_dirs_exist(self.nonempty_path) - with open(os.path.join(self.nonempty_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.nonempty_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) def test_returns_true_for_empty(self): @@ -1340,20 +1417,21 @@ class MigSharedFileio__makedirs_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) - self.tmp_path = os.path.join(self.tmp_base, - DUMMY_DIRECTORY_MAKEDIRSREC) + self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_MAKEDIRSREC) def test_creates_directory_path(self): """Test makedirs_rec creates nested directories""" - nested_path = os.path.join(self.tmp_path, DUMMY_TESTDIR, DUMMY_SUBDIR, - DUMMY_TESTDIR) + nested_path = os.path.join( + self.tmp_path, DUMMY_TESTDIR, DUMMY_SUBDIR, DUMMY_TESTDIR + ) result = fileio.makedirs_rec(nested_path, self.configuration) self.assertTrue(result) self.assertTrue(os.path.exists(nested_path)) @@ -1370,7 +1448,7 @@ def test_fails_for_file_path(self): # Create a file at the path ensure_dirs_exist(self.tmp_path) file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with open(file_path, 'w') as fh: + with open(file_path, "w") as fh: fh.write(DUMMY_TEXT) result = fileio.makedirs_rec(file_path, self.configuration) self.assertFalse(result) @@ -1381,26 +1459,26 @@ class MigSharedFileio__check_access(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) - self.tmp_dir = os.path.join(self.tmp_base, - DUMMY_DIRECTORY_CHECKACCESS) + self.tmp_dir = os.path.join(self.tmp_base, DUMMY_DIRECTORY_CHECKACCESS) ensure_dirs_exist(self.tmp_dir) self.writeonly_file = os.path.join(self.tmp_dir, DUMMY_FILE_WO) self.readonly_file = os.path.join(self.tmp_dir, DUMMY_FILE_RO) self.readwrite_file = os.path.join(self.tmp_dir, DUMMY_FILE_RW) # Create test files with different permissions - with open(self.writeonly_file, 'w') as fh: + with open(self.writeonly_file, "w") as fh: fh.write(DUMMY_TEXT) - with open(self.readonly_file, 'w') as fh: + with open(self.readonly_file, "w") as fh: fh.write(DUMMY_TEXT) - with open(self.readwrite_file, 'w') as fh: + with open(self.readwrite_file, "w") as fh: fh.write(DUMMY_TEXT) # Set permissions @@ -1414,14 +1492,13 @@ def test_check_read_access_file(self): """Test check_read_access with readable file""" self.assertTrue(fileio.check_read_access(self.readwrite_file)) self.assertTrue(fileio.check_read_access(self.readonly_file)) - self.assertTrue(fileio.check_read_access(self.tmp_dir, - parent_dir=True)) + self.assertTrue(fileio.check_read_access(self.tmp_dir, parent_dir=True)) # Super-user has access to read and write all files! if os.getuid() == 0: self.assertTrue(fileio.check_read_access(self.writeonly_file)) else: self.assertFalse(fileio.check_read_access(self.writeonly_file)) - self.assertFalse(fileio.check_read_access('/invalid/path')) + self.assertFalse(fileio.check_read_access("/invalid/path")) def test_check_write_access_file(self): """Test check_write_access with writable file""" @@ -1432,7 +1509,7 @@ def test_check_write_access_file(self): self.assertTrue(fileio.check_write_access(self.readonly_file)) else: self.assertFalse(fileio.check_write_access(self.readonly_file)) - self.assertFalse(fileio.check_write_access('/invalid/path')) + self.assertFalse(fileio.check_write_access("/invalid/path")) def test_check_read_access_with_parent(self): """Test check_read_access with parent_dir True""" @@ -1448,78 +1525,108 @@ def test_check_write_access_with_parent(self): def test_check_readable(self): """Test check_readable wrapper function""" - self.assertTrue(fileio.check_readable(self.configuration, - self.readwrite_file)) - self.assertTrue(fileio.check_readable(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_readable(self.configuration, self.readwrite_file) + ) + self.assertTrue( + fileio.check_readable(self.configuration, self.readonly_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_readable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_readable(self.configuration, self.writeonly_file) + ) else: - self.assertFalse(fileio.check_readable(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readable(self.configuration, - '/invalid/path')) + self.assertFalse( + fileio.check_readable(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readable(self.configuration, "/invalid/path") + ) def test_check_writable(self): """Test check_writable wrapper function""" - self.assertTrue(fileio.check_writable(self.configuration, - self.readwrite_file)) - self.assertTrue(fileio.check_writable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_writable(self.configuration, self.readwrite_file) + ) + self.assertTrue( + fileio.check_writable(self.configuration, self.writeonly_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_writable(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_writable(self.configuration, self.readonly_file) + ) else: - self.assertFalse(fileio.check_writable(self.configuration, - self.readonly_file)) - self.assertFalse(fileio.check_writable(self.configuration, - "/no/such/file")) + self.assertFalse( + fileio.check_writable(self.configuration, self.readonly_file) + ) + self.assertFalse( + fileio.check_writable(self.configuration, "/no/such/file") + ) def test_check_readonly(self): """Test check_readonly wrapper function""" # Super-user has access to read and write all files! if os.getuid() == 0: # Test with read-only file path - self.assertFalse(fileio.check_readonly(self.configuration, - self.readonly_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readonly_file) + ) # Test with writable file - self.assertFalse(fileio.check_readonly(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readonly(self.configuration, - self.readwrite_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readwrite_file) + ) else: # Test with read-only file path - self.assertTrue(fileio.check_readonly(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_readonly(self.configuration, self.readonly_file) + ) # Test with writable file - self.assertFalse(fileio.check_readonly(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readonly(self.configuration, - self.readwrite_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readwrite_file) + ) def test_check_readwritable(self): """Test check_readwritable wrapper function""" - self.assertTrue(fileio.check_readwritable(self.configuration, - self.readwrite_file)) + self.assertTrue( + fileio.check_readwritable(self.configuration, self.readwrite_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_readwritable(self.configuration, - self.readonly_file)) - self.assertTrue(fileio.check_readwritable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_readwritable( + self.configuration, self.readonly_file + ) + ) + self.assertTrue( + fileio.check_readwritable( + self.configuration, self.writeonly_file + ) + ) else: - self.assertFalse(fileio.check_readwritable(self.configuration, - self.readonly_file)) - self.assertFalse(fileio.check_readwritable(self.configuration, - self.writeonly_file)) - - self.assertFalse(fileio.check_readwritable(self.configuration, - "/invalid/file")) + self.assertFalse( + fileio.check_readwritable( + self.configuration, self.readonly_file + ) + ) + self.assertFalse( + fileio.check_readwritable( + self.configuration, self.writeonly_file + ) + ) + + self.assertFalse( + fileio.check_readwritable(self.configuration, "/invalid/file") + ) def test_special_cases(self): """Test various special cases for access checks""" @@ -1533,11 +1640,13 @@ def test_special_cases(self): self.assertFalse(fileio.check_write_access(missing_path)) # Check with custom follow_symlink=False - self.assertTrue(fileio.check_read_access(self.readwrite_file, - follow_symlink=False)) - self.assertTrue(fileio.check_read_access(self.tmp_dir, True, - follow_symlink=False)) + self.assertTrue( + fileio.check_read_access(self.readwrite_file, follow_symlink=False) + ) + self.assertTrue( + fileio.check_read_access(self.tmp_dir, True, follow_symlink=False) + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_filemarks.py b/tests/test_mig_shared_filemarks.py index a640deba9..41a3a95c8 100644 --- a/tests/test_mig_shared_filemarks.py +++ b/tests/test_mig_shared_filemarks.py @@ -36,11 +36,12 @@ # Imports of the code under test from mig.shared.filemarks import get_filemark, reset_filemark, update_filemark + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -TEST_MARKS_DIR = 'TestMarks' -TEST_MARKS_FILE = 'file.mark' +TEST_MARKS_DIR = "TestMarks" +TEST_MARKS_FILE = "file.mark" class TestMigSharedFilemarks(MigTestCase): @@ -48,7 +49,7 @@ class TestMigSharedFilemarks(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def _prepare_mark_for_test(self, mark_name=None, timestamp=None): """Prepare test for mark_name with timestamp in default location""" @@ -57,7 +58,7 @@ def _prepare_mark_for_test(self, mark_name=None, timestamp=None): if timestamp is None: timestamp = time.time() self.marks_path = os.path.join(self.marks_base, mark_name) - open(self.marks_path, 'w').close() + open(self.marks_path, "w").close() os.utime(self.marks_path, (timestamp, timestamp)) return timestamp @@ -69,8 +70,9 @@ def _verify_mark_after_test(self, mark_name, timestamp): def before_each(self): """Setup fake configuration and temp dir before each test.""" - self.marks_base = os.path.join(self.configuration.mig_system_run, - TEST_MARKS_DIR) + self.marks_base = os.path.join( + self.configuration.mig_system_run, TEST_MARKS_DIR + ) ensure_dirs_exist(self.marks_base) self.marks_path = os.path.join(self.marks_base, TEST_MARKS_FILE) @@ -78,8 +80,9 @@ def test_update_filemark_create(self): """Test update_filemark creates mark file with timestamp""" timestamp = 4242 self.assertFalse(os.path.isfile(self.marks_path)) - update_result = update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, timestamp) + update_result = update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, timestamp + ) self.assertTrue(update_result) self.assertTrue(os.path.isfile(self.marks_path)) self.assertEqual(os.path.getmtime(self.marks_path), timestamp) @@ -89,8 +92,9 @@ def test_update_filemark_timestamp(self): timestamp = 424242 self._prepare_mark_for_test(TEST_MARKS_FILE, 4242) - update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, timestamp) + update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, timestamp + ) self.assertTrue(os.path.isfile(self.marks_path)) self.assertEqual(os.path.getmtime(self.marks_path), timestamp) @@ -98,8 +102,9 @@ def test_update_filemark_delete(self): """Test update_filemark deletes mark files with negative timestamp""" self._prepare_mark_for_test(TEST_MARKS_FILE) - delete_result = update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, -1) + delete_result = update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, -1 + ) self.assertTrue(delete_result) self.assertFalse(os.path.exists(self.marks_path)) @@ -108,23 +113,26 @@ def test_get_filemark_existing(self): timestamp = 4242 self._prepare_mark_for_test(TEST_MARKS_FILE, timestamp) - retrieved = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + retrieved = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertEqual(retrieved, timestamp) def test_get_filemark_missing(self): """Test get_filemark returns None for missing mark files""" self.assertFalse(os.path.isfile(self.marks_path)) - retrieved = get_filemark(self.configuration, self.marks_base, - 'missing.mark') + retrieved = get_filemark( + self.configuration, self.marks_base, "missing.mark" + ) self.assertIsNone(retrieved) def test_reset_filemark_single(self): """Test reset_filemark updates single mark timestamp to 0""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - [TEST_MARKS_FILE]) + reset_result = reset_filemark( + self.configuration, self.marks_base, [TEST_MARKS_FILE] + ) self.assertTrue(reset_result) self._verify_mark_after_test(TEST_MARKS_FILE, 0) @@ -133,18 +141,20 @@ def test_reset_filemark_delete(self): """Test reset_filemark deletes marks with delete=True""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - [TEST_MARKS_FILE], delete=True) + reset_result = reset_filemark( + self.configuration, self.marks_base, [TEST_MARKS_FILE], delete=True + ) self.assertTrue(reset_result) - retrieved = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + retrieved = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertIsNone(retrieved) self.assertFalse(os.path.exists(self.marks_path)) def test_reset_filemark_all(self): """Test reset_filemark resets all marks when mark_list=None""" - marks = ['mark1', 'mark2', 'mark3'] + marks = ["mark1", "mark2", "mark3"] for mark in marks: self._prepare_mark_for_test(mark) @@ -157,15 +167,22 @@ def test_reset_filemark_all(self): def test_update_filemark_fails_when_file_prevents_directory(self): """Test update_filemark fails when file prevents create directory""" # Create a file in the way to prevent subdir creation - self._prepare_mark_for_test('obstruct') - - with self.assertLogs(level='ERROR') as log_capture: - result = update_filemark(self.configuration, self.marks_base, - os.path.join('obstruct', 'test.mark'), - time.time()) + self._prepare_mark_for_test("obstruct") + + with self.assertLogs(level="ERROR") as log_capture: + result = update_filemark( + self.configuration, + self.marks_base, + os.path.join("obstruct", "test.mark"), + time.time(), + ) self.assertFalse(result) - self.assertTrue(any('in the way' in msg or 'could not create' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "in the way" in msg or "could not create" in msg + for msg in log_capture.output + ) + ) @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_update_filemark_directory_perms_failure(self): @@ -173,13 +190,17 @@ def test_update_filemark_directory_perms_failure(self): # Create a read-only parent directory to prevent subdir creation os.chmod(self.marks_base, stat.S_IRUSR) # Remove write permissions - with self.assertLogs(level='ERROR') as log_capture: - result = update_filemark(self.configuration, self.marks_base, - os.path.join('noaccess', 'test.mark'), - time.time()) + with self.assertLogs(level="ERROR") as log_capture: + result = update_filemark( + self.configuration, + self.marks_base, + os.path.join("noaccess", "test.mark"), + time.time(), + ) self.assertFalse(result) - self.assertTrue(any('Permission denied' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Permission denied" in msg for msg in log_capture.output) + ) @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_get_filemark_permission_denied(self): @@ -188,8 +209,9 @@ def test_get_filemark_permission_denied(self): # Remove read permissions through parent dir os.chmod(self.marks_base, 0) - result = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + result = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertIsNone(result) # Restore permissions so cleanup works os.chmod(self.marks_base, stat.S_IRWXU) @@ -198,20 +220,23 @@ def test_reset_filemark_string_mark_list(self): """Test reset_filemark handles single string mark_list""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + reset_result = reset_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertTrue(reset_result) self._verify_mark_after_test(TEST_MARKS_FILE, 0) def test_reset_filemark_invalid_mark_list(self): """Test reset_filemark fails with invalid mark_list type""" - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - {'invalid': 'type'}) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, {"invalid": "type"} + ) self.assertFalse(reset_result) - self.assertTrue(any('invalid mark list' in msg for msg in - log_capture.output)) + self.assertTrue( + any("invalid mark list" in msg for msg in log_capture.output) + ) def test_reset_filemark_all_missing_dir(self): """Test reset_filemark handles missing directory when mark_list=None""" @@ -222,41 +247,48 @@ def test_reset_filemark_all_missing_dir(self): @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_reset_filemark_partial_perms_failure(self): """Test reset_filemark with partial failure due to permissions""" - valid_mark = 'valid.mark' - invalid_mark = 'invalid.mark' + valid_mark = "valid.mark" + invalid_mark = "invalid.mark" invalid_path = os.path.join(self.marks_base, invalid_mark) # Create both marks but remove access to the latter self._prepare_mark_for_test(valid_mark) self._prepare_mark_for_test(invalid_mark) os.chmod(invalid_path, stat.S_IRUSR) # Remove write permissions - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - [valid_mark, invalid_mark]) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, [valid_mark, invalid_mark] + ) self.assertFalse(reset_result) # Should fail due to partial failure - self.assertTrue(any('Permission denied' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Permission denied" in msg for msg in log_capture.output) + ) self._verify_mark_after_test(valid_mark, 0) def test_reset_filemark_partial_file_prevents_directory_failure(self): """Test reset_filemark with partial failure due to a file in the way""" - valid_mark = 'valid.mark' - invalid_mark = os.path.join('obstruct', 'invalid.mark') + valid_mark = "valid.mark" + invalid_mark = os.path.join("obstruct", "invalid.mark") # Create valid mark and a file to prevent the invalid mark self._prepare_mark_for_test(valid_mark) # Create a file in the way to prevent subdir creation - self._prepare_mark_for_test('obstruct') + self._prepare_mark_for_test("obstruct") - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - [valid_mark, invalid_mark]) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, [valid_mark, invalid_mark] + ) self.assertFalse(reset_result) # Should fail due to partial failure - self.assertTrue(any('in the way' in msg or 'could not create' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "in the way" in msg or "could not create" in msg + for msg in log_capture.output + ) + ) self._verify_mark_after_test(valid_mark, 0) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_functionality_cat.py b/tests/test_mig_shared_functionality_cat.py index 619c3dd4e..a8081c105 100644 --- a/tests/test_mig_shared_functionality_cat.py +++ b/tests/test_mig_shared_functionality_cat.py @@ -28,17 +28,25 @@ """Unit tests of the MiG functionality file implementing the cat backend""" from __future__ import print_function + import importlib import os import shutil import sys import unittest -from tests.support import MIG_BASE, PY2, TEST_DATA_DIR, MigTestCase, testmain, \ - temppath, ensure_dirs_exist - from mig.shared.base import client_id_dir -from mig.shared.functionality.cat import _main as submain, main as realmain +from mig.shared.functionality.cat import _main as submain +from mig.shared.functionality.cat import main as realmain +from tests.support import ( + MIG_BASE, + PY2, + TEST_DATA_DIR, + MigTestCase, + ensure_dirs_exist, + temppath, + testmain, +) def create_http_environ(configuration): @@ -47,139 +55,169 @@ def create_http_environ(configuration): """ environ = {} - environ['MIG_CONF'] = configuration.config_file - environ['HTTP_HOST'] = 'localhost' - environ['PATH_INFO'] = '/' - environ['REMOTE_ADDR'] = '127.0.0.1' - environ['SCRIPT_URI'] = ''.join(('https://', environ['HTTP_HOST'], - environ['PATH_INFO'])) + environ["MIG_CONF"] = configuration.config_file + environ["HTTP_HOST"] = "localhost" + environ["PATH_INFO"] = "/" + environ["REMOTE_ADDR"] = "127.0.0.1" + environ["SCRIPT_URI"] = "".join( + ("https://", environ["HTTP_HOST"], environ["PATH_INFO"]) + ) return environ def _only_output_objects(output_objects, with_object_type=None): - return [o for o in output_objects if o['object_type'] == with_object_type] + return [o for o in output_objects if o["object_type"] == with_object_type] class MigSharedFunctionalityCat(MigTestCase): """Wrap unit tests for the corresponding module""" - TEST_CLIENT_ID = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_CLIENT_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): - self.test_user_dir = self._provision_test_user(self, self.TEST_CLIENT_ID) + self.test_user_dir = self._provision_test_user( + self, self.TEST_CLIENT_ID + ) self.test_environ = create_http_environ(self.configuration) def assertSingleOutputObject(self, output_objects, with_object_type=None): assert with_object_type is not None - found_objects = _only_output_objects(output_objects, - with_object_type=with_object_type) + found_objects = _only_output_objects( + output_objects, with_object_type=with_object_type + ) self.assertEqual(len(found_objects), 1) return found_objects[0] def test_file_serving_a_single_file_match(self): - with open(os.path.join(self.test_user_dir, 'foobar.txt'), 'w'): + with open(os.path.join(self.test_user_dir, "foobar.txt"), "w"): pass payload = { - 'path': ['foobar.txt'], + "path": ["foobar.txt"], } - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual content self.assertEqual(len(output_objects), 2) - self.assertSingleOutputObject(output_objects, - with_object_type='file_output') + self.assertSingleOutputObject( + output_objects, with_object_type="file_output" + ) def test_file_serving_at_limit(self): test_binary_file = os.path.realpath( - os.path.join(TEST_DATA_DIR, 'loading.gif')) + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join( - self.test_user_dir, 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } self.configuration.wwwserve_max_bytes = test_binary_file_size - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) self.assertEqual(len(output_objects), 2) - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='file_output') - self.assertEqual(len(relevant_obj['lines']), 1) - self.assertEqual(relevant_obj['lines'][0], test_binary_file_data) + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="file_output" + ) + self.assertEqual(len(relevant_obj["lines"]), 1) + self.assertEqual(relevant_obj["lines"][0], test_binary_file_data) def test_file_serving_over_limit_without_storage_protocols(self): - test_binary_file = os.path.realpath(os.path.join(TEST_DATA_DIR, - 'loading.gif')) + test_binary_file = os.path.realpath( + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join(self.test_user_dir, - 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } # NOTE: override default storage_protocols to empty in this test self.configuration.storage_protocols = [] self.configuration.wwwserve_max_bytes = test_binary_file_size - 1 - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual error message self.assertEqual(len(output_objects), 2) - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='error_text') - self.assertEqual(relevant_obj['text'], - "Site configuration prevents web serving contents " - "bigger than 3896 bytes") + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="error_text" + ) + self.assertEqual( + relevant_obj["text"], + "Site configuration prevents web serving contents " + "bigger than 3896 bytes", + ) def test_file_serving_over_limit_with_storage_protocols_sftp(self): - test_binary_file = os.path.realpath(os.path.join(TEST_DATA_DIR, - 'loading.gif')) + test_binary_file = os.path.realpath( + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join(self.test_user_dir, - 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } - self.configuration.storage_protocols = ['sftp'] + self.configuration.storage_protocols = ["sftp"] self.configuration.wwwserve_max_bytes = test_binary_file_size - 1 - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual error message - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='error_text') - self.assertEqual(relevant_obj['text'], - "Site configuration prevents web serving contents " - "bigger than 3896 bytes - please use better " - "alternatives (SFTP) to retrieve large data") + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="error_text" + ) + self.assertEqual( + relevant_obj["text"], + "Site configuration prevents web serving contents " + "bigger than 3896 bytes - please use better " + "alternatives (SFTP) to retrieve large data", + ) @unittest.skipIf(PY2, "Python 3 only") def test_main_passes_environ(self): @@ -187,17 +225,21 @@ def test_main_passes_environ(self): result = realmain(self.TEST_CLIENT_ID, {}, self.test_environ) except Exception as unexpectedexc: raise AssertionError( - "saw unexpected exception: %s" % (unexpectedexc,)) + "saw unexpected exception: %s" % (unexpectedexc,) + ) - (output_objects, status) = result - self.assertEqual(status[1], 'Client error') + output_objects, status = result + self.assertEqual(status[1], "Client error") - error_text_objects = _only_output_objects(output_objects, - with_object_type='error_text') + error_text_objects = _only_output_objects( + output_objects, with_object_type="error_text" + ) relevant_obj = error_text_objects[2] self.assertEqual( - relevant_obj['text'], 'Input arguments were rejected - not allowed for this script!') + relevant_obj["text"], + "Input arguments were rejected - not allowed for this script!", + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_functionality_datatransfer.py b/tests/test_mig_shared_functionality_datatransfer.py index 3fe9e5fdc..6887fd5a1 100644 --- a/tests/test_mig_shared_functionality_datatransfer.py +++ b/tests/test_mig_shared_functionality_datatransfer.py @@ -28,18 +28,19 @@ """Unit tests of the MiG functionality file implementing the datatransfer backend""" from __future__ import print_function + import os import mig.shared.returnvalues as returnvalues -from mig.shared.defaults import CSRF_MINIMAL from mig.shared.base import client_id_dir -from mig.shared.functionality.datatransfer import _main as submain, main as realmain - +from mig.shared.defaults import CSRF_MINIMAL +from mig.shared.functionality.datatransfer import _main as submain +from mig.shared.functionality.datatransfer import main as realmain from tests.support import ( MigTestCase, - testmain, - temppath, ensure_dirs_exist, + temppath, + testmain, ) @@ -66,25 +67,27 @@ def _only_output_objects(output_objects, with_object_type=None): class MigSharedFunctionalityDataTransfer(MigTestCase): """Wrap unit tests for the corresponding module""" - TEST_CLIENT_ID = ( - "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" - ) + TEST_CLIENT_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" def _provide_configuration(self): return "testconfig" def before_each(self): - self.test_user_dir = self._provision_test_user(self, self.TEST_CLIENT_ID) + self.test_user_dir = self._provision_test_user( + self, self.TEST_CLIENT_ID + ) self.test_environ = create_http_environ(self.configuration) def test_default_disabled_site_transfer(self): self.assertFalse(self.configuration.site_enable_transfers) result = realmain(self.TEST_CLIENT_ID, {}, self.test_environ) - (output_objects, status) = result + output_objects, status = result self.assertEqual(status, returnvalues.OK) - text_objects = _only_output_objects(output_objects, with_object_type="text") + text_objects = _only_output_objects( + output_objects, with_object_type="text" + ) self.assertEqual(len(text_objects), 1) self.assertIn("text", text_objects[0]) text_object = text_objects[0]["text"] @@ -95,7 +98,7 @@ def test_show_action_enabled_site_transfer(self): payload = {"action": ["show"]} self.configuration.site_enable_transfers = True - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, @@ -105,17 +108,22 @@ def test_show_action_enabled_site_transfer(self): self.assertEqual(status, returnvalues.OK) # We don't expect any text messages here - text_objects = _only_output_objects(output_objects, with_object_type="text") + text_objects = _only_output_objects( + output_objects, with_object_type="text" + ) self.assertEqual(len(text_objects), 0) def test_deltransfer_without_transfer_id(self): non_existing_transfer_id = "non-existing-transfer-id" - payload = {"action": ["deltransfer"], "transfer_id": [non_existing_transfer_id]} + payload = { + "action": ["deltransfer"], + "transfer_id": [non_existing_transfer_id], + } self.configuration.site_enable_transfers = True self.configuration.site_csrf_protection = CSRF_MINIMAL self.test_environ["REQUEST_METHOD"] = "post" - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, @@ -129,7 +137,8 @@ def test_deltransfer_without_transfer_id(self): ) self.assertEqual(len(error_text_objects), 1) self.assertEqual( - error_text_objects[0]["text"], "existing transfer_id is required for delete" + error_text_objects[0]["text"], + "existing transfer_id is required for delete", ) def test_redotransfer_without_transfer_id(self): @@ -142,7 +151,7 @@ def test_redotransfer_without_transfer_id(self): self.configuration.site_csrf_protection = CSRF_MINIMAL self.test_environ["REQUEST_METHOD"] = "post" - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, diff --git a/tests/test_mig_shared_install.py b/tests/test_mig_shared_install.py index b06f3c928..e67793879 100644 --- a/tests/test_mig_shared_install.py +++ b/tests/test_mig_shared_install.py @@ -27,21 +27,28 @@ """Unit tests for the migrid module pointed to in the filename""" -from past.builtins import basestring import binascii -from configparser import ConfigParser, NoSectionError, NoOptionError import difflib import io import os import pwd import sys +from configparser import ConfigParser, NoOptionError, NoSectionError -from tests.support import MIG_BASE, TEST_OUTPUT_DIR, MigTestCase, \ - testmain, temppath, cleanpath, is_path_within -from tests.support.fixturesupp import fixturepath +from past.builtins import basestring from mig.shared.defaults import keyword_auto from mig.shared.install import determine_timezone, generate_confs +from tests.support import ( + MIG_BASE, + TEST_OUTPUT_DIR, + MigTestCase, + cleanpath, + is_path_within, + temppath, + testmain, +) +from tests.support.fixturesupp import fixturepath class DummyPwInfo: @@ -70,40 +77,43 @@ class MigSharedInstall__determine_timezone(MigTestCase): def test_determines_tz_utc_fallback(self): timezone = determine_timezone( - _environ={}, _path_exists=lambda _: False, _print=noop) + _environ={}, _path_exists=lambda _: False, _print=noop + ) - self.assertEqual(timezone, 'UTC') + self.assertEqual(timezone, "UTC") def test_determines_tz_via_environ(self): - example_environ = { - 'TZ': 'Example/Enviromnent' - } + example_environ = {"TZ": "Example/Enviromnent"} timezone = determine_timezone(_environ=example_environ) - self.assertEqual(timezone, 'Example/Enviromnent') + self.assertEqual(timezone, "Example/Enviromnent") def test_determines_tz_via_localtime(self): def exists_localtime(value): - saw_call = value == '/etc/localtime' + saw_call = value == "/etc/localtime" exists_localtime.was_called = saw_call return saw_call + exists_localtime.was_called = False timezone = determine_timezone( - _environ={}, _path_exists=exists_localtime) + _environ={}, _path_exists=exists_localtime + ) self.assertTrue(exists_localtime.was_called) self.assertIsNotNone(timezone) def test_determines_tz_via_timedatectl(self): def exists_timedatectl(value): - saw_call = value == '/usr/bin/timedatectl' + saw_call = value == "/usr/bin/timedatectl" exists_timedatectl.was_called = saw_call return saw_call + exists_timedatectl.was_called = False timezone = determine_timezone( - _environ={}, _path_exists=exists_timedatectl, _print=noop) + _environ={}, _path_exists=exists_timedatectl, _print=noop + ) self.assertTrue(exists_timedatectl.was_called) self.assertIsNotNone(timezone) @@ -131,46 +141,48 @@ def assertConfigKey(self, generated, section, key, expected): self.assertEqual(actual, expected) def test_creates_output_directory_and_adds_active_symlink(self): - symlink_path = temppath('confs', self) - cleanpath('confs-foobar', self) + symlink_path = temppath("confs", self) + cleanpath("confs-foobar", self) - generate_confs(self.output_path, destination_suffix='-foobar') + generate_confs(self.output_path, destination_suffix="-foobar") - path_kind = self.assertPathExists('confs-foobar') + path_kind = self.assertPathExists("confs-foobar") self.assertEqual(path_kind, "dir") - path_kind = self.assertPathExists('confs') + path_kind = self.assertPathExists("confs") self.assertEqual(path_kind, "symlink") def test_creates_output_directory_and_repairs_active_symlink(self): - expected_generated_dir = cleanpath('confs-foobar', self) - symlink_path = temppath('confs', self) - nowhere_path = temppath('confs-nowhere', self) + expected_generated_dir = cleanpath("confs-foobar", self) + symlink_path = temppath("confs", self) + nowhere_path = temppath("confs-nowhere", self) # arrange pre-existing symlink pointing nowhere os.symlink(nowhere_path, symlink_path) - generate_confs(self.output_path, destination_suffix='-foobar') + generate_confs(self.output_path, destination_suffix="-foobar") generated_dir = os.path.realpath(symlink_path) self.assertEqual(generated_dir, expected_generated_dir) - def test_creates_output_directory_containing_a_standard_local_configuration(self): + def test_creates_output_directory_containing_a_standard_local_configuration( + self, + ): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) generate_confs( self.output_path, - destination_suffix='-stdlocal', - user='testuser', - group='testgroup', - mig_code='/home/mig/mig', - mig_certs='/home/mig/certs', - mig_state='/home/mig/state', - timezone='Test/Place', - crypto_salt='_TEST_CRYPTO_SALT'.zfill(32), - digest_salt='_TEST_DIGEST_SALT'.zfill(32), - seafile_secret='_test-seafile-secret='.zfill(44), - seafile_ccnetid='_TEST_SEAFILE_CCNETID'.zfill(40), + destination_suffix="-stdlocal", + user="testuser", + group="testgroup", + mig_code="/home/mig/mig", + mig_certs="/home/mig/certs", + mig_state="/home/mig/state", + timezone="Test/Place", + crypto_salt="_TEST_CRYPTO_SALT".zfill(32), + digest_salt="_TEST_DIGEST_SALT".zfill(32), + seafile_secret="_test-seafile-secret=".zfill(44), + seafile_ccnetid="_TEST_SEAFILE_CCNETID".zfill(40), _getpwnam=create_dummy_gpwnam(4321, 1234), ) @@ -193,10 +205,11 @@ def test_kwargs_for_paths_auto(self): def capture_defaulted(*args, **kwargs): capture_defaulted.kwargs = kwargs return args[0] + capture_defaulted.kwargs = None - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=capture_defaulted, _writefiles=noop, @@ -204,123 +217,139 @@ def capture_defaulted(*args, **kwargs): ) defaulted = capture_defaulted.kwargs - self.assertPathWithin(defaulted['mig_certs'], MIG_BASE) - self.assertPathWithin(defaulted['mig_state'], MIG_BASE) + self.assertPathWithin(defaulted["mig_certs"], MIG_BASE) + self.assertPathWithin(defaulted["mig_state"], MIG_BASE) def test_creates_output_files_with_datasafety(self): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) generate_confs( self.output_path, destination=symlink_path, - destination_suffix='-stdlocal', - datasafety_link='TEST_DATASAFETY_LINK', - datasafety_text='TEST_DATASAFETY_TEXT', + destination_suffix="-stdlocal", + datasafety_link="TEST_DATASAFETY_LINK", + datasafety_text="TEST_DATASAFETY_TEXT", _getpwnam=create_dummy_gpwnam(4321, 1234), ) - actual_file = self.assertFileExists('confs-stdlocal/MiGserver.conf') + actual_file = self.assertFileExists("confs-stdlocal/MiGserver.conf") self.assertConfigKey( - actual_file, 'SITE', 'datasafety_link', expected='TEST_DATASAFETY_LINK') + actual_file, + "SITE", + "datasafety_link", + expected="TEST_DATASAFETY_LINK", + ) self.assertConfigKey( - actual_file, 'SITE', 'datasafety_text', expected='TEST_DATASAFETY_TEXT') + actual_file, + "SITE", + "datasafety_text", + expected="TEST_DATASAFETY_TEXT", + ) def test_creates_output_files_with_permanent_freeze(self): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) - for arg_val in ('yes', 'no', 'foo bar baz'): + for arg_val in ("yes", "no", "foo bar baz"): generate_confs( self.output_path, destination=symlink_path, - destination_suffix='-stdlocal', + destination_suffix="-stdlocal", permanent_freeze=arg_val, _getpwnam=create_dummy_gpwnam(4321, 1234), ) - actual_file = self.assertFileExists('confs-stdlocal/MiGserver.conf') + actual_file = self.assertFileExists("confs-stdlocal/MiGserver.conf") self.assertConfigKey( - actual_file, 'SITE', 'permanent_freeze', expected=arg_val) + actual_file, "SITE", "permanent_freeze", expected=arg_val + ) def test_options_for_source_auto(self): - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", source=keyword_auto, _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - expected_template_dir = os.path.join(MIG_BASE, 'mig/install') + expected_template_dir = os.path.join(MIG_BASE, "mig/install") - self.assertEqual(options['template_dir'], expected_template_dir) + self.assertEqual(options["template_dir"], expected_template_dir) def test_options_for_source_relative(self): - (options, _) = generate_confs( - '/current/working/directory/mig/install', - source='.', + options, _ = generate_confs( + "/current/working/directory/mig/install", + source=".", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['template_dir'], - '/current/working/directory/mig/install') + self.assertEqual( + options["template_dir"], "/current/working/directory/mig/install" + ) def test_options_for_destination_auto(self): - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", destination=keyword_auto, - destination_suffix='_suffix', + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/some/arbitrary/path/confs') - self.assertEqual(options['destination_dir'], - '/some/arbitrary/path/confs_suffix') + self.assertEqual( + options["destination_link"], "/some/arbitrary/path/confs" + ) + self.assertEqual( + options["destination_dir"], "/some/arbitrary/path/confs_suffix" + ) def test_options_for_destination_relative(self): - (options, _) = generate_confs( - '/current/working/directory', - destination='generate-confs', - destination_suffix='_suffix', + options, _ = generate_confs( + "/current/working/directory", + destination="generate-confs", + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/current/working/directory/generate-confs') - self.assertEqual(options['destination_dir'], - '/current/working/directory/generate-confs_suffix') + self.assertEqual( + options["destination_link"], + "/current/working/directory/generate-confs", + ) + self.assertEqual( + options["destination_dir"], + "/current/working/directory/generate-confs_suffix", + ) def test_options_for_destination_absolute(self): - (options, _) = generate_confs( - '/current/working/directory', - destination='/some/other/place/confs', - destination_suffix='_suffix', + options, _ = generate_confs( + "/current/working/directory", + destination="/some/other/place/confs", + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/some/other/place/confs') - self.assertEqual(options['destination_dir'], - '/some/other/place/confs_suffix') + self.assertEqual(options["destination_link"], "/some/other/place/confs") + self.assertEqual( + options["destination_dir"], "/some/other/place/confs_suffix" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_jupyter.py b/tests/test_mig_shared_jupyter.py index 3eb53bf63..461282df8 100644 --- a/tests/test_mig_shared_jupyter.py +++ b/tests/test_mig_shared_jupyter.py @@ -34,9 +34,14 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain from mig.shared.jupyter import gen_openid_template +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) def noop(*args): @@ -48,7 +53,8 @@ class MigSharedJupyter(MigTestCase): def test_jupyter_gen_openid_template_openid_auth(self): filled = gen_openid_template( - "/some-jupyter-url", "MyDefine", "OpenID", _print=noop) + "/some-jupyter-url", "MyDefine", "OpenID", _print=noop + ) expected = """ @@ -63,7 +69,8 @@ def test_jupyter_gen_openid_template_openid_auth(self): def test_jupyter_gen_openid_template_oidc_auth(self): filled = gen_openid_template( - "/some-jupyter-url", "MyDefine", "openid-connect", _print=noop) + "/some-jupyter-url", "MyDefine", "openid-connect", _print=noop + ) expected = """ @@ -79,26 +86,26 @@ def test_jupyter_gen_openid_template_oidc_auth(self): def test_jupyter_gen_openid_template_invalid_url_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template(None, "MyDefine", - "openid-connect") + filled = gen_openid_template(None, "MyDefine", "openid-connect") def test_jupyter_gen_openid_template_invalid_define_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", None, - "no-such-auth-type") + filled = gen_openid_template( + "/some-jupyter-url", None, "no-such-auth-type" + ) def test_jupyter_gen_openid_template_missing_auth_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", "MyDefine", - None) + filled = gen_openid_template("/some-jupyter-url", "MyDefine", None) def test_jupyter_gen_openid_template_invalid_auth_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", "MyDefine", - "no-such-auth-type") + filled = gen_openid_template( + "/some-jupyter-url", "MyDefine", "no-such-auth-type" + ) # TODO: add more coverage of module -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_localfile.py b/tests/test_mig_shared_localfile.py index 62a5112aa..be9718bd7 100644 --- a/tests/test_mig_shared_localfile.py +++ b/tests/test_mig_shared_localfile.py @@ -27,21 +27,19 @@ """Unit tests for the migrid module pointed to in the filename""" -from contextlib import contextmanager import errno import fcntl import os import sys +from contextlib import contextmanager -sys.path.append(os.path.realpath( - os.path.join(os.path.dirname(__file__), ".."))) - -from tests.support import MigTestCase, temppath, testmain +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))) -from mig.shared.serverfile import LOCK_EX from mig.shared.localfile import LocalFile +from mig.shared.serverfile import LOCK_EX +from tests.support import MigTestCase, temppath, testmain -DUMMY_FILE = 'some_file' +DUMMY_FILE = "some_file" @contextmanager @@ -72,7 +70,7 @@ def assertPathLockedExclusive(self, file_path): # we were errantly able to acquire a lock, mark errored reraise = AssertionError("RERAISE_MUST_UNLOCK") except Exception as maybe_err: - if getattr(maybe_err, 'errno', None) == errno.EAGAIN: + if getattr(maybe_err, "errno", None) == errno.EAGAIN: # this is the expected exception - the logic tried to lock # a file that was (as we intended) already locked, meaning # this assertion has succeeded so we do not need to raise @@ -83,17 +81,19 @@ def assertPathLockedExclusive(self, file_path): if reraise is not None: # if marked errored and locked, cleanup the lock we acquired but shouldn't - if str(reraise) == 'RERAISE_MUST_UNLOCK': + if str(reraise) == "RERAISE_MUST_UNLOCK": fcntl.flock(conflicting_f, fcntl.LOCK_NB | fcntl.LOCK_UN) # raise a user-friendly error to avoid nested raise raise AssertionError( - "expected locked file: %s" % self.pretty_display_path(file_path)) + "expected locked file: %s" + % self.pretty_display_path(file_path) + ) def test_localfile_locking(self): some_file = temppath(DUMMY_FILE, self) - with managed_localfile(LocalFile(some_file, 'w')) as lfd: + with managed_localfile(LocalFile(some_file, "w")) as lfd: lfd.lock(LOCK_EX) self.assertEqual(lfd.get_lock_mode(), LOCK_EX) @@ -101,5 +101,5 @@ def test_localfile_locking(self): self.assertPathLockedExclusive(some_file) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_pwcrypto.py b/tests/test_mig_shared_pwcrypto.py index 784c6963e..42a7440a2 100644 --- a/tests/test_mig_shared_pwcrypto.py +++ b/tests/test_mig_shared_pwcrypto.py @@ -32,65 +32,87 @@ import sys import unittest -from tests.support import MigTestCase, FakeConfiguration, \ - cleanpath, temppath, testmain - -from mig.shared.defaults import POLICY_NONE, POLICY_WEAK, POLICY_MEDIUM, \ - POLICY_HIGH, POLICY_MODERN, POLICY_CUSTOM, PASSWORD_POLICIES -from mig.shared.pwcrypto import main as pwcrypto_main +from mig.shared.defaults import ( + PASSWORD_POLICIES, + POLICY_CUSTOM, + POLICY_HIGH, + POLICY_MEDIUM, + POLICY_MODERN, + POLICY_NONE, + POLICY_WEAK, +) from mig.shared.pwcrypto import * +from mig.shared.pwcrypto import main as pwcrypto_main +from tests.support import ( + FakeConfiguration, + MigTestCase, + cleanpath, + temppath, + testmain, +) DUMMY_USER = "dummy-user" DUMMY_ID = "dummy-id" # NOTE: these passwords are not and should not ever be used outside unit tests -DUMMY_WEAK_PW = 'foobar' -DUMMY_MEDIUM_PW = 'QZFnCp7h' -DUMMY_HIGH_PW = 'QZFnp7I-GZ' -DUMMY_MODERN_PW = 'QZFnCp7hmI1G' -DUMMY_GENERATED_PW = '7hmI1GnCpQZF' +DUMMY_WEAK_PW = "foobar" +DUMMY_MEDIUM_PW = "QZFnCp7h" +DUMMY_HIGH_PW = "QZFnp7I-GZ" +DUMMY_MODERN_PW = "QZFnCp7hmI1G" +DUMMY_GENERATED_PW = "7hmI1GnCpQZF" DUMMY_WEAK_PW_MD5 = "3858f62230ac3c915f300c664312c63f" -DUMMY_WEAK_PW_SHA256 = \ +DUMMY_WEAK_PW_SHA256 = ( "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2" -DUMMY_WEAK_PW_PBKDF2 = \ +) +DUMMY_WEAK_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$epib2rEg/HYTQZFnCp7hmIGZ6rzHnViy" -DUMMY_MEDIUM_PW_PBKDF2 = \ +) +DUMMY_MEDIUM_PW_PBKDF2 = ( "PBKDF2$sha256$10000$ebQHnDX1rzY9Rizb$0vUJ9/4ThhsN4cRaKYmOj4N0YKEsozTr" -DUMMY_HIGH_PW_PBKDF2 = \ +) +DUMMY_HIGH_PW_PBKDF2 = ( "PBKDF2$sha256$10000$HR+KcqLyQe3v0WSk$CtxMAomi8JHiI7gWc/PH5Ey00zW1Now3" +) DUMMY_MODERN_PW_MD5 = "a06d169a171ef7d4383b212457162d93" -DUMMY_MODERN_PW_SHA256 = \ +DUMMY_MODERN_PW_SHA256 = ( "d293dcb9762c87641ea1decbfe76d84ec51b13d6a1e688cdf1a838eebc5bb1a9" -DUMMY_MODERN_PW_PBKDF2 = \ +) +DUMMY_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf10n58FHrn1pjX" -DUMMY_MODERN_PW_DIGEST = \ - "DIGEST$custom$CONFSALT$64756D6D792D7265616C6D3A64756D6D792D7520DE71261F96A2FE48A67DD0877F2A2C" +) +DUMMY_MODERN_PW_DIGEST = "DIGEST$custom$CONFSALT$64756D6D792D7265616C6D3A64756D6D792D7520DE71261F96A2FE48A67DD0877F2A2C" DUMMY_MODERN_DIGEST_SCRAMBLE = "53BB031C1F96A2FE48A67DD0877F2A2C" DUMMY_MODERN_PW_SCRAMBLE = "53BB031C1F96A2FE48A67DD0877F2A2C" -DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED = b'xRsT1qHmiM3xqDjuvFuxqQ==.g4-Gt83uRrdvVWwX0SF1iMza3NyKJbp2sEYVkw==.ICAgIG1pZ3JpZCBhdXRoZW50aWNhdGVkMjA1MDAxMDE=' -DUMMY_MODERN_PW_RESET_TOKEN = b'gAAAAABo63hYqeHE7Db93pMxWn1sWzj2Z-6td2UhA5gKYa4KV096ndV-AO0pp6hrR9jXKcwWAouLCMiNC0BRudeCAYHoBii15lTRbP9b7JzvJjeusbidjRxqcJg0om6bbtSK1Rz_RBTq_jhdAk4v-7PccWlZ15dVJ4j-X3X4zSsBWIOR5y6Y-bA=' -DUMMY_METHOD = 'dummy-method' -DUMMY_OPERATION = 'dummy-operation' -DUMMY_ARGS = {'dummy-key': 'dummy-val'} -DUMMY_CSRF_TOKEN = '351cc47e0cd5c155fa4c4d3d0a6f1ee8f20eeb293ba13d59ede9d2a789687d3d' -DUMMY_CSRF_TRUST_TOKEN = '466c0bacd045a060a201c4e08c749c2e19743613422e0381ab0a57706c9fa2b8' -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED = b"xRsT1qHmiM3xqDjuvFuxqQ==.g4-Gt83uRrdvVWwX0SF1iMza3NyKJbp2sEYVkw==.ICAgIG1pZ3JpZCBhdXRoZW50aWNhdGVkMjA1MDAxMDE=" +DUMMY_MODERN_PW_RESET_TOKEN = b"gAAAAABo63hYqeHE7Db93pMxWn1sWzj2Z-6td2UhA5gKYa4KV096ndV-AO0pp6hrR9jXKcwWAouLCMiNC0BRudeCAYHoBii15lTRbP9b7JzvJjeusbidjRxqcJg0om6bbtSK1Rz_RBTq_jhdAk4v-7PccWlZ15dVJ4j-X3X4zSsBWIOR5y6Y-bA=" +DUMMY_METHOD = "dummy-method" +DUMMY_OPERATION = "dummy-operation" +DUMMY_ARGS = {"dummy-key": "dummy-val"} +DUMMY_CSRF_TOKEN = ( + "351cc47e0cd5c155fa4c4d3d0a6f1ee8f20eeb293ba13d59ede9d2a789687d3d" +) +DUMMY_CSRF_TRUST_TOKEN = ( + "466c0bacd045a060a201c4e08c749c2e19743613422e0381ab0a57706c9fa2b8" +) +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" # TODO: adjust password reset token helpers to handle configured services # it currently silently fails if not in migoid(c) or migcert # DUMMY_SERVICE = 'dummy-svc' -DUMMY_SERVICE = 'migoid' -DUMMY_REALM = 'dummy-realm' -DUMMY_PATH = 'dummy-path' -DUMMY_PATH_MD5 = 'd19033877452e8c217d3cddebbc37419' -DUMMY_SALT = b'53BB031C4ECCE4900BD64AB8EA361B6B' -DUMMY_ENTROPY = b'\xd2\x93\xdc\xb9v,\x87d\x1e\xa1\xde\xcb\xfev\xd8N\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9' -DUMMY_FERNET_KEY = 'NDg3OTcyNzE1NTQ2Nzc3ODYxNjc0NjRFRDZGMjNFQzY=' -DUMMY_AESGCM_KEY = b'48797271554677786167464ED6F23EC6' -DUMMY_AESGCM_STATIC_IV = b'\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9' -DUMMY_AESGCM_AAD_PREFIX = b'\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa9' -DUMMY_AESGCM_AAD = b' \xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa920500101' +DUMMY_SERVICE = "migoid" +DUMMY_REALM = "dummy-realm" +DUMMY_PATH = "dummy-path" +DUMMY_PATH_MD5 = "d19033877452e8c217d3cddebbc37419" +DUMMY_SALT = b"53BB031C4ECCE4900BD64AB8EA361B6B" +DUMMY_ENTROPY = b"\xd2\x93\xdc\xb9v,\x87d\x1e\xa1\xde\xcb\xfev\xd8N\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9" +DUMMY_FERNET_KEY = "NDg3OTcyNzE1NTQ2Nzc3ODYxNjc0NjRFRDZGMjNFQzY=" +DUMMY_AESGCM_KEY = b"48797271554677786167464ED6F23EC6" +DUMMY_AESGCM_STATIC_IV = ( + b"\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9" +) +DUMMY_AESGCM_AAD_PREFIX = b"\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa9" +DUMMY_AESGCM_AAD = b" \xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa920500101" # NOTE: we avoid any percent expansion values of actual date here to freeze AAD -DUMMY_FIXED_TIMESTAMP = '20500101' +DUMMY_FIXED_TIMESTAMP = "20500101" class MigSharedPwCrypto(MigTestCase): @@ -99,13 +121,15 @@ class MigSharedPwCrypto(MigTestCase): def before_each(self): test_user_home = temppath(DUMMY_HOME_DIR, self, ensure_dir=True) test_user_settings = cleanpath( - DUMMY_SETTINGS_DIR, self, ensure_dir=True) + DUMMY_SETTINGS_DIR, self, ensure_dir=True + ) # make two requisite root folders for the dummy user os.mkdir(os.path.join(test_user_home, DUMMY_USER)) os.mkdir(os.path.join(test_user_settings, DUMMY_USER)) # now create a configuration self.dummy_conf = FakeConfiguration( - user_home=test_user_home, user_settings=test_user_settings, + user_home=test_user_home, + user_settings=test_user_settings, site_password_policy="%s:12" % POLICY_MODERN, site_password_legacy_policy=POLICY_MEDIUM, site_password_cracklib=False, @@ -120,9 +144,11 @@ def before_each(self): # 'FakeConfiguration' has no 'site_password_legacy_policy' member # (no-member) unless we explicitly (re-)init it here self.dummy_conf.site_password_legacy_policy = getattr( - self.dummy_conf, 'site_password_legacy_policy', POLICY_NONE) - self.assertEqual(self.dummy_conf.site_password_legacy_policy, - POLICY_MEDIUM) + self.dummy_conf, "site_password_legacy_policy", POLICY_NONE + ) + self.assertEqual( + self.dummy_conf.site_password_legacy_policy, POLICY_MEDIUM + ) def test_best_crypt_salt(self): """Test selection of best salt based on salt availability in @@ -131,13 +157,13 @@ def test_best_crypt_salt(self): expected = DUMMY_SALT actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "best crypt salt not found") - self.dummy_conf.site_crypto_salt = '' + self.dummy_conf.site_crypto_salt = "" actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "2nd best crypt salt not found") - self.dummy_conf.site_password_salt = '' + self.dummy_conf.site_password_salt = "" actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "3rd best crypt salt not found") - self.dummy_conf.site_digest_salt = '' + self.dummy_conf.site_digest_salt = "" actual = None try: actual = best_crypt_salt(self.dummy_conf) @@ -154,10 +180,10 @@ def test_password_requirements(self): self.assertEqual(expected[2], result[2], "failed pw req errors") expected = (8, 3, []) result = password_requirements( - self.dummy_conf.site_password_legacy_policy) + self.dummy_conf.site_password_legacy_policy + ) self.assertEqual(expected[0], result[0], "failed legacy pw req chars") - self.assertEqual(expected[1], result[1], - "failed legacy pw req classes") + self.assertEqual(expected[1], result[1], "failed legacy pw req classes") self.assertEqual(expected[2], result[2], "failed legacy pw req errors") def test_parse_password_policy(self): @@ -174,56 +200,60 @@ def test_parse_password_policy(self): def test_assure_password_strength(self): """Test assure password strength""" try: - allow_weak = assure_password_strength(self.dummy_conf, - DUMMY_WEAK_PW) + allow_weak = assure_password_strength( + self.dummy_conf, DUMMY_WEAK_PW + ) except ValueError as vae: allow_weak = False self.assertFalse(allow_weak, "allowed weak pw") try: - allow_weak = assure_password_strength(self.dummy_conf, - DUMMY_WEAK_PW, - allow_legacy=True) + allow_weak = assure_password_strength( + self.dummy_conf, DUMMY_WEAK_PW, allow_legacy=True + ) except ValueError as vae: allow_weak = False self.assertFalse(allow_weak, "allowed weak pw with legacy") # NOTE: only allow medium with legacy try: - allow_medium = assure_password_strength(self.dummy_conf, - DUMMY_MEDIUM_PW) + allow_medium = assure_password_strength( + self.dummy_conf, DUMMY_MEDIUM_PW + ) except ValueError as vae: allow_medium = False self.assertFalse(allow_medium, "allowed medium pw without legacy") try: - allow_medium = assure_password_strength(self.dummy_conf, - DUMMY_MEDIUM_PW, - allow_legacy=True) + allow_medium = assure_password_strength( + self.dummy_conf, DUMMY_MEDIUM_PW, allow_legacy=True + ) except ValueError as vae: allow_medium = False self.assertTrue(allow_medium, "refused medium pw with legacy") # NOTE: only allow high with legacy - not long enough for modern try: - allow_high = assure_password_strength(self.dummy_conf, - DUMMY_HIGH_PW) + allow_high = assure_password_strength( + self.dummy_conf, DUMMY_HIGH_PW + ) except ValueError as vae: allow_high = False self.assertFalse(allow_high, "allowed high pw without legacy") try: - allow_high = assure_password_strength(self.dummy_conf, - DUMMY_HIGH_PW, - allow_legacy=True) + allow_high = assure_password_strength( + self.dummy_conf, DUMMY_HIGH_PW, allow_legacy=True + ) except ValueError as vae: allow_high = False self.assertTrue(allow_high, "refused high pw with legacy") try: - allow_modern = assure_password_strength(self.dummy_conf, - DUMMY_MODERN_PW) + allow_modern = assure_password_strength( + self.dummy_conf, DUMMY_MODERN_PW + ) except ValueError as vae: allow_modern = False self.assertTrue(allow_modern, "refused modern pw") try: - allow_modern = assure_password_strength(self.dummy_conf, - DUMMY_MODERN_PW, - allow_legacy=True) + allow_modern = assure_password_strength( + self.dummy_conf, DUMMY_MODERN_PW, allow_legacy=True + ) except ValueError as vae: allow_modern = False self.assertTrue(allow_modern, "refused modern pw with legacy") @@ -278,7 +308,7 @@ def test_make_hash_fixed_seed(self): random seed. """ expected = DUMMY_MODERN_PW_PBKDF2 - actual = make_hash(DUMMY_MODERN_PW, _urandom=lambda vlen: b'0' * vlen) + actual = make_hash(DUMMY_MODERN_PW, _urandom=lambda vlen: b"0" * vlen) self.assertEqual(actual, expected, "mismatch hashing string") def test_make_hash_constant_string(self): @@ -286,17 +316,25 @@ def test_make_hash_constant_string(self): random seed. I.e. the value may differ across interpreter invocations but remains constant in same interpreter. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) self.assertEqual(first, second, "same seed hashing is not constant") def test_check_hash_reject_weak(self): """Test basic hash checking of a constant weak complexity password""" expected = DUMMY_WEAK_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_WEAK_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_WEAK_PW, + expected, + strict_policy=True, + ) self.assertFalse(result, "check hash should fail on weak pw") def test_check_hash_reject_medium_without_legacy(self): @@ -304,9 +342,15 @@ def test_check_hash_reject_medium_without_legacy(self): without legacy password support. """ expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertFalse(result, "check hash strict should fail on medium pw") def test_check_hash_accept_medium_with_legacy(self): @@ -314,9 +358,15 @@ def test_check_hash_accept_medium_with_legacy(self): with legacy password support. """ expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True, - allow_legacy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + allow_legacy=True, + ) self.assertTrue(result, "check hash with legacy must accept medium pw") def test_check_hash_accept_high(self): @@ -325,9 +375,15 @@ def test_check_hash_accept_high(self): """ expected = DUMMY_HIGH_PW_PBKDF2 self.dummy_conf.site_password_policy = POLICY_HIGH - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_HIGH_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_HIGH_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertTrue(result, "check hash must accept high complexity pw") def test_check_hash_accept_modern(self): @@ -335,32 +391,57 @@ def test_check_hash_accept_modern(self): without legacy password support. """ expected = DUMMY_MODERN_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertTrue(result, "check hash must accept modern complexity pw") def test_check_hash_fixed(self): """Test basic hash checking of a fixed string""" expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + ) self.assertFalse(result, "check hash should reject medium pw") - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=False, - allow_legacy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=False, + allow_legacy=True, + ) self.assertTrue(result, "check hash failed medium pw when not strict") expected = DUMMY_MODERN_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + strict_policy=True, + ) self.assertTrue(result, "check hash failed modern pw") def test_check_hash_random(self): """Test basic hashing and hash checking of a random string""" random_pw = generate_random_password(self.dummy_conf) expected = make_hash(random_pw) - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, expected) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, random_pw, expected + ) self.assertTrue(result, "mismatch in random hash and check") def test_make_hash_variation(self): @@ -368,10 +449,12 @@ def test_make_hash_variation(self): I.e. the value likely remains constant in same interpreter but differs across interpreter invocations. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen]) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen] + ) self.assertNotEqual(first, second, "varying seed hashing is constant") def test_check_hash_despite_variation(self): @@ -379,15 +462,19 @@ def test_check_hash_despite_variation(self): I.e. the hash value differs across interpreter invocations but testing the same password against each succeeds. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen]) - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, first) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen] + ) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, DUMMY_MODERN_PW, first + ) self.assertTrue(result, "mismatch in 1st random password hash check") - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, second) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, DUMMY_MODERN_PW, second + ) self.assertTrue(result, "mismatch in 2nd random password hash check") def test_scramble_digest_fixed(self): @@ -405,16 +492,23 @@ def test_unscramble_digest_fixed(self): def test_make_digest_fixed(self): """Test basic digest of a fixed string""" expected = DUMMY_MODERN_PW_DIGEST - result = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) + result = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) self.assertEqual(expected, result, "mismatch in fixed digest") def test_check_digest_fixed(self): """Test basic digest checking of a fixed string""" expected = DUMMY_MODERN_PW_DIGEST - result = check_digest(self.dummy_conf, DUMMY_SERVICE, DUMMY_REALM, - DUMMY_USER, DUMMY_MODERN_PW, expected, - DUMMY_SALT) + result = check_digest( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_REALM, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + DUMMY_SALT, + ) self.assertTrue(result, "mismatch in fixed digest check") def test_check_digest_random(self): @@ -422,8 +516,15 @@ def test_check_digest_random(self): random_pw = generate_random_password(self.dummy_conf) random_salt = base64.b16encode(os.urandom(16)) expected = make_digest(DUMMY_REALM, DUMMY_USER, random_pw, random_salt) - result = check_digest(self.dummy_conf, DUMMY_SERVICE, DUMMY_REALM, - DUMMY_USER, random_pw, expected, random_salt) + result = check_digest( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_REALM, + DUMMY_USER, + random_pw, + expected, + random_salt, + ) self.assertTrue(result, "mismatch in random digest check") def test_digest_constant_string(self): @@ -431,10 +532,12 @@ def test_digest_constant_string(self): random seed. I.e. the value may differ across interpreter invocations but remains constant in same interpreter. """ - first = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) - second = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) + first = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) + second = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) self.assertEqual(first, second, "basic digest is not constant") def test_scramble_password_fixed(self): @@ -458,8 +561,14 @@ def test_make_scramble_fixed(self): def test_check_scramble_fixed(self): """Test basic scramble checking of a fixed string""" expected = DUMMY_MODERN_PW_SCRAMBLE - result = check_scramble(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, DUMMY_SALT) + result = check_scramble( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + DUMMY_SALT, + ) self.assertTrue(result, "mismatch in fixed scramble check") def test_check_scramble_random(self): @@ -467,8 +576,14 @@ def test_check_scramble_random(self): random_pw = generate_random_password(self.dummy_conf) random_salt = base64.b16encode(os.urandom(16)) expected = make_scramble(DUMMY_MODERN_PW, random_salt) - result = check_scramble(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, random_salt) + result = check_scramble( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + random_salt, + ) self.assertTrue(result, "mismatch in random scramble check") def test_scramble_constant_string(self): @@ -480,8 +595,7 @@ def test_scramble_constant_string(self): self.assertEqual(first, second, "basic scramble is not constant") def test_prepare_fernet_key(self): - """Test basic fernet secret key preparation on a fixed string. - """ + """Test basic fernet secret key preparation on a fixed string.""" expected = DUMMY_FERNET_KEY result = prepare_fernet_key(self.dummy_conf) self.assertEqual(expected, result, "failed prepare fernet key") @@ -494,8 +608,7 @@ def test_fernet_encrypt_decrypt(self): self.assertEqual(random_pw, result, "failed fernet enc+dec") def test_prepare_aesgcm_key(self): - """Test basic aesgcm secret key preparation on a fixed string. - """ + """Test basic aesgcm secret key preparation on a fixed string.""" expected = DUMMY_AESGCM_KEY result = prepare_aesgcm_key(self.dummy_conf) self.assertEqual(expected, result, "failed prepare aesgcm key") @@ -520,8 +633,11 @@ def test_prepare_aesgcm_aad_fixed(self): entropy and date value. """ expected = DUMMY_AESGCM_AAD - result = prepare_aesgcm_aad(self.dummy_conf, DUMMY_AESGCM_AAD_PREFIX, - aad_stamp=DUMMY_FIXED_TIMESTAMP) + result = prepare_aesgcm_aad( + self.dummy_conf, + DUMMY_AESGCM_AAD_PREFIX, + aad_stamp=DUMMY_FIXED_TIMESTAMP, + ) self.assertEqual(expected, result, "failed prepare aesgcm aad") def test_aesgcm_encrypt_static_iv_fixed(self): @@ -529,9 +645,12 @@ def test_aesgcm_encrypt_static_iv_fixed(self): initialization vector and date helper. """ expected = DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED - result = aesgcm_encrypt_password(self.dummy_conf, DUMMY_MODERN_PW, - init_vector=DUMMY_AESGCM_STATIC_IV, - aad_stamp=DUMMY_FIXED_TIMESTAMP) + result = aesgcm_encrypt_password( + self.dummy_conf, + DUMMY_MODERN_PW, + init_vector=DUMMY_AESGCM_STATIC_IV, + aad_stamp=DUMMY_FIXED_TIMESTAMP, + ) self.assertEqual(expected, result, "failed fixed aesgcm static iv enc") def test_aesgcm_decrypt_static_iv_fixed(self): @@ -539,9 +658,11 @@ def test_aesgcm_decrypt_static_iv_fixed(self): initialization vector. """ expected = DUMMY_MODERN_PW - result = aesgcm_decrypt_password(self.dummy_conf, - DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED, - init_vector=DUMMY_AESGCM_STATIC_IV) + result = aesgcm_decrypt_password( + self.dummy_conf, + DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED, + init_vector=DUMMY_AESGCM_STATIC_IV, + ) self.assertEqual(expected, result, "failed fixed aesgcm static iv den") def test_aesgcm_encrypt_decrypt_static_iv(self): @@ -551,10 +672,12 @@ def test_aesgcm_encrypt_decrypt_static_iv(self): random_pw = generate_random_password(self.dummy_conf) entropy = make_safe_hash(random_pw, False) static_iv = prepare_aesgcm_iv(self.dummy_conf, iv_entropy=entropy) - expected = aesgcm_encrypt_password(self.dummy_conf, random_pw, - init_vector=static_iv) - result = aesgcm_decrypt_password(self.dummy_conf, expected, - init_vector=static_iv) + expected = aesgcm_encrypt_password( + self.dummy_conf, random_pw, init_vector=static_iv + ) + result = aesgcm_decrypt_password( + self.dummy_conf, expected, init_vector=static_iv + ) self.assertEqual(random_pw, result, "failed aesgcm static iv enc+dec") def test_make_encrypt_decrypt(self): @@ -582,97 +705,127 @@ def test_check_encrypt(self): """Test basic password simple encrypt and decrypt on a random string""" random_pw = generate_random_password(self.dummy_conf) # IMPORTANT: only aesgcm_static generates constant enc value! - encrypted = make_encrypt(self.dummy_conf, random_pw, - algo="fernet") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='fernet') + encrypted = make_encrypt(self.dummy_conf, random_pw, algo="fernet") + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="fernet", + ) self.assertFalse(result, "invalid match in fernet encrypt check") encrypted = make_encrypt(self.dummy_conf, random_pw, algo="aesgcm") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='aesgcm') + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="aesgcm", + ) self.assertFalse(result, "invalid match in aesgcm encrypt check") - encrypted = make_encrypt(self.dummy_conf, random_pw, - algo="aesgcm_static") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='aesgcm_static') + encrypted = make_encrypt( + self.dummy_conf, random_pw, algo="aesgcm_static" + ) + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="aesgcm_static", + ) self.assertTrue(result, "mismatch in aesgcm_static encrypt check") def test_assure_reset_supported(self): """Test basic password reset token check for a fixed user and auth""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 - result = assure_reset_supported(self.dummy_conf, dummy_user, - DUMMY_SERVICE) + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 + result = assure_reset_supported( + self.dummy_conf, dummy_user, DUMMY_SERVICE + ) self.assertTrue(result, "failed assure reset supported") # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_generate_reset_token_fixed(self): """Test basic password reset token generate for a fixed string""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 timestamp = 42 expected = DUMMY_MODERN_PW_RESET_TOKEN - result = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) - self.assertEqual(expected, result, - "failed generate password reset token") + result = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) + self.assertEqual( + expected, result, "failed generate password reset token" + ) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_parse_reset_token_fixed(self): """Test basic password reset token parse for a fixed string""" timestamp = 42 - result = parse_reset_token(self.dummy_conf, - DUMMY_MODERN_PW_RESET_TOKEN, - DUMMY_SERVICE) + result = parse_reset_token( + self.dummy_conf, DUMMY_MODERN_PW_RESET_TOKEN, DUMMY_SERVICE + ) self.assertEqual(result[0], timestamp, "failed parse token time") - self.assertEqual(result[1], DUMMY_MODERN_PW_PBKDF2, - "failed parse token hash") + self.assertEqual( + result[1], DUMMY_MODERN_PW_PBKDF2, "failed parse token hash" + ) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_verify_reset_token_fixed(self): """Test basic password reset token verify for a fixed string""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 timestamp = 42 - result = verify_reset_token(self.dummy_conf, dummy_user, - DUMMY_MODERN_PW_RESET_TOKEN, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, + dummy_user, + DUMMY_MODERN_PW_RESET_TOKEN, + DUMMY_SERVICE, + timestamp, + ) self.assertTrue(result, "failed password reset token handling") def test_password_reset_token_generate_and_verify(self): """Test basic password reset token generate and verify helper""" random_pw = generate_random_password(self.dummy_conf) hashed_pw = make_hash(random_pw) - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = hashed_pw + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = hashed_pw timestamp = 42 - expected = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) + expected = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) parsed = parse_reset_token(self.dummy_conf, expected, DUMMY_SERVICE) self.assertEqual(parsed[0], timestamp, "failed parse token time") self.assertEqual(parsed[1], hashed_pw, "failed parse token hash") - result = verify_reset_token(self.dummy_conf, dummy_user, expected, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, dummy_user, expected, DUMMY_SERVICE, timestamp + ) self.assertTrue(result, "failed password reset token handling") def test_password_reset_token_verify_expired(self): """Test basic password reset token verify failure after it expired""" random_pw = generate_random_password(self.dummy_conf) hashed_pw = make_hash(random_pw) - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = hashed_pw + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = hashed_pw timestamp = 42 - expected = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) + expected = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) parsed = parse_reset_token(self.dummy_conf, expected, DUMMY_SERVICE) self.assertEqual(parsed[0], timestamp, "failed parse token time") self.assertEqual(parsed[1], hashed_pw, "failed parse token hash") timestamp = 4242 - result = verify_reset_token(self.dummy_conf, dummy_user, expected, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, dummy_user, expected, DUMMY_SERVICE, timestamp + ) self.assertFalse(result, "failed password reset token expiry check") def test_make_csrf_token_fixed(self): @@ -680,8 +833,9 @@ def test_make_csrf_token_fixed(self): client id. """ expected = DUMMY_CSRF_TOKEN - result = make_csrf_token(self.dummy_conf, DUMMY_METHOD, - DUMMY_OPERATION, DUMMY_ID) + result = make_csrf_token( + self.dummy_conf, DUMMY_METHOD, DUMMY_OPERATION, DUMMY_ID + ) self.assertEqual(expected, result, "failed make csrf token") def test_make_csrf_trust_token_fixed(self): @@ -689,8 +843,13 @@ def test_make_csrf_trust_token_fixed(self): client id and args. """ expected = DUMMY_CSRF_TRUST_TOKEN - result = make_csrf_trust_token(self.dummy_conf, DUMMY_METHOD, - DUMMY_OPERATION, DUMMY_ARGS, DUMMY_ID) + result = make_csrf_trust_token( + self.dummy_conf, + DUMMY_METHOD, + DUMMY_OPERATION, + DUMMY_ARGS, + DUMMY_ID, + ) self.assertEqual(expected, result, "failed make csrf trust token") def test_generate_random_password(self): @@ -705,8 +864,9 @@ def test_generate_random_password_fixed_seed(self): """Test basic generate password is constant with fixed random seed""" expected = DUMMY_GENERATED_PW result = generate_random_password(self.dummy_conf) - self.assertEqual(expected, result, - "failed generate password with fixed seed") + self.assertEqual( + expected, result, "failed generate password with fixed seed" + ) # TODO: migrate remaining inline checks from module here instead def test_existing_main(self): @@ -715,9 +875,11 @@ def raise_on_error_exit(exit_code): if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -726,5 +888,5 @@ def record_last_print(value): pwcrypto_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_safeeval.py b/tests/test_mig_shared_safeeval.py index 4bec4aa90..c7e827905 100644 --- a/tests/test_mig_shared_safeeval.py +++ b/tests/test_mig_shared_safeeval.py @@ -30,13 +30,11 @@ import os import sys -from tests.support import MigTestCase, testmain - from mig.shared.safeeval import * - +from tests.support import MigTestCase, testmain PWD_STR = os.getcwd() -PWD_BYTES = PWD_STR.encode('utf8') +PWD_BYTES = PWD_STR.encode("utf8") class MigSharedSafeeval(MigTestCase): @@ -44,49 +42,60 @@ class MigSharedSafeeval(MigTestCase): def test_subprocess_call(self): """Check that pwd call without args succeeds""" - retval = subprocess_call(['pwd'], stdout=subprocess_pipe) + retval = subprocess_call(["pwd"], stdout=subprocess_pipe) self.assertEqual(retval, 0, "unexpected subprocess call pwd retval") def test_subprocess_call_invalid(self): """Check that pwd call with invalid arg fails""" - retval = subprocess_call(['pwd', '-h'], stderr=subprocess_pipe) - self.assertNotEqual(retval, 0, - "unexpected subprocess call nosuchcommand retval") + retval = subprocess_call(["pwd", "-h"], stderr=subprocess_pipe) + self.assertNotEqual( + retval, 0, "unexpected subprocess call nosuchcommand retval" + ) def test_subprocess_check_output(self): """Check that pwd command output matches getcwd as bytes""" - data = subprocess_check_output(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_pipe).strip() - self.assertEqual(data, PWD_BYTES, - "mismatch in subprocess check pwd output") + data = subprocess_check_output( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_pipe + ).strip() + self.assertEqual( + data, PWD_BYTES, "mismatch in subprocess check pwd output" + ) def test_subprocess_check_output_text(self): """Check that pwd command output matches getcwd as string""" - data = subprocess_check_output(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_pipe, - text=True).strip() - self.assertEqual(data, PWD_STR, - "mismatch in subprocess check pwd output") + data = subprocess_check_output( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_pipe, text=True + ).strip() + self.assertEqual( + data, PWD_STR, "mismatch in subprocess check pwd output" + ) def test_subprocess_popen(self): """Check that pwd popen output matches getcwd as bytes""" - proc = subprocess_popen(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_stdout) + proc = subprocess_popen( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_stdout + ) retval = proc.wait() data = proc.stdout.read().strip() - self.assertEqual(data, PWD_BYTES, - "mismatch in subprocess popen pwd output") + self.assertEqual( + data, PWD_BYTES, "mismatch in subprocess popen pwd output" + ) def test_subprocess_popen_text(self): """Check that pwd popen output matches getcwd as string""" orig = os.getcwd() - proc = subprocess_popen(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_stdout, text=True) + proc = subprocess_popen( + ["pwd"], + stdout=subprocess_pipe, + stderr=subprocess_stdout, + text=True, + ) retval = proc.wait() data = proc.stdout.read().strip() - self.assertEqual(data, PWD_STR, - "mismatch in subprocess popen pwd output") + self.assertEqual( + data, PWD_STR, "mismatch in subprocess popen pwd output" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_safeinput.py b/tests/test_mig_shared_safeinput.py index 4a01dddd8..15e8b91d6 100644 --- a/tests/test_mig_shared_safeinput.py +++ b/tests/test_mig_shared_safeinput.py @@ -30,15 +30,26 @@ import base64 import codecs import sys + from past.builtins import basestring, unicode +from mig.shared.safeinput import ( + VALID_NAME_CHARACTERS, + InputException, + filter_commonname, +) +from mig.shared.safeinput import main as safeinput_main +from mig.shared.safeinput import ( + valid_alphanumeric, + valid_base_url, + valid_commonname, + valid_complex_url, + valid_path, + valid_printable, + valid_url, +) from tests.support import MigTestCase, testmain -from mig.shared.safeinput import main as safeinput_main, InputException, \ - filter_commonname, valid_alphanumeric, valid_commonname, valid_path, \ - valid_printable, valid_base_url, valid_url, valid_complex_url, \ - VALID_NAME_CHARACTERS - PY2 = sys.version_info[0] == 2 @@ -46,18 +57,18 @@ def as_string_of_unicode(value): assert isinstance(value, basestring) if not is_string_of_unicode(value): assert PY2, "unreachable unless Python 2" - return unicode(codecs.decode(value, 'utf8')) + return unicode(codecs.decode(value, "utf8")) return value def is_string_of_unicode(value): - return type(value) == type(u'') + return type(value) == type("") def _hex_wrap(val): """Insert a clearly marked hex representation of val""" # Please keep aligned with helper in mig/shared/functionality/autocreate.py - return ".X%s" % base64.b16encode(val.encode('utf8')).decode('utf8') + return ".X%s" % base64.b16encode(val.encode("utf8")).decode("utf8") class TestMigSharedSafeInput(MigTestCase): @@ -69,7 +80,7 @@ class TestMigSharedSafeInput(MigTestCase): PRINTABLE_CHARS = "abc123!@#" ACCENTED_VALID = "Renée Müller" ACCENTED_INVALID_EXOTIC = "Źaćâř" - DECOMPOSED_UNICODE = u"å" # a + combining ring above + DECOMPOSED_UNICODE = "å" # a + combining ring above # Commonname specific test constants APOSTROPHE_FULL_NAME = "John O'Connor" @@ -77,26 +88,28 @@ class TestMigSharedSafeInput(MigTestCase): APOSTROPHE_FULL_NAME_HEX = "John O.X27Connor" COMMONNAME_PERMITTED = ( - 'Firstname Lastname', - 'Test Æøå', - 'Test Überh4x0r', - 'Harry S. Truman', - u'Unicode æøå') + "Firstname Lastname", + "Test Æøå", + "Test Überh4x0r", + "Harry S. Truman", + "Unicode æøå", + ) COMMONNAME_PROHIBITED = ( "Invalid D'Angelo", - 'Test Maybe Invalid Źacãŕ', - 'Test Invalid ?', - 'Test HTML Invalid ') + "Test Maybe Invalid Źacãŕ", + "Test Invalid ?", + "Test HTML Invalid ", + ) - BASE_URL = 'https://www.migrid.org' - REGULAR_URL = 'https://www.migrid.org/wsgi-bin/ls.py?path=README&flags=v' - COMPLEX_URL = 'https://www.migrid.org/abc123@some.org/ls.py?path=R+D#HERE' - INVALID_URL = 'https://www.migrid.org/¾½§' + BASE_URL = "https://www.migrid.org" + REGULAR_URL = "https://www.migrid.org/wsgi-bin/ls.py?path=README&flags=v" + COMPLEX_URL = "https://www.migrid.org/abc123@some.org/ls.py?path=R+D#HERE" + INVALID_URL = "https://www.migrid.org/¾½§" def _provide_configuration(self): """Provide test configuration""" - return 'testconfig' + return "testconfig" def test_commonname_valid(self): """Test valid_commonname with acceptable and prohibited names""" @@ -130,7 +143,7 @@ def test_commonname_filter(self): self.assertTrue(len(filtered_cn) < len(test_cn_unicode)) # With default skip all chars in filtered_cn must be in original overlap = [i for i in filtered_cn if i in test_cn_unicode] - self.assertEqual(''.join(overlap), filtered_cn) + self.assertEqual("".join(overlap), filtered_cn) def test_commonname_filter_hexlify_illegal(self): """Test filter_commonname with hex encoding of illegal chars""" @@ -145,21 +158,23 @@ def test_commonname_filter_hexlify_illegal(self): filtered_cn = filter_commonname(test_cn, illegal_handler=_hex_wrap) # Invalid should be replaced with hexlify illegal_handler self.assertNotEqual(filtered_cn, test_cn_unicode) - self.assertIn('.X', filtered_cn) + self.assertIn(".X", filtered_cn) self.assertTrue(len(filtered_cn) > len(test_cn_unicode)) def test_filter_commonname_apostrophe_name_skip_illegal(self): """Test apostrophe handling with skip illegal_handler""" - result = filter_commonname(self.APOSTROPHE_FULL_NAME, - illegal_handler=None) + result = filter_commonname( + self.APOSTROPHE_FULL_NAME, illegal_handler=None + ) self.assertNotEqual(result, self.APOSTROPHE_FULL_NAME) self.assertNotIn("'", result) self.assertEqual(result, self.APOSTROPHE_FULL_NAME_SKIP) def test_filter_commonname_apostrophe_name_hexlify_illegal(self): """Test apostrophe handling with hex encode illegal_handler""" - result = filter_commonname(self.APOSTROPHE_FULL_NAME, - illegal_handler=_hex_wrap) + result = filter_commonname( + self.APOSTROPHE_FULL_NAME, illegal_handler=_hex_wrap + ) self.assertNotEqual(result, self.APOSTROPHE_FULL_NAME) self.assertNotIn("'", result) self.assertEqual(result, self.APOSTROPHE_FULL_NAME_HEX) @@ -283,14 +298,17 @@ class TestMigSharedSafeInput__legacy(MigTestCase): # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): """Run built-in self-tests and check output""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -300,5 +318,5 @@ def record_last_print(value): safeinput_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_serial.py b/tests/test_mig_shared_serial.py index c577e0b4e..c42dc0d8b 100644 --- a/tests/test_mig_shared_serial.py +++ b/tests/test_mig_shared_serial.py @@ -30,12 +30,17 @@ import os import sys +from mig.shared.serial import * from tests.support import MigTestCase, temppath, testmain -from mig.shared.serial import * class BasicSerial(MigTestCase): - BASIC_OBJECT = {'abc': 123, 'def': 'def', 'ghi': 42.0, 'accented': 'TéstÆøå'} + BASIC_OBJECT = { + "abc": 123, + "def": "def", + "ghi": 42.0, + "accented": "TéstÆøå", + } def test_pickle_string(self): orig = BasicSerial.BASIC_OBJECT @@ -49,5 +54,6 @@ def test_pickle_file(self): data = load(tmp_path) self.assertEqual(data, orig, "mismatch pickling string") -if __name__ == '__main__': + +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_settings.py b/tests/test_mig_shared_settings.py index 87b9e02e1..dabcbbc9e 100644 --- a/tests/test_mig_shared_settings.py +++ b/tests/test_mig_shared_settings.py @@ -32,23 +32,31 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.settings import load_settings, update_settings, \ - parse_and_save_settings +from mig.shared.settings import ( + load_settings, + parse_and_save_settings, + update_settings, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) DUMMY_USER = "dummy-user" -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_SETTINGS_DIR = "dummy_user_settings" DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) -DUMMY_SYSTEM_FILES_DIR = 'dummy_system_files' +DUMMY_SYSTEM_FILES_DIR = "dummy_system_files" DUMMY_SYSTEM_FILES_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SYSTEM_FILES_DIR) -DUMMY_TMP_DIR = 'dummy_tmp' -DUMMY_TMP_FILE = 'settings.mRSL' +DUMMY_TMP_DIR = "dummy_tmp" +DUMMY_TMP_FILE = "settings.mRSL" DUMMY_TMP_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_TMP_DIR) DUMMY_MRSL_PATH = os.path.join(DUMMY_TMP_PATH, DUMMY_TMP_FILE) -DUMMY_USER_INTERFACE = ['V3', 'V42'] -DUMMY_DEFAULT_UI = 'V42' +DUMMY_USER_INTERFACE = ["V3", "V42"] +DUMMY_DEFAULT_UI = "V42" DUMMY_INIT_MRSL = """ ::EMAIL:: john@doe.org @@ -65,10 +73,12 @@ ::SITE_USER_MENU:: people """ -DUMMY_CONF = FakeConfiguration(user_settings=DUMMY_SETTINGS_PATH, - mig_system_files=DUMMY_SYSTEM_FILES_PATH, - user_interface=DUMMY_USER_INTERFACE, - new_user_default_ui=DUMMY_DEFAULT_UI) +DUMMY_CONF = FakeConfiguration( + user_settings=DUMMY_SETTINGS_PATH, + mig_system_files=DUMMY_SYSTEM_FILES_PATH, + user_interface=DUMMY_USER_INTERFACE, + new_user_default_ui=DUMMY_DEFAULT_UI, +) class MigSharedSettings(MigTestCase): @@ -82,28 +92,31 @@ def test_settings_save_load(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, 'settings') + saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, "settings") self.assertTrue(os.path.exists(saved_path)) settings = load_settings(DUMMY_USER, DUMMY_CONF) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(settings, dict)) - self.assertEqual(settings['EMAIL'], ['john@doe.org']) - self.assertEqual(settings['SITE_USER_MENU'], - ['sharelinks', 'people', 'peers']) + self.assertEqual(settings["EMAIL"], ["john@doe.org"]) + self.assertEqual( + settings["SITE_USER_MENU"], ["sharelinks", "people", "peers"] + ) # NOTE: we no longer auto save default values for optional vars for key in settings.keys(): - self.assertTrue(key in ['EMAIL', 'SITE_USER_MENU']) + self.assertTrue(key in ["EMAIL", "SITE_USER_MENU"]) # Any saved USER_INTERFACE value must match configured default if set - self.assertEqual(settings.get('USER_INTERFACE', DUMMY_DEFAULT_UI), - DUMMY_DEFAULT_UI) + self.assertEqual( + settings.get("USER_INTERFACE", DUMMY_DEFAULT_UI), DUMMY_DEFAULT_UI + ) def test_settings_replace(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) @@ -113,25 +126,27 @@ def test_settings_replace(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_UPDATE_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) updated = load_settings(DUMMY_USER, DUMMY_CONF) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(updated, dict)) - self.assertEqual(updated['EMAIL'], ['jane@doe.org']) - self.assertEqual(updated['SITE_USER_MENU'], ['people']) + self.assertEqual(updated["EMAIL"], ["jane@doe.org"]) + self.assertEqual(updated["SITE_USER_MENU"], ["people"]) def test_update_settings(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) @@ -141,20 +156,21 @@ def test_update_settings(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - changes = {'EMAIL': ['john@doe.org', 'jane@doe.org']} + changes = {"EMAIL": ["john@doe.org", "jane@doe.org"]} defaults = {} updated = update_settings(DUMMY_USER, DUMMY_CONF, changes, defaults) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(updated, dict)) - self.assertEqual(updated['EMAIL'], ['john@doe.org', 'jane@doe.org']) + self.assertEqual(updated["EMAIL"], ["john@doe.org", "jane@doe.org"]) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_ssh.py b/tests/test_mig_shared_ssh.py index 7aa4ba3ec..954bd8c1a 100644 --- a/tests/test_mig_shared_ssh.py +++ b/tests/test_mig_shared_ssh.py @@ -32,10 +32,18 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.ssh import supported_pub_key_parsers, parse_pub_key, \ - generate_ssh_rsa_key_pair +from mig.shared.ssh import ( + generate_ssh_rsa_key_pair, + parse_pub_key, + supported_pub_key_parsers, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) class MigSharedSsh(MigTestCase): @@ -45,11 +53,11 @@ def test_ssh_key_generate_and_parse(self): parsers = supported_pub_key_parsers() # NOTE: should return a non-empty dict of algos and parsers self.assertTrue(parsers) - self.assertTrue('ssh-rsa' in parsers) + self.assertTrue("ssh-rsa" in parsers) # Generate common sized keys and parse the result for keysize in (2048, 3072, 4096): - (priv_key, pub_key) = generate_ssh_rsa_key_pair(size=keysize) + priv_key, pub_key = generate_ssh_rsa_key_pair(size=keysize) self.assertTrue(priv_key) self.assertTrue(pub_key) @@ -57,15 +65,16 @@ def test_ssh_key_generate_and_parse(self): try: parsed = parse_pub_key(pub_key) except ValueError as vae: - #print("Error in parsing pub key: %r" % vae) + # print("Error in parsing pub key: %r" % vae) parsed = None self.assertIsNotNone(parsed) - (priv_key, pub_key) = generate_ssh_rsa_key_pair(size=keysize, - encode_utf8=True) + priv_key, pub_key = generate_ssh_rsa_key_pair( + size=keysize, encode_utf8=True + ) self.assertTrue(priv_key) self.assertTrue(pub_key) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_transferfunctions.py b/tests/test_mig_shared_transferfunctions.py index fd8c7d5cc..53665259a 100644 --- a/tests/test_mig_shared_transferfunctions.py +++ b/tests/test_mig_shared_transferfunctions.py @@ -31,16 +31,27 @@ import sys import tempfile -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, temppath, testmain -from mig.shared.transferfunctions import get_transfers_path, \ - load_data_transfers, create_data_transfer, delete_data_transfer, \ - lock_data_transfers, unlock_data_transfers +from mig.shared.transferfunctions import ( + create_data_transfer, + delete_data_transfer, + get_transfers_path, + load_data_transfers, + lock_data_transfers, + unlock_data_transfers, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + temppath, + testmain, +) DUMMY_USER = "dummy-user" DUMMY_ID = "dummy-id" -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" def noop(*args, **kwargs): @@ -55,13 +66,15 @@ class MigSharedTransferfunctions(MigTestCase): def before_each(self): test_user_home = temppath(DUMMY_HOME_DIR, self, ensure_dir=True) test_user_settings = cleanpath( - DUMMY_SETTINGS_DIR, self, ensure_dir=True) + DUMMY_SETTINGS_DIR, self, ensure_dir=True + ) # make two requisite root folders for the dummy user os.mkdir(os.path.join(test_user_home, DUMMY_USER)) os.mkdir(os.path.join(test_user_settings, DUMMY_USER)) # now create a configuration - self.dummy_conf = FakeConfiguration(user_home=test_user_home, - user_settings=test_user_settings) + self.dummy_conf = FakeConfiguration( + user_home=test_user_home, user_settings=test_user_settings + ) def test_transfers_basic_locking_shared(self): dummy_conf = self.dummy_conf @@ -82,9 +95,11 @@ def test_transfers_basic_locking_ro_to_rw_exclusive(self): # Non-blocking exclusive locking of shared lock must fail ro_lock = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) rw_lock = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) self.assertTrue(ro_lock) self.assertFalse(rw_lock) @@ -99,9 +114,11 @@ def test_transfers_basic_locking_exclusive(self): rw_lock = lock_data_transfers(transfers_path, exclusive=True) # Non-blocking repeated shared or exclusive locking must fail ro_lock_again = lock_data_transfers( - transfers_path, exclusive=False, blocking=False) + transfers_path, exclusive=False, blocking=False + ) rw_lock_again = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) self.assertTrue(rw_lock) self.assertFalse(ro_lock_again) @@ -112,18 +129,19 @@ def test_transfers_basic_locking_exclusive(self): def test_create_and_delete_transfer(self): dummy_conf = self.dummy_conf - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}) + success, out = create_data_transfer( + dummy_conf, DUMMY_USER, {"transfer_id": DUMMY_ID} + ) self.assertTrue(success and DUMMY_ID in out) - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and transfers.get(DUMMY_ID, None)) - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID) + success, out = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID) self.assertTrue(success and out == DUMMY_ID) - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and transfers.get(DUMMY_ID, None) is None) @@ -131,38 +149,48 @@ def test_transfers_shared_read_locking(self): dummy_conf = self.dummy_conf transfers_path = get_transfers_path(dummy_conf, DUMMY_USER) # Init a dummy transfer to read and delete later - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) # take a shared ro lock up front ro_lock = lock_data_transfers(transfers_path, exclusive=False) # cases: - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and DUMMY_ID in transfers) # Create with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Delete with repeated locking should fail - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=True, blocking=False) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=True, blocking=False + ) self.assertFalse(success) # Verify unchanged - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and DUMMY_ID in transfers) # Unlock all to leave critical section and allow clean up unlock_data_transfers(ro_lock) # Delete with locking should be fine again - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=True) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=True + ) self.assertTrue(success and out == DUMMY_ID) def test_transfers_exclusive_write_locking(self): @@ -174,40 +202,51 @@ def test_transfers_exclusive_write_locking(self): # cases: # Non-blocking load with repeated locking should fail - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER, - do_lock=True, blocking=False) + success, transfers = load_data_transfers( + dummy_conf, DUMMY_USER, do_lock=True, blocking=False + ) self.assertFalse(success) # Load without repeated locking should be fine - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER, - do_lock=False) + success, transfers = load_data_transfers( + dummy_conf, DUMMY_USER, do_lock=False + ) self.assertTrue(success) # Non-blocking create with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Create without repeated locking should be fine - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=False) + success, out = create_data_transfer( + dummy_conf, DUMMY_USER, {"transfer_id": DUMMY_ID}, do_lock=False + ) self.assertTrue(success) # Non-blocking delete with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Delete without repeated locking should be fine - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=False) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=False + ) self.assertTrue(success) unlock_data_transfers(rw_lock) -if __name__ == '__main__': +if __name__ == "__main__": testmain(failfast=True) diff --git a/tests/test_mig_shared_url.py b/tests/test_mig_shared_url.py index 1f029775d..b81cf9799 100644 --- a/tests/test_mig_shared_url.py +++ b/tests/test_mig_shared_url.py @@ -32,8 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import MigTestCase, FakeConfiguration, testmain from mig.shared.url import _get_site_urls, check_local_site_url +from tests.support import FakeConfiguration, MigTestCase, testmain def _generate_dynamic_site_urls(url_base_list): @@ -42,40 +42,58 @@ def _generate_dynamic_site_urls(url_base_list): """ site_urls = [] for url_base in url_base_list: - site_urls += ['%s' % url_base, '%s/' % url_base, - '%s/wsgi-bin/home.py' % url_base, - '%s/wsgi-bin/logout.py' % url_base, - '%s/wsgi-bin/logout.py?return_url=' % url_base, - '%s/wsgi-bin/logout.py?return_url=%s' % (url_base, - ENC_URL) - ] + site_urls += [ + "%s" % url_base, + "%s/" % url_base, + "%s/wsgi-bin/home.py" % url_base, + "%s/wsgi-bin/logout.py" % url_base, + "%s/wsgi-bin/logout.py?return_url=" % url_base, + "%s/wsgi-bin/logout.py?return_url=%s" % (url_base, ENC_URL), + ] return site_urls -DUMMY_CONF = FakeConfiguration(migserver_http_url='http://myfqdn.org', - migserver_https_url='https://myfqdn.org', - migserver_https_mig_cert_url='', - migserver_https_ext_cert_url='', - migserver_https_mig_oid_url='', - migserver_https_ext_oid_url='', - migserver_https_mig_oidc_url='', - migserver_https_ext_oidc_url='', - migserver_https_sid_url='', - migserver_public_url='', - migserver_public_alias_url='') -ENC_URL = 'https%3A%2F%2Fsomewhere.org%2Fsub%0A' +DUMMY_CONF = FakeConfiguration( + migserver_http_url="http://myfqdn.org", + migserver_https_url="https://myfqdn.org", + migserver_https_mig_cert_url="", + migserver_https_ext_cert_url="", + migserver_https_mig_oid_url="", + migserver_https_ext_oid_url="", + migserver_https_mig_oidc_url="", + migserver_https_ext_oidc_url="", + migserver_https_sid_url="", + migserver_public_url="", + migserver_public_alias_url="", +) +ENC_URL = "https%3A%2F%2Fsomewhere.org%2Fsub%0A" # Static site-anchored and dynamic full URLs for local site URL check success -LOCAL_SITE_URLS = ['', 'abc', 'abc.txt', '/', '/bla', '/bla#anchor', - '/bla/', '/bla/#anchor', '/bla/bla', '/bla/bla/bla', - '//bla//', './bla', './bla/', './bla/bla', - './bla/bla/bla', 'logout.py', 'logout.py?bla=', - '/cgi-sid/logout.py', '/cgi-sid/logout.py?bla=bla', - '/cgi-sid/logout.py?return_url=%s' % ENC_URL, - ] +LOCAL_SITE_URLS = [ + "", + "abc", + "abc.txt", + "/", + "/bla", + "/bla#anchor", + "/bla/", + "/bla/#anchor", + "/bla/bla", + "/bla/bla/bla", + "//bla//", + "./bla", + "./bla/", + "./bla/bla", + "./bla/bla/bla", + "logout.py", + "logout.py?bla=", + "/cgi-sid/logout.py", + "/cgi-sid/logout.py?bla=bla", + "/cgi-sid/logout.py?return_url=%s" % ENC_URL, +] LOCAL_BASE_URLS = _get_site_urls(DUMMY_CONF) LOCAL_SITE_URLS += _generate_dynamic_site_urls(LOCAL_SITE_URLS) # Dynamic full URLs for local site URL check failure -REMOTE_BASE_URLS = ['https://someevilsite.com', 'ftp://someevilsite.com'] +REMOTE_BASE_URLS = ["https://someevilsite.com", "ftp://someevilsite.com"] REMOTE_SITE_URLS = _generate_dynamic_site_urls(REMOTE_BASE_URLS) @@ -85,15 +103,19 @@ class BasicUrl(MigTestCase): def test_valid_local_site_urls(self): """Check known valid static and dynamic URLs""" for url in LOCAL_SITE_URLS: - self.assertTrue(check_local_site_url(DUMMY_CONF, url), - "Local site url should succeed for %s" % url) + self.assertTrue( + check_local_site_url(DUMMY_CONF, url), + "Local site url should succeed for %s" % url, + ) def test_invalid_local_site_urls(self): """Check known invalid URLs""" for url in REMOTE_SITE_URLS: - self.assertFalse(check_local_site_url(DUMMY_CONF, url), - "Local site url should fail for %s" % url) + self.assertFalse( + check_local_site_url(DUMMY_CONF, url), + "Local site url should fail for %s" % url, + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_useradm.py b/tests/test_mig_shared_useradm.py index e5d1a3b16..be04ff21b 100644 --- a/tests/test_mig_shared_useradm.py +++ b/tests/test_mig_shared_useradm.py @@ -38,62 +38,88 @@ from past.builtins import basestring # Imports required for the unit test wrapping -from mig.shared.defaults import DEFAULT_USER_ID_FORMAT, htaccess_filename, \ - keyword_auto +from mig.shared.defaults import ( + DEFAULT_USER_ID_FORMAT, + htaccess_filename, + keyword_auto, +) + # Imports of the code under test -from mig.shared.useradm import _ensure_dirs_needed_for_userdb, \ - assure_current_htaccess, create_user +from mig.shared.useradm import ( + _ensure_dirs_needed_for_userdb, + assure_current_htaccess, + create_user, +) + # Imports required for the unit tests themselves -from tests.support import MIG_BASE, TEST_OUTPUT_DIR, FakeConfiguration, \ - MigTestCase, cleanpath, ensure_dirs_exist, is_path_within, testmain +from tests.support import ( + MIG_BASE, + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + ensure_dirs_exist, + is_path_within, + testmain, +) from tests.support.fixturesupp import FixtureAssertMixin from tests.support.picklesupp import PickleAssertMixin -DUMMY_USER = 'dummy-user' -DUMMY_STALE_USER = 'dummy-stale-user' -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' -DUMMY_MRSL_FILES_DIR = 'dummy_mrsl_files' -DUMMY_RESOURCE_PENDING_DIR = 'dummy_resource_pending' -DUMMY_CACHE_DIR = 'dummy_user_cache' +DUMMY_USER = "dummy-user" +DUMMY_STALE_USER = "dummy-stale-user" +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" +DUMMY_MRSL_FILES_DIR = "dummy_mrsl_files" +DUMMY_RESOURCE_PENDING_DIR = "dummy_resource_pending" +DUMMY_CACHE_DIR = "dummy_user_cache" DUMMY_HOME_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_HOME_DIR) DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) DUMMY_MRSL_FILES_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_MRSL_FILES_DIR) -DUMMY_RESOURCE_PENDING_PATH = os.path.join(TEST_OUTPUT_DIR, - DUMMY_RESOURCE_PENDING_DIR) +DUMMY_RESOURCE_PENDING_PATH = os.path.join( + TEST_OUTPUT_DIR, DUMMY_RESOURCE_PENDING_DIR +) DUMMY_CACHE_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_CACHE_DIR) -DUMMY_USER_DICT = {'distinguished_name': DUMMY_USER, - 'short_id': '%s@my.org' % DUMMY_USER} -DUMMY_REL_HTACCESS_PATH = os.path.join(DUMMY_HOME_DIR, DUMMY_USER, - htaccess_filename) -DUMMY_HTACCESS_PATH = DUMMY_REL_HTACCESS_PATH.replace(DUMMY_HOME_DIR, - DUMMY_HOME_PATH) -DUMMY_REL_HTACCESS_BACKUP_PATH = os.path.join(DUMMY_CACHE_DIR, DUMMY_USER, - "%s.old" % htaccess_filename) +DUMMY_USER_DICT = { + "distinguished_name": DUMMY_USER, + "short_id": "%s@my.org" % DUMMY_USER, +} +DUMMY_REL_HTACCESS_PATH = os.path.join( + DUMMY_HOME_DIR, DUMMY_USER, htaccess_filename +) +DUMMY_HTACCESS_PATH = DUMMY_REL_HTACCESS_PATH.replace( + DUMMY_HOME_DIR, DUMMY_HOME_PATH +) +DUMMY_REL_HTACCESS_BACKUP_PATH = os.path.join( + DUMMY_CACHE_DIR, DUMMY_USER, "%s.old" % htaccess_filename +) DUMMY_HTACCESS_BACKUP_PATH = DUMMY_REL_HTACCESS_BACKUP_PATH.replace( - DUMMY_CACHE_DIR, DUMMY_CACHE_PATH) + DUMMY_CACHE_DIR, DUMMY_CACHE_PATH +) DUMMY_REQUIRE_USER = 'require user "%s"' % DUMMY_USER DUMMY_REQUIRE_STALE_USER = 'require user "%s"' % DUMMY_STALE_USER -DUMMY_CONF = FakeConfiguration(user_home=DUMMY_HOME_PATH, - user_settings=DUMMY_SETTINGS_PATH, - user_cache=DUMMY_CACHE_PATH, - mrsl_files_dir=DUMMY_MRSL_FILES_PATH, - resource_pending=DUMMY_RESOURCE_PENDING_PATH, - site_user_id_format=DEFAULT_USER_ID_FORMAT, - short_title='dummysite', - support_email='support@dummysite.org', - user_openid_providers=['dummyoidprovider.org'], - ) - - -class TestMigSharedUsedadm_create_user(MigTestCase, - FixtureAssertMixin, - PickleAssertMixin): +DUMMY_CONF = FakeConfiguration( + user_home=DUMMY_HOME_PATH, + user_settings=DUMMY_SETTINGS_PATH, + user_cache=DUMMY_CACHE_PATH, + mrsl_files_dir=DUMMY_MRSL_FILES_PATH, + resource_pending=DUMMY_RESOURCE_PENDING_PATH, + site_user_id_format=DEFAULT_USER_ID_FORMAT, + short_title="dummysite", + support_email="support@dummysite.org", + user_openid_providers=["dummyoidprovider.org"], +) + + +class TestMigSharedUsedadm_create_user( + MigTestCase, FixtureAssertMixin, PickleAssertMixin +): """Coverage of useradm create_user function.""" - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" TEST_USER_DN_GDP = "%s/GDP" % (TEST_USER_DN,) - TEST_USER_PASSWORD_HASH = 'PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix' + TEST_USER_PASSWORD_HASH = ( + "PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix" + ) def before_each(self): configuration = self.configuration @@ -101,73 +127,82 @@ def before_each(self): _ensure_dirs_needed_for_userdb(self.configuration) self.expected_user_db_home = os.path.normpath( - configuration.user_db_home) + configuration.user_db_home + ) self.expected_user_db_file = os.path.join( - self.expected_user_db_home, 'MiG-users.db') + self.expected_user_db_home, "MiG-users.db" + ) ensure_dirs_exist(self.configuration.mig_system_files) def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_user_db_is_created(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "user@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['password'] = "password" - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "user@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["password"] = "password" + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) # presence of user home path_kind = MigTestCase._absolute_path_kind(self.expected_user_db_home) - self.assertEqual(path_kind, 'dir') + self.assertEqual(path_kind, "dir") # presence of user db path_kind = MigTestCase._absolute_path_kind(self.expected_user_db_file) - self.assertEqual(path_kind, 'file') + self.assertEqual(path_kind, "file") def test_user_creation_records_a_user(self): def _adjust_user_dict_for_compare(user_obj): obj = dict(user_obj) - obj['created'] = 9999999999.9999999 - obj['expire'] = 9999999999.9999999 - obj['unique_id'] = '__UNIQUE_ID__' + obj["created"] = 9999999999.9999999 + obj["expire"] = 9999999999.9999999 + obj["unique_id"] = "__UNIQUE_ID__" return obj expected_user_id = self.TEST_USER_DN expected_user_password_hash = self.TEST_USER_PASSWORD_HASH user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = expected_user_password_hash - - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) - - pickled = self.assertPickledFile(self.expected_user_db_file, - apply_hints=['convert_dict_bytes_to_strings_kv']) + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = expected_user_password_hash + + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) + + pickled = self.assertPickledFile( + self.expected_user_db_file, + apply_hints=["convert_dict_bytes_to_strings_kv"], + ) self.assertIn(expected_user_id, pickled) - prepared = self.prepareFixtureAssert('MiG-users.db--example', - fixture_format='json') + prepared = self.prepareFixtureAssert( + "MiG-users.db--example", fixture_format="json" + ) # TODO: remove resetting the handful of keys here # this is done to allow the comparision to succeed actual_user_object = _adjust_user_dict_for_compare( - pickled[expected_user_id]) + pickled[expected_user_id] + ) expected_user_object = _adjust_user_dict_for_compare( - prepared.fixture_data[expected_user_id]) + prepared.fixture_data[expected_user_id] + ) self.maxDiff = None self.assertEqual(actual_user_object, expected_user_object) @@ -176,95 +211,117 @@ def test_user_creation_records_a_user_with_gdp(self): self.configuration.site_enable_gdp = True user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = self.TEST_USER_PASSWORD_HASH + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = self.TEST_USER_PASSWORD_HASH # explicitly setting set a DN suffixed user DN to force GDP - user_dict['distinguished_name'] = self.TEST_USER_DN_GDP + user_dict["distinguished_name"] = self.TEST_USER_DN_GDP try: - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) except: self.assertFalse(True, "should not be reached") def test_user_creation_and_renew_records_a_user(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = self.TEST_USER_PASSWORD_HASH + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = self.TEST_USER_PASSWORD_HASH try: - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) except: self.assertFalse(True, "should not be reached") try: - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) except: self.assertFalse(True, "should not be reached") def test_user_creation_fails_in_renew_when_locked(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "test@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['locality'] = "" - user_dict['organizational_unit'] = "" - user_dict['password'] = "" - user_dict['password_hash'] = self.TEST_USER_PASSWORD_HASH + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "test@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["locality"] = "" + user_dict["organizational_unit"] = "" + user_dict["password"] = "" + user_dict["password_hash"] = self.TEST_USER_PASSWORD_HASH # explicitly setting set a DN suffixed user DN to force GDP - user_dict['distinguished_name'] = self.TEST_USER_DN_GDP - user_dict['status'] = "locked" + user_dict["distinguished_name"] = self.TEST_USER_DN_GDP + user_dict["status"] = "locked" try: - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) except: self.assertFalse(True, "should not be reached") def test_user_creation_with_id_collission_fails(self): user_dict = {} - user_dict['full_name'] = "Test User" - user_dict['organization'] = "Test Org" - user_dict['state'] = "NA" - user_dict['country'] = "DK" - user_dict['email'] = "user@example.com" - user_dict['comment'] = "This is the create comment" - user_dict['password'] = "password" - user_dict['distinguished_name'] = self.TEST_USER_DN + user_dict["full_name"] = "Test User" + user_dict["organization"] = "Test Org" + user_dict["state"] = "NA" + user_dict["country"] = "DK" + user_dict["email"] = "user@example.com" + user_dict["comment"] = "This is the create comment" + user_dict["password"] = "password" + user_dict["distinguished_name"] = self.TEST_USER_DN try: - create_user(user_dict, self.configuration, - keyword_auto, default_renew=True) + create_user( + user_dict, self.configuration, keyword_auto, default_renew=True + ) except: self.assertFalse(True, "should not be reached") # NOTE: reset distinguished_name and introduce an ID conflict to test - del user_dict['distinguished_name'] - user_dict['organization'] = "Another Org" + del user_dict["distinguished_name"] + user_dict["organization"] = "Another Org" with self.assertRaises(Exception): - create_user(user_dict, self.configuration, keyword_auto, - default_renew=True, ask_renew=False) + create_user( + user_dict, + self.configuration, + keyword_auto, + default_renew=True, + ask_renew=False, + ) class MigSharedUseradm__assure_current_htaccess(MigTestCase): @@ -289,16 +346,15 @@ def assertHtaccessRequireUserClause(self, generated, expected): with io.open(generated) as htaccess_file: generated = htaccess_file.read() - generated_lines = generated.split('\n') + generated_lines = generated.split("\n") if not expected in generated_lines: raise AssertionError("no such require user line: %s" % expected) def test_skips_accounts_without_short_id(self): user_dict = {} user_dict.update(DUMMY_USER_DICT) - del user_dict['short_id'] - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + del user_dict["short_id"] + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) try: path_kind = self.assertPathExists(DUMMY_REL_HTACCESS_PATH) @@ -310,8 +366,7 @@ def test_skips_accounts_without_short_id(self): def test_creates_missing_htaccess_file(self): user_dict = {} user_dict.update(DUMMY_USER_DICT) - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) path_kind = self.assertPathExists(DUMMY_REL_HTACCESS_PATH) # File should exist here and be valid @@ -320,26 +375,26 @@ def test_creates_missing_htaccess_file(self): # Backup file should exist here and be empty self.assertEqual(path_kind, "file") - self.assertHtaccessRequireUserClause(DUMMY_HTACCESS_PATH, - DUMMY_REQUIRE_USER) + self.assertHtaccessRequireUserClause( + DUMMY_HTACCESS_PATH, DUMMY_REQUIRE_USER + ) def test_repairs_existing_stale_htaccess_file(self): user_dict = {} user_dict.update(DUMMY_USER_DICT) # Fake stale user ID directly through DN - user_dict['distinguished_name'] = DUMMY_STALE_USER - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + user_dict["distinguished_name"] = DUMMY_STALE_USER + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) # Verify stale - self.assertHtaccessRequireUserClause(DUMMY_HTACCESS_PATH, - DUMMY_REQUIRE_STALE_USER) + self.assertHtaccessRequireUserClause( + DUMMY_HTACCESS_PATH, DUMMY_REQUIRE_STALE_USER + ) # Reset stale user ID and retry user_dict = {} user_dict.update(DUMMY_USER_DICT) - assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, - False) + assure_current_htaccess(DUMMY_CONF, DUMMY_USER, user_dict, False, False) path_kind = self.assertPathExists(DUMMY_REL_HTACCESS_PATH) # File should exist here and be valid @@ -348,9 +403,10 @@ def test_repairs_existing_stale_htaccess_file(self): # Backup file should exist here and be empty self.assertEqual(path_kind, "file") - self.assertHtaccessRequireUserClause(DUMMY_HTACCESS_PATH, - DUMMY_REQUIRE_USER) + self.assertHtaccessRequireUserClause( + DUMMY_HTACCESS_PATH, DUMMY_REQUIRE_USER + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_userdb.py b/tests/test_mig_shared_userdb.py index 2efc524fc..083f14092 100644 --- a/tests/test_mig_shared_userdb.py +++ b/tests/test_mig_shared_userdb.py @@ -34,17 +34,26 @@ # Imports required for the unit test wrapping from mig.shared.base import distinguished_name_to_user from mig.shared.fileio import delete_file -from mig.shared.serial import loads, dumps +from mig.shared.serial import dumps, loads + # Imports of the code under test -from mig.shared.userdb import default_db_path, load_user_db, load_user_dict, \ - lock_user_db, save_user_db, save_user_dict, unlock_user_db, \ - update_user_dict +from mig.shared.userdb import ( + default_db_path, + load_user_db, + load_user_dict, + lock_user_db, + save_user_db, + save_user_dict, + unlock_user_db, + update_user_dict, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -TEST_USER_ID = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' -THIS_USER_ID = '/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=This User/emailAddress=this.user@here.org' -OTHER_USER_ID = '/C=DK/ST=NA/L=NA/O=Other Org/OU=NA/CN=Other User/emailAddress=other.user@there.org' +TEST_USER_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" +THIS_USER_ID = "/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=This User/emailAddress=this.user@here.org" +OTHER_USER_ID = "/C=DK/ST=NA/L=NA/O=Other Org/OU=NA/CN=Other User/emailAddress=other.user@there.org" class TestMigSharedUserDB(MigTestCase): @@ -52,7 +61,7 @@ class TestMigSharedUserDB(MigTestCase): def _provide_configuration(self): """Get test configuration""" - return 'testconfig' + return "testconfig" # Helper methods def _generate_sample_db(self, content=None): @@ -60,7 +69,7 @@ def _generate_sample_db(self, content=None): if content is None: sample_db = { TEST_USER_ID: distinguished_name_to_user(TEST_USER_ID), - THIS_USER_ID: distinguished_name_to_user(THIS_USER_ID) + THIS_USER_ID: distinguished_name_to_user(THIS_USER_ID), } else: sample_db = content @@ -78,10 +87,12 @@ def before_each(self): """Set up test configuration and reset user DB paths""" ensure_dirs_exist(self.configuration.user_db_home) ensure_dirs_exist(self.configuration.mig_server_home) - self.user_db_path = os.path.join(self.configuration.user_db_home, - "MiG-users.db") - self.legacy_db_path = os.path.join(self.configuration.mig_server_home, - "MiG-users.db") + self.user_db_path = os.path.join( + self.configuration.user_db_home, "MiG-users.db" + ) + self.legacy_db_path = os.path.join( + self.configuration.mig_server_home, "MiG-users.db" + ) # Clear any existing test DBs if os.path.exists(self.user_db_path): @@ -95,15 +106,15 @@ def before_each(self): def test_default_db_path(self): """Test default_db_path returns correct path structure""" - expected = os.path.join(self.configuration.user_db_home, - "MiG-users.db") + expected = os.path.join(self.configuration.user_db_home, "MiG-users.db") result = default_db_path(self.configuration) self.assertEqual(result, expected) # Test legacy path fallback - self.configuration.user_db_home = '/no-such-dir' - expected_legacy = os.path.join(self.configuration.mig_server_home, - "MiG-users.db") + self.configuration.user_db_home = "/no-such-dir" + expected_legacy = os.path.join( + self.configuration.mig_server_home, "MiG-users.db" + ) result = default_db_path(self.configuration) self.assertEqual(result, expected_legacy) @@ -149,8 +160,7 @@ def test_load_user_db_direct(self): def test_load_user_db_missing(self): """Test loading missing user database""" - db_path = os.path.join( - self.configuration.user_db_home, "no-such-db.db") + db_path = os.path.join(self.configuration.user_db_home, "no-such-db.db") try: loaded = load_user_db(db_path) except Exception: @@ -190,8 +200,9 @@ def test_load_user_dict_missing(self): """Test loading non-existent user from DB""" self._create_sample_db() try: - loaded = load_user_dict(self.logger, "no-such-user", - self.user_db_path) + loaded = load_user_dict( + self.logger, "no-such-user", self.user_db_path + ) except Exception: loaded = None self.assertIsNone(loaded) @@ -200,8 +211,9 @@ def test_load_user_dict_existing(self): """Test loading existing user from DB""" sample_db = self._create_sample_db() try: - test_user_data = load_user_dict(self.logger, TEST_USER_ID, - self.user_db_path) + test_user_data = load_user_dict( + self.logger, TEST_USER_ID, self.user_db_path + ) except Exception: test_user_data = None self.assertEqual(test_user_data, sample_db[TEST_USER_ID]) @@ -209,8 +221,9 @@ def test_load_user_dict_existing(self): def test_save_user_dict_new_user(self): """Test saving new user to database""" other_user = distinguished_name_to_user(OTHER_USER_ID) - save_status = save_user_dict(self.logger, OTHER_USER_ID, - other_user, self.user_db_path) + save_status = save_user_dict( + self.logger, OTHER_USER_ID, other_user, self.user_db_path + ) self.assertTrue(save_status) with open(self.user_db_path, "rb") as fh: @@ -223,8 +236,9 @@ def test_save_user_dict_update(self): sample_db = self._create_sample_db() changed = distinguished_name_to_user(THIS_USER_ID) changed.update({"Organization": "UPDATED", "new_field": "ADDED"}) - save_status = save_user_dict(self.logger, THIS_USER_ID, - changed, self.user_db_path) + save_status = save_user_dict( + self.logger, THIS_USER_ID, changed, self.user_db_path + ) self.assertTrue(save_status) with open(self.user_db_path, "rb") as fh: @@ -235,9 +249,12 @@ def test_save_user_dict_update(self): def test_update_user_dict(self): """Test update_user_dict with partial changes""" sample_db = self._create_sample_db() - updated = update_user_dict(self.logger, THIS_USER_ID, - {"Organization": "CHANGED"}, - self.user_db_path) + updated = update_user_dict( + self.logger, + THIS_USER_ID, + {"Organization": "CHANGED"}, + self.user_db_path, + ) self.assertEqual(updated["Organization"], "CHANGED") with open(self.user_db_path, "rb") as fh: @@ -249,8 +266,12 @@ def test_update_user_dict_requirements(self): """Test update_user_dict with invalid user ID""" self.logger.forgive_errors() try: - result = update_user_dict(self.logger, "no-such-user", - {"field": "test"}, self.user_db_path) + result = update_user_dict( + self.logger, + "no-such-user", + {"field": "test"}, + self.user_db_path, + ) except Exception: result = None self.assertIsNone(result) @@ -274,6 +295,7 @@ def delayed_load(): return loaded import threading + delayed_thread = threading.Thread(target=delayed_load) delayed_thread.start() time.sleep(0.2) @@ -308,7 +330,8 @@ def test_load_user_db_pickle_abi(self): def test_lock_user_db_invalid_path(self): """Test locking on non-existent database path""" invalid_path = os.path.join( - self.configuration.user_db_home, "missing", "MiG-users.db") + self.configuration.user_db_home, "missing", "MiG-users.db" + ) flock = lock_user_db(invalid_path) self.assertIsNone(flock) @@ -322,7 +345,7 @@ def test_unlock_user_db_invalid(self): def test_load_user_db_corrupted(self): """Test loading corrupted user database""" - with open(self.user_db_path, 'w') as fh: + with open(self.user_db_path, "w") as fh: fh.write("invalid pickle content") with self.assertRaises(Exception): load_user_db(self.user_db_path) @@ -354,8 +377,9 @@ def test_save_user_dict_invalid_id(self): """Test saving user with invalid characters in ID""" invalid_id = "../../invalid.user" user_dict = distinguished_name_to_user(TEST_USER_ID) - save_status = save_user_dict(self.logger, invalid_id, - user_dict, self.user_db_path) + save_status = save_user_dict( + self.logger, invalid_id, user_dict, self.user_db_path + ) self.assertFalse(save_status) def test_update_user_dict_empty_changes(self): @@ -364,8 +388,9 @@ def test_update_user_dict_empty_changes(self): self.logger.forgive_errors() sample_db = self._create_sample_db() original = sample_db[THIS_USER_ID].copy() - updated = update_user_dict(self.logger, THIS_USER_ID, {}, - self.user_db_path) + updated = update_user_dict( + self.logger, THIS_USER_ID, {}, self.user_db_path + ) self.assertEqual(updated, original) # TODO: adjust API to allow enabling the next test @@ -395,5 +420,5 @@ def test_load_user_db_allows_concurrent_read_access(self): unlock_user_db(flock) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_userio.py b/tests/test_mig_shared_userio.py index daa69444a..c17d49b0b 100644 --- a/tests/test_mig_shared_userio.py +++ b/tests/test_mig_shared_userio.py @@ -29,11 +29,11 @@ import os import sys -from past.builtins import basestring, unicode -from tests.support import MigTestCase, testmain +from past.builtins import basestring, unicode from mig.shared.userio import main as userio_main +from tests.support import MigTestCase, testmain class MigSharedUserIO(MigTestCase): @@ -45,9 +45,11 @@ def raise_on_error_exit(exit_code): if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -56,5 +58,5 @@ def record_last_print(value): userio_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_vgrid.py b/tests/test_mig_shared_vgrid.py index 465b74390..89250898d 100644 --- a/tests/test_mig_shared_vgrid.py +++ b/tests/test_mig_shared_vgrid.py @@ -35,41 +35,69 @@ # Imports required for the unit test wrapping from mig.shared.base import client_id_dir from mig.shared.serial import dump + # Imports of the code under test -from mig.shared.vgrid import get_vgrid_workflow_jobs, legacy_main, \ - vgrid_add_entities, vgrid_add_members, vgrid_add_owners, \ - vgrid_add_resources, vgrid_add_workflow_jobs, vgrid_allow_restrict_write, \ - vgrid_exists, vgrid_flat_name, vgrid_is_default, vgrid_is_member, \ - vgrid_is_owner, vgrid_is_owner_or_member, vgrid_is_trigger, vgrid_list, \ - vgrid_list_parents, vgrid_list_subvgrids, vgrid_list_vgrids, \ - vgrid_match_resources, vgrid_nest_sep, vgrid_remove_entities, \ - vgrid_restrict_write, vgrid_set_entities, vgrid_set_members, \ - vgrid_set_owners, vgrid_set_workflow_jobs, vgrid_settings +from mig.shared.vgrid import ( + get_vgrid_workflow_jobs, + legacy_main, + vgrid_add_entities, + vgrid_add_members, + vgrid_add_owners, + vgrid_add_resources, + vgrid_add_workflow_jobs, + vgrid_allow_restrict_write, + vgrid_exists, + vgrid_flat_name, + vgrid_is_default, + vgrid_is_member, + vgrid_is_owner, + vgrid_is_owner_or_member, + vgrid_is_trigger, + vgrid_list, + vgrid_list_parents, + vgrid_list_subvgrids, + vgrid_list_vgrids, + vgrid_match_resources, + vgrid_nest_sep, + vgrid_remove_entities, + vgrid_restrict_write, + vgrid_set_entities, + vgrid_set_members, + vgrid_set_owners, + vgrid_set_workflow_jobs, + vgrid_settings, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain class TestMigSharedVgrid(MigTestCase): """Unit tests for vgrid helpers""" + # Standard user IDs following X.500 DN format - TEST_OWNER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/'\ - 'emailAddress=owner@example.com' - TEST_MEMBER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/'\ - 'emailAddress=member@example.com' - TEST_OUTSIDER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/'\ - 'emailAddress=outsider@example.com' - TEST_RESOURCE_DN = 'test.example.org' - TEST_OWNER_DIR = \ - '+C=DK+ST=NA+L=NA+O=Test_Org+OU=NA+CN=Test_Owner+'\ - 'emailAddress=owner@example.com' - TEST_JOB_ID = '12345667890' + TEST_OWNER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/" + "emailAddress=owner@example.com" + ) + TEST_MEMBER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/" + "emailAddress=member@example.com" + ) + TEST_OUTSIDER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/" + "emailAddress=outsider@example.com" + ) + TEST_RESOURCE_DN = "test.example.org" + TEST_OWNER_DIR = ( + "+C=DK+ST=NA+L=NA+O=Test_Org+OU=NA+CN=Test_Owner+" + "emailAddress=owner@example.com" + ) + TEST_JOB_ID = "12345667890" def _provide_configuration(self): """Return configuration to use""" - return 'testconfig' + return "testconfig" def before_each(self): """Create test environment for vgrid tests""" @@ -83,30 +111,33 @@ def before_each(self): ensure_dirs_exist(self.configuration.workflows_home) ensure_dirs_exist(self.configuration.workflows_db_home) ensure_dirs_exist(self.configuration.mig_system_files) - self.configuration.site_vgrid_label = 'VGridLabel' - self.configuration.vgrid_owners = 'owners.pck' - self.configuration.vgrid_members = 'members.pck' - self.configuration.vgrid_resources = 'resources.pck' - self.configuration.vgrid_settings = 'settings.pck' - self.configuration.vgrid_workflow_job_queue = 'jobqueue.pck' + self.configuration.site_vgrid_label = "VGridLabel" + self.configuration.vgrid_owners = "owners.pck" + self.configuration.vgrid_members = "members.pck" + self.configuration.vgrid_resources = "resources.pck" + self.configuration.vgrid_settings = "settings.pck" + self.configuration.vgrid_workflow_job_queue = "jobqueue.pck" self.configuration.site_enable_workflows = True # Default vgrid for comparison - self.default_vgrid = 'Generic' + self.default_vgrid = "Generic" # Create test vgrid structure using ensure_dirs_exist - self.test_vgrid = 'testvgrid' + self.test_vgrid = "testvgrid" self.test_vgrid_path = os.path.join( - self.configuration.vgrid_home, self.test_vgrid) + self.configuration.vgrid_home, self.test_vgrid + ) ensure_dirs_exist(self.test_vgrid_path) - vgrid_add_owners(self.configuration, self.test_vgrid, - [self.TEST_OWNER_DN]) + vgrid_add_owners( + self.configuration, self.test_vgrid, [self.TEST_OWNER_DN] + ) vgrid_add_members(self.configuration, self.test_vgrid, []) # Nested sub-VGrid - self.test_subvgrid = os.path.join(self.test_vgrid, 'subvgrid') + self.test_subvgrid = os.path.join(self.test_vgrid, "subvgrid") self.test_subvgrid_path = os.path.join( - self.configuration.vgrid_home, self.test_subvgrid) + self.configuration.vgrid_home, self.test_subvgrid + ) ensure_dirs_exist(self.test_subvgrid_path) vgrid_set_owners(self.configuration, self.test_vgrid, []) vgrid_set_members(self.configuration, self.test_vgrid, []) @@ -114,32 +145,32 @@ def before_each(self): def test_vgrid_is_default_match(self): """Test default vgrid detection with match""" - self.assertTrue(vgrid_is_default('Generic')) + self.assertTrue(vgrid_is_default("Generic")) self.assertTrue(vgrid_is_default(None)) - self.assertTrue(vgrid_is_default('')) + self.assertTrue(vgrid_is_default("")) def test_vgrid_is_default_mismatch(self): """Test default vgrid detection with mismatch""" self.assertFalse(vgrid_is_default(self.test_vgrid)) self.assertFalse(vgrid_is_default(self.test_subvgrid)) - self.assertFalse(vgrid_is_default('MiG')) + self.assertFalse(vgrid_is_default("MiG")) def test_vgrid_exists_for_existing(self): """Test vgrid existence checks for existing vgrids""" - self.assertTrue(vgrid_exists(self.configuration, 'Generic')) + self.assertTrue(vgrid_exists(self.configuration, "Generic")) self.assertTrue(vgrid_exists(self.configuration, None)) - self.assertTrue(vgrid_exists(self.configuration, '')) + self.assertTrue(vgrid_exists(self.configuration, "")) self.assertTrue(vgrid_exists(self.configuration, self.test_vgrid)) self.assertTrue(vgrid_exists(self.configuration, self.test_subvgrid)) def test_vgrid_exists_for_missing(self): """Test vgrid existence checks for missing vgrids""" - self.assertFalse(vgrid_exists(self.configuration, 'no_such_vgrid')) - self.assertFalse(vgrid_exists(self.configuration, 'no_such_vgrid/sub')) + self.assertFalse(vgrid_exists(self.configuration, "no_such_vgrid")) + self.assertFalse(vgrid_exists(self.configuration, "no_such_vgrid/sub")) # Parent exists but not child - yet, vgrid_exists defaults to recursive # and allow_missing so it will return True for ALL subvgrids of vgrids. - missing_child = os.path.join(self.test_subvgrid, 'missing_child') + missing_child = os.path.join(self.test_subvgrid, "missing_child") self.assertTrue(vgrid_exists(self.configuration, missing_child)) def test_vgrid_list_vgrids(self): @@ -152,15 +183,17 @@ def test_vgrid_list_vgrids(self): self.assertIn(self.test_subvgrid, all_vgrids) # Exclude default - status, no_default = vgrid_list_vgrids(self.configuration, - include_default=False) + status, no_default = vgrid_list_vgrids( + self.configuration, include_default=False + ) self.assertTrue(status) self.assertNotIn(self.default_vgrid, no_default) self.assertIn(self.test_vgrid, no_default) # Root filtering - status, root_vgrids = vgrid_list_vgrids(self.configuration, - root_vgrid=self.test_vgrid) + status, root_vgrids = vgrid_list_vgrids( + self.configuration, root_vgrid=self.test_vgrid + ) self.assertTrue(status) self.assertIn(self.test_subvgrid, root_vgrids) self.assertNotIn(self.test_vgrid, root_vgrids) @@ -168,134 +201,185 @@ def test_vgrid_list_vgrids(self): def test_vgrid_add_remove_entities(self): """Test entity management in vgrid""" # Test adding owner - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "owners", [self.TEST_OWNER_DN] + ) self.assertTrue(added, msg) time.sleep(0.1) # Ensure timestamp changes # Verify existence - self.assertTrue(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) + self.assertTrue( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) # Test removal without and with allow empty in turn - removed, msg = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN], False) + removed, msg = vgrid_remove_entities( + self.configuration, + self.test_vgrid, + "owners", + [self.TEST_OWNER_DN], + False, + ) self.assertFalse(removed, msg) - self.assertTrue(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) - removed, msg = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN], True) + self.assertTrue( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) + removed, msg = vgrid_remove_entities( + self.configuration, + self.test_vgrid, + "owners", + [self.TEST_OWNER_DN], + True, + ) self.assertTrue(removed, msg) - self.assertFalse(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) + self.assertFalse( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) def test_vgrid_settings_inheritance(self): """Test vgrid setting inheritance""" # Parent settings (MUST include required vgrid_name field) parent_settings = [ - ('vgrid_name', self.test_vgrid), - ('write_shared_files', 'None') + ("vgrid_name", self.test_vgrid), + ("write_shared_files", "None"), ] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', parent_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "settings", parent_settings + ) self.assertTrue(added, msg) # Verify inheritance - status, settings = vgrid_settings(self.test_subvgrid, self.configuration, - recursive=True, as_dict=True) + status, settings = vgrid_settings( + self.test_subvgrid, + self.configuration, + recursive=True, + as_dict=True, + ) self.assertTrue(status) - self.assertEqual(settings.get('write_shared_files'), 'None') + self.assertEqual(settings.get("write_shared_files"), "None") # Verify vgrid_name is preserved - self.assertEqual(settings['vgrid_name'], self.test_subvgrid) + self.assertEqual(settings["vgrid_name"], self.test_subvgrid) def test_vgrid_permission_checks(self): """Test owner/member permission verification""" # Setup owners and members - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'owners', [self.TEST_OWNER_DN]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "owners", [self.TEST_OWNER_DN] + ) self.assertTrue(added, msg) time.sleep(0.1) - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'members', [self.TEST_MEMBER_DN]) + added, msg = vgrid_add_entities( + self.configuration, + self.test_vgrid, + "members", + [self.TEST_MEMBER_DN], + ) self.assertTrue(added, msg) time.sleep(0.1) # Verify owner permissions - self.assertTrue(vgrid_is_owner(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) - self.assertTrue(vgrid_is_owner_or_member(self.test_vgrid, self.TEST_OWNER_DN, - self.configuration)) + self.assertTrue( + vgrid_is_owner( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) + self.assertTrue( + vgrid_is_owner_or_member( + self.test_vgrid, self.TEST_OWNER_DN, self.configuration + ) + ) # Verify member permissions - self.assertTrue(vgrid_is_member(self.test_vgrid, self.TEST_MEMBER_DN, - self.configuration)) - self.assertTrue(vgrid_is_owner_or_member(self.test_vgrid, self.TEST_MEMBER_DN, - self.configuration)) + self.assertTrue( + vgrid_is_member( + self.test_vgrid, self.TEST_MEMBER_DN, self.configuration + ) + ) + self.assertTrue( + vgrid_is_owner_or_member( + self.test_vgrid, self.TEST_MEMBER_DN, self.configuration + ) + ) # Verify non-member - self.assertFalse(vgrid_is_owner_or_member(self.test_vgrid, self.TEST_OUTSIDER_DN, - self.configuration)) + self.assertFalse( + vgrid_is_owner_or_member( + self.test_vgrid, self.TEST_OUTSIDER_DN, self.configuration + ) + ) def test_workflow_job_management(self): """Test workflow job queue handling""" job_entry = { - 'client_id': self.TEST_OWNER_DN, - 'job_id': self.TEST_JOB_ID + "client_id": self.TEST_OWNER_DN, + "job_id": self.TEST_JOB_ID, } - job_dir = os.path.join(self.configuration.mrsl_files_dir, - self.TEST_OWNER_DIR) + job_dir = os.path.join( + self.configuration.mrsl_files_dir, self.TEST_OWNER_DIR + ) ensure_dirs_exist(job_dir) - job_path = os.path.join(job_dir, '%s.mRSL' % self.TEST_JOB_ID) - dump({'job_id': self.TEST_JOB_ID, 'EXECUTE': 'uptime'}, job_path) + job_path = os.path.join(job_dir, "%s.mRSL" % self.TEST_JOB_ID) + dump({"job_id": self.TEST_JOB_ID, "EXECUTE": "uptime"}, job_path) # Add job - status, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'jobqueue', [job_entry]) + status, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "jobqueue", [job_entry] + ) self.assertTrue(status, msg) # TODO: adjust function to consistent return API? tuple vs list now. # List jobs - result = get_vgrid_workflow_jobs(self.configuration, - self.test_vgrid, True) + result = get_vgrid_workflow_jobs( + self.configuration, self.test_vgrid, True + ) if isinstance(result, tuple): status, msg = result jobs = [] else: - status, msg = True, '' + status, msg = True, "" jobs = result self.assertTrue(status) self.assertEqual(len(jobs), 1) - self.assertEqual(jobs[0]['job_id'], self.TEST_JOB_ID) + self.assertEqual(jobs[0]["job_id"], self.TEST_JOB_ID) # TODO: adjust function to consistent return API? tuple vs list now. # Remove job - result = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'jobqueue', [job_entry], True) + result = vgrid_remove_entities( + self.configuration, self.test_vgrid, "jobqueue", [job_entry], True + ) if isinstance(result, tuple): status, msg = result jobs = [] else: - status, msg = True, '' + status, msg = True, "" jobs = result self.assertTrue(status, msg) # TODO: adjust function to consistent return API? tuple vs list now. # Verify removal - result = get_vgrid_workflow_jobs(self.configuration, - self.test_vgrid, True) + result = get_vgrid_workflow_jobs( + self.configuration, self.test_vgrid, True + ) if isinstance(result, tuple): status, msg = result jobs = [] else: - status, msg = True, '' + status, msg = True, "" jobs = result self.assertTrue(status) self.assertEqual(len(jobs), 0) def test_vgrid_list_subvgrids(self): """Test retrieving subvgrids of given vgrid""" - status, subvgrids = vgrid_list_subvgrids(self.test_vgrid, - self.configuration) + status, subvgrids = vgrid_list_subvgrids( + self.test_vgrid, self.configuration + ) self.assertTrue(status) self.assertEqual(subvgrids, [self.test_subvgrid]) @@ -307,13 +391,18 @@ def test_vgrid_list_parents(self): def test_vgrid_match_resources(self): """Test resource filtering for vgrid""" - test_resources = ['res1', 'res2', 'invalid_res', self.TEST_RESOURCE_DN] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'resources', [self.TEST_RESOURCE_DN]) + test_resources = ["res1", "res2", "invalid_res", self.TEST_RESOURCE_DN] + added, msg = vgrid_add_entities( + self.configuration, + self.test_vgrid, + "resources", + [self.TEST_RESOURCE_DN], + ) self.assertTrue(added, msg) - matched = vgrid_match_resources(self.test_vgrid, test_resources, - self.configuration) + matched = vgrid_match_resources( + self.test_vgrid, test_resources, self.configuration + ) self.assertEqual(matched, [self.TEST_RESOURCE_DN]) # TODO: adjust API to allow enabling the next test @@ -321,25 +410,37 @@ def test_vgrid_match_resources(self): def test_vgrid_allow_restrict_write(self): """Test write restriction validation logic""" # Create parent-child structure - parent_write_setting = [('write_shared_files', 'none')] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', parent_write_setting) + parent_write_setting = [("write_shared_files", "none")] + added, msg = vgrid_add_entities( + self.configuration, + self.test_vgrid, + "settings", + parent_write_setting, + ) self.assertTrue(added, msg) # Child tries to set write_shared_files to "members" - result = vgrid_allow_restrict_write(self.test_subvgrid, 'members', - self.configuration, - auto_migrate=True) + result = vgrid_allow_restrict_write( + self.test_subvgrid, + "members", + self.configuration, + auto_migrate=True, + ) self.assertFalse(result) # Valid case: parent allows writes, child sets to "none" - parent_write_setting = [('write_shared_files', 'members')] - vgrid_set_entities(self.configuration, self.test_vgrid, - 'settings', parent_write_setting, True) + parent_write_setting = [("write_shared_files", "members")] + vgrid_set_entities( + self.configuration, + self.test_vgrid, + "settings", + parent_write_setting, + True, + ) - result = vgrid_allow_restrict_write(self.test_subvgrid, 'none', - self.configuration, - auto_migrate=True) + result = vgrid_allow_restrict_write( + self.test_subvgrid, "none", self.configuration, auto_migrate=True + ) self.assertTrue(result) # TODO: adjust API to allow enabling the next test @@ -347,289 +448,348 @@ def test_vgrid_allow_restrict_write(self): def test_vgrid_restrict_write(self): """Test write restriction enforcement""" # Setup test share - test_share = os.path.join(self.configuration.vgrid_files_home, - self.test_vgrid) + test_share = os.path.join( + self.configuration.vgrid_files_home, self.test_vgrid + ) ensure_dirs_exist(test_share) # Migrate to restricted mode - result = vgrid_restrict_write(self.test_vgrid, 'none', - self.configuration, auto_migrate=True) + result = vgrid_restrict_write( + self.test_vgrid, "none", self.configuration, auto_migrate=True + ) self.assertTrue(result) # Verify symlink points to readonly - flat_vgrid = self.test_vgrid.replace('/', vgrid_nest_sep) - read_path = os.path.join(self.configuration.vgrid_files_readonly, - flat_vgrid) - self.assertEqual(os.path.realpath(test_share), - os.path.realpath(read_path)) + flat_vgrid = self.test_vgrid.replace("/", vgrid_nest_sep) + read_path = os.path.join( + self.configuration.vgrid_files_readonly, flat_vgrid + ) + self.assertEqual( + os.path.realpath(test_share), os.path.realpath(read_path) + ) def test_vgrid_settings_scopes(self): """Test different vgrid settings lookup scopes""" # Local settings only - MUST include required vgrid_name field local_settings = [ - ('vgrid_name', self.test_subvgrid), - ('description', 'test subvgrid'), - ('write_shared_files', 'members'), + ("vgrid_name", self.test_subvgrid), + ("description", "test subvgrid"), + ("write_shared_files", "members"), ] - added, msg = vgrid_add_entities(self.configuration, self.test_subvgrid, - 'settings', local_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_subvgrid, "settings", local_settings + ) self.assertTrue(added, msg) # Check recursive vs direct status, recursive_settings = vgrid_settings( - self.test_subvgrid, self.configuration, recursive=True, as_dict=True) + self.test_subvgrid, + self.configuration, + recursive=True, + as_dict=True, + ) self.assertTrue(status) status, direct_settings = vgrid_settings( - self.test_subvgrid, self.configuration, recursive=False, as_dict=True) + self.test_subvgrid, + self.configuration, + recursive=False, + as_dict=True, + ) self.assertTrue(status) - self.assertEqual(direct_settings['description'], 'test subvgrid') - self.assertIn('write_shared_files', recursive_settings) # Inherited + self.assertEqual(direct_settings["description"], "test subvgrid") + self.assertIn("write_shared_files", recursive_settings) # Inherited def test_vgrid_add_owner_single(self): """Test vgrid_add_owners for initial owner""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN added, msg = vgrid_add_owners( - self.configuration, self.test_vgrid, [owner1]) + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1]) def test_vgrid_add_owner_with_rank(self): """Test owner ranking/ordering functionality""" - new_owner = '/C=DK/CN=New Owner' - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [new_owner], rank=0) # Add first owner + new_owner = "/C=DK/CN=New Owner" + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [new_owner], rank=0 + ) # Add first owner self.assertTrue(added, msg) - owners = vgrid_list(self.test_vgrid, 'owners', self.configuration)[1] + owners = vgrid_list(self.test_vgrid, "owners", self.configuration)[1] self.assertEqual(owners[0], new_owner) def test_vgrid_add_owners_twice(self): """Test vgrid_add_owners for initial and secondary owner""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2] + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1, owner2]) def test_vgrid_add_owners_twice_with_rank_zero(self): """Test vgrid_add_owners for two owners with 2nd inserted first""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2], rank=0) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2], rank=0 + ) self.assertTrue(added, msg) status, owners = vgrid_list( - self.test_vgrid, 'owners', self.configuration) + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner1]) def test_vgrid_add_owners_thrice_with_rank_one(self): """Test vgrid_add_owners for three owners with 3rd inserted in middle""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2] + ) self.assertTrue(added, msg) owner3 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner3], rank=1) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner3], rank=1 + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1, owner3, owner2]) def test_vgrid_add_owners_multiple(self): """Test vgrid_add_owners inserting list of owners at once""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) new_owners = [ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/' - 'emailAddress=owner5@example.com', - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/' - 'emailAddress=owner6@example.com' + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/" + "emailAddress=owner5@example.com", + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/" + "emailAddress=owner6@example.com", ] - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - new_owners) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, new_owners + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, new_owners) def test_vgrid_add_owners_repeat_inserts_only_one(self): """Test vgrid_add_owners inserting same owners twice does nothing""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(len(owners), 1) self.assertEqual(owners, [owner1]) def test_vgrid_add_owners_lifecycle(self): """Comprehensive life-cycle test for vgrid_add_owners functionality""" # Clear existing owners to start fresh - reset, msg = vgrid_set_owners(self.configuration, self.test_vgrid, [], - allow_empty=True) + reset, msg = vgrid_set_owners( + self.configuration, self.test_vgrid, [], allow_empty=True + ) self.assertTrue(reset, msg) # Test 1: Add initial owner owner1 = self.TEST_OWNER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) # Verify single owner - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner1]) # Test 2: Prepend new owner owner2 = self.TEST_MEMBER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner2], rank=0) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner2], rank=0 + ) self.assertTrue(added, msg) # Verify new order - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner1]) # Test 3: Append without rank owner3 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner3]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner3] + ) self.assertTrue(added, msg) # Verify append - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner1, owner3]) # Test 4: Insert at middle position - owner4 = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Four/'\ - 'emailAddress=owner4@example.com' - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner4], rank=1) + owner4 = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Four/" + "emailAddress=owner4@example.com" + ) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner4], rank=1 + ) self.assertTrue(added, msg) # Verify insertion - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(owners, [owner2, owner4, owner1, owner3]) # Test 5: Add multiple owners at once new_owners = [ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/' - 'emailAddress=owner5@example.com', - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/' - 'emailAddress=owner6@example.com' + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Five/" + "emailAddress=owner5@example.com", + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Owner Six/" + "emailAddress=owner6@example.com", ] - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - new_owners, rank=2) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, new_owners, rank=2 + ) self.assertTrue(added, msg) # Verify multi-insert - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) expected = [owner2, owner4] + new_owners + [owner1, owner3] self.assertEqual(owners, expected) # Test 6: Prevent duplicate owner addition pre_add_count = len(owners) - added, msg = vgrid_add_owners(self.configuration, self.test_vgrid, - [owner1]) + added, msg = vgrid_add_owners( + self.configuration, self.test_vgrid, [owner1] + ) self.assertTrue(added, msg) # Verify no duplicate added - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertEqual(len(owners), pre_add_count) def test_vgrid_is_trigger(self): """Test trigger rule detection""" test_rule = { - 'rule_id': 'test_rule', - 'vgrid_name': self.test_vgrid, - 'path': '*.txt', - 'changes': ['modified'], - 'run_as': self.TEST_OWNER_DN, - 'action': 'copy', - 'arguments': [], - 'match_files': True, - 'match_dirs': False, - 'match_recursive': False, + "rule_id": "test_rule", + "vgrid_name": self.test_vgrid, + "path": "*.txt", + "changes": ["modified"], + "run_as": self.TEST_OWNER_DN, + "action": "copy", + "arguments": [], + "match_files": True, + "match_dirs": False, + "match_recursive": False, } # Add trigger to vgrid with all required fields - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'triggers', [test_rule]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "triggers", [test_rule] + ) self.assertTrue(added, msg) - self.assertTrue(vgrid_is_trigger( - self.test_vgrid, 'test_rule', self.configuration)) + self.assertTrue( + vgrid_is_trigger(self.test_vgrid, "test_rule", self.configuration) + ) def test_vgrid_sharelink_operations(self): """Test sharelink add/remove cycles""" test_share = { - 'share_id': 'test_share', - 'path': '/test/path', - 'access': ['read'], # Must be list type - 'invites': [self.TEST_MEMBER_DN], # Required field - 'single_file': True, # Correct field name (was 'is_dir') - 'expire': '-1', # Optional but included for completeness - 'owner': self.TEST_OWNER_DN, - 'created_timestamp': datetime.datetime.now() + "share_id": "test_share", + "path": "/test/path", + "access": ["read"], # Must be list type + "invites": [self.TEST_MEMBER_DN], # Required field + "single_file": True, # Correct field name (was 'is_dir') + "expire": "-1", # Optional but included for completeness + "owner": self.TEST_OWNER_DN, + "created_timestamp": datetime.datetime.now(), } - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'sharelinks', [test_share]) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "sharelinks", [test_share] + ) self.assertTrue(added, msg) # Test removal - removed, msg = vgrid_remove_entities(self.configuration, self.test_vgrid, - 'sharelinks', ['test_share'], True) + removed, msg = vgrid_remove_entities( + self.configuration, + self.test_vgrid, + "sharelinks", + ["test_share"], + True, + ) self.assertTrue(removed, msg) # TODO: adjust API to allow enabling the next test @@ -637,287 +797,357 @@ def test_vgrid_sharelink_operations(self): def test_vgrid_settings_validation(self): """Test settings key validation""" invalid_settings = [ - ('vgrid_name', self.test_vgrid), # Required field - ('invalid_key', 'value') # Invalid extra field + ("vgrid_name", self.test_vgrid), # Required field + ("invalid_key", "value"), # Invalid extra field ] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', invalid_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "settings", invalid_settings + ) # Should not accept invalid key even with valid required fields self.assertFalse(added) self.assertIn("unknown settings key 'invalid_key'", msg) - status, settings = vgrid_list(self.test_vgrid, 'settings', - self.configuration) + status, settings = vgrid_list( + self.test_vgrid, "settings", self.configuration + ) # Should never save invalid key even with valid required fields self.assertTrue(status) - self.assertNotIn('invalid_key', settings[0]) + self.assertNotIn("invalid_key", settings[0]) def test_vgrid_entity_listing(self): """Test direct entity listing functions""" # Test empty members and one owner listing from init - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertTrue(status) self.assertEqual(len(members), 0) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertTrue(status) self.assertEqual(len(owners), 1) # Populate and verify - vgrid_add_owners(self.configuration, self.test_vgrid, - [self.TEST_OWNER_DN]) - status, owners = vgrid_list(self.test_vgrid, 'owners', - self.configuration) + vgrid_add_owners( + self.configuration, self.test_vgrid, [self.TEST_OWNER_DN] + ) + status, owners = vgrid_list( + self.test_vgrid, "owners", self.configuration + ) self.assertTrue(status) self.assertEqual(owners, [self.TEST_OWNER_DN]) def test_vgrid_add_members_single(self): """Test vgrid_add_owners for initial member""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1]) def test_vgrid_add_member_with_rank(self): """Test member ranking/ordering functionality""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1], rank=0) # Add first owner + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1], rank=0 + ) # Add first owner self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1]) def test_vgrid_add_members_twice(self): """Test vgrid_add_members for initial and secondary member""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2] + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1, member2]) def test_vgrid_add_members_twice_with_rank_zero(self): """Test vgrid_add_members for two members with 2nd inserted first""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2], rank=0) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2], rank=0 + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member1]) def test_vgrid_add_members_thrice_with_rank_one(self): """Test vgrid_add_members for three members with 3rd inserted in middle""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2] + ) self.assertTrue(added, msg) member3 = self.TEST_OWNER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member3], rank=1) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member3], rank=1 + ) self.assertTrue(added, msg) - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1, member3, member2]) def test_vgrid_add_members_lifecycle(self): """Comprehensive life-cycle test for vgrid_add_members functionality""" # Clear existing members to start fresh - reset, msg = vgrid_set_entities(self.configuration, self.test_vgrid, - 'members', [], allow_empty=True) + reset, msg = vgrid_set_entities( + self.configuration, + self.test_vgrid, + "members", + [], + allow_empty=True, + ) self.assertTrue(reset, msg) # Test 1: Add initial member member1 = self.TEST_MEMBER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) # Verify single member - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member1]) # Test 2: Prepend new member member2 = self.TEST_OUTSIDER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member2], rank=0) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member2], rank=0 + ) self.assertTrue(added, msg) # Verify new order - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member1]) # Test 3: Append without rank member3 = self.TEST_OWNER_DN - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member3]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member3] + ) self.assertTrue(added, msg) # Verify append - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member1, member3]) # Test 4: Insert at middle position - member4 = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Four/'\ - 'emailAddress=member4@example.com' - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member4], rank=1) + member4 = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Four/" + "emailAddress=member4@example.com" + ) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member4], rank=1 + ) self.assertTrue(added, msg) # Verify insertion - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(members, [member2, member4, member1, member3]) # Test 5: Add multiple members at once new_members = [ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Five/' - 'emailAddress=member5@example.com', - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Six/' - 'emailAddress=member6@example.com' + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Five/" + "emailAddress=member5@example.com", + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Member Six/" + "emailAddress=member6@example.com", ] - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - new_members, rank=2) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, new_members, rank=2 + ) self.assertTrue(added, msg) # Verify multi-insert - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) expected = [member2, member4] + new_members + [member1, member3] self.assertEqual(members, expected) # Test 6: Prevent duplicate member addition pre_add_count = len(members) - added, msg = vgrid_add_members(self.configuration, self.test_vgrid, - [member1]) + added, msg = vgrid_add_members( + self.configuration, self.test_vgrid, [member1] + ) self.assertTrue(added, msg) # Verify no duplicate added - status, members = vgrid_list(self.test_vgrid, 'members', - self.configuration) + status, members = vgrid_list( + self.test_vgrid, "members", self.configuration + ) self.assertEqual(len(members), pre_add_count) def test_flat_vgrid_name(self): """Test vgrid_flat_name conversion""" - nested_vgrid = 'testvgrid/sub' - expected_flat = '%s' % vgrid_nest_sep.join(['testvgrid', 'sub']) + nested_vgrid = "testvgrid/sub" + expected_flat = "%s" % vgrid_nest_sep.join(["testvgrid", "sub"]) converted = vgrid_flat_name(nested_vgrid, self.configuration) self.assertEqual(converted, expected_flat) - nested_vgrid = 'testvgrid/sub/test' - expected_flat = '%s' % vgrid_nest_sep.join( - ['testvgrid', 'sub', 'test']) + nested_vgrid = "testvgrid/sub/test" + expected_flat = "%s" % vgrid_nest_sep.join(["testvgrid", "sub", "test"]) converted = vgrid_flat_name(nested_vgrid, self.configuration) self.assertEqual(converted, expected_flat) def test_resource_signup_workflow(self): """Test full resource signup workflow""" # Sign up resource - added, msg = vgrid_add_resources(self.configuration, self.test_vgrid, - [self.TEST_RESOURCE_DN]) + added, msg = vgrid_add_resources( + self.configuration, self.test_vgrid, [self.TEST_RESOURCE_DN] + ) self.assertTrue(added, msg) # Verify visibility - matched = vgrid_match_resources(self.test_vgrid, [self.TEST_RESOURCE_DN], - self.configuration) + matched = vgrid_match_resources( + self.test_vgrid, [self.TEST_RESOURCE_DN], self.configuration + ) self.assertEqual(matched, [self.TEST_RESOURCE_DN]) def test_multi_level_inheritance(self): """Test settings propagation through multiple vgrid levels""" # Create grandchild vgrid - grandchild = os.path.join(self.test_subvgrid, 'grandchild') + grandchild = os.path.join(self.test_subvgrid, "grandchild") grandchild_path = os.path.join( - self.configuration.vgrid_home, grandchild) + self.configuration.vgrid_home, grandchild + ) ensure_dirs_exist(grandchild_path) # Set valid inherited setting at top level with required vgrid_name top_settings = [ - ('vgrid_name', self.test_vgrid), - ('hidden', True) # Valid inherited field with boolean value + ("vgrid_name", self.test_vgrid), + ("hidden", True), # Valid inherited field with boolean value ] - added, msg = vgrid_add_entities(self.configuration, self.test_vgrid, - 'settings', top_settings) + added, msg = vgrid_add_entities( + self.configuration, self.test_vgrid, "settings", top_settings + ) self.assertTrue(added, msg) # Verify grandchild inheritance using 'hidden' field inherit=true - status, settings = vgrid_settings(grandchild, self.configuration, - recursive=True, as_dict=True) + status, settings = vgrid_settings( + grandchild, self.configuration, recursive=True, as_dict=True + ) self.assertTrue(status) - self.assertEqual(settings.get('hidden'), True) + self.assertEqual(settings.get("hidden"), True) # Verify vgrid_name is preserved - self.assertEqual(settings['vgrid_name'], grandchild) + self.assertEqual(settings["vgrid_name"], grandchild) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires tweaking of funcion") def test_workflow_job_priority(self): """Test workflow job queue ordering and limits""" # Create max jobs + 1 - job_entries = [{ - 'vgrid_name': self.test_vgrid, - 'client_id': self.TEST_OWNER_DN, - 'job_id': str(i), - 'run_as': self.TEST_OWNER_DN, # Required field - 'exe': '/bin/echo', # Required job field - 'arguments': ['Test job'], # Required job field - } for i in range(101)] + job_entries = [ + { + "vgrid_name": self.test_vgrid, + "client_id": self.TEST_OWNER_DN, + "job_id": str(i), + "run_as": self.TEST_OWNER_DN, # Required field + "exe": "/bin/echo", # Required job field + "arguments": ["Test job"], # Required job field + } + for i in range(101) + ] added, msg = vgrid_add_workflow_jobs( - self.configuration, - self.test_vgrid, - job_entries + self.configuration, self.test_vgrid, job_entries ) self.assertTrue(added, msg) status, jobs = vgrid_list( - self.test_vgrid, 'jobqueue', self.configuration) + self.test_vgrid, "jobqueue", self.configuration + ) self.assertTrue(status) # Should stay at max 100 by removing oldest self.assertEqual(len(jobs), 100) - self.assertEqual(jobs[-1]['job_id'], '100') # Newest at end + self.assertEqual(jobs[-1]["job_id"], "100") # Newest at end class TestMigSharedVgrid__legacy_main(MigTestCase): @@ -925,14 +1155,17 @@ class TestMigSharedVgrid__legacy_main(MigTestCase): def test_existing_main(self): """Run the legacy self-tests directly in module""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -942,5 +1175,5 @@ def record_last_print(value): legacy_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_vgridaccess.py b/tests/test_mig_shared_vgridaccess.py index 1dbe2a086..6bfe086d3 100644 --- a/tests/test_mig_shared_vgridaccess.py +++ b/tests/test_mig_shared_vgridaccess.py @@ -35,92 +35,157 @@ import mig.shared.vgridaccess as vgridaccess from mig.shared.fileio import pickle, read_file from mig.shared.vgrid import vgrid_list, vgrid_set_entities, vgrid_settings -from mig.shared.vgridaccess import CONF, MEMBERS, OWNERS, RESOURCES, SETTINGS, \ - USERID, USERS, VGRIDS, check_resources_modified, check_vgrid_access, \ - check_vgrids_modified, fill_placeholder_cache, force_update_resource_map, \ - force_update_user_map, force_update_vgrid_map, get_re_provider_map, \ - get_resource_map, get_user_map, get_vgrid_map, get_vgrid_map_vgrids, \ - is_vgrid_parent_placeholder, load_resource_map, load_user_map, \ - load_vgrid_map, mark_vgrid_modified, refresh_resource_map, \ - refresh_user_map, refresh_vgrid_map, res_vgrid_access, \ - reset_resources_modified, reset_vgrids_modified, resources_using_re, \ - unmap_inheritance, unmap_resource, unmap_vgrid, user_allowed_res_confs, \ - user_allowed_res_exes, user_allowed_res_stores, user_allowed_res_units, \ - user_allowed_user_confs, user_owned_res_exes, user_owned_res_stores, \ - user_vgrid_access, user_visible_res_confs, user_visible_res_exes, \ - user_visible_res_stores, user_visible_user_confs, vgrid_inherit_map +from mig.shared.vgridaccess import ( + CONF, + MEMBERS, + OWNERS, + RESOURCES, + SETTINGS, + USERID, + USERS, + VGRIDS, + check_resources_modified, + check_vgrid_access, + check_vgrids_modified, + fill_placeholder_cache, + force_update_resource_map, + force_update_user_map, + force_update_vgrid_map, + get_re_provider_map, + get_resource_map, + get_user_map, + get_vgrid_map, + get_vgrid_map_vgrids, + is_vgrid_parent_placeholder, + load_resource_map, + load_user_map, + load_vgrid_map, + mark_vgrid_modified, + refresh_resource_map, + refresh_user_map, + refresh_vgrid_map, + res_vgrid_access, + reset_resources_modified, + reset_vgrids_modified, + resources_using_re, + unmap_inheritance, + unmap_resource, + unmap_vgrid, + user_allowed_res_confs, + user_allowed_res_exes, + user_allowed_res_stores, + user_allowed_res_units, + user_allowed_user_confs, + user_owned_res_exes, + user_owned_res_stores, + user_vgrid_access, + user_visible_res_confs, + user_visible_res_exes, + user_visible_res_stores, + user_visible_user_confs, + vgrid_inherit_map, +) from tests.support import MigTestCase, ensure_dirs_exist, testmain -from tests.support.usersupp import UserAssertMixin, TEST_USER_DN +from tests.support.usersupp import TEST_USER_DN, UserAssertMixin class TestMigSharedVgridAccess(MigTestCase, UserAssertMixin): """Unit tests for vgridaccess related helper functions""" - TEST_OWNER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/'\ - 'emailAddress=owner@example.org' - TEST_MEMBER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/'\ - 'emailAddress=member@example.org' - TEST_OUTSIDER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/'\ - 'emailAddress=outsider@example.com' - TEST_RESOURCE_ID = 'test.example.org.0' - TEST_VGRID_NAME = 'testvgrid' - - TEST_OWNER_UUID = 'ff326a2b984828d9b32077c9b0b35a05' - TEST_MEMBER_UUID = 'ea9aedcbe69db279ca3676f83de94669' - TEST_RESOURCE_ALIAS = '0835f310d6422c36e33eeb7d0d3e9cf5' + TEST_OWNER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/" + "emailAddress=owner@example.org" + ) + TEST_MEMBER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/" + "emailAddress=member@example.org" + ) + TEST_OUTSIDER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/" + "emailAddress=outsider@example.com" + ) + TEST_RESOURCE_ID = "test.example.org.0" + TEST_VGRID_NAME = "testvgrid" + + TEST_OWNER_UUID = "ff326a2b984828d9b32077c9b0b35a05" + TEST_MEMBER_UUID = "ea9aedcbe69db279ca3676f83de94669" + TEST_RESOURCE_ALIAS = "0835f310d6422c36e33eeb7d0d3e9cf5" # Default vgrid is initially set up without settings when force loaded - MINIMAL_VGRIDS = {'Generic': {OWNERS: [], MEMBERS: [], RESOURCES: [], - SETTINGS: []}} + MINIMAL_VGRIDS = { + "Generic": {OWNERS: [], MEMBERS: [], RESOURCES: [], SETTINGS: []} + } def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' - - def _create_vgrid(self, vgrid_name, *, owners=None, members=None, - resources=None, settings=None, triggers=None): + return "testconfig" + + def _create_vgrid( + self, + vgrid_name, + *, + owners=None, + members=None, + resources=None, + settings=None, + triggers=None + ): """Helper to create valid skeleton vgrid for testing""" vgrid_path = os.path.join(self.configuration.vgrid_home, vgrid_name) ensure_dirs_exist(vgrid_path) # Save vgrid owners, members, resources, settings and triggers if owners is None: owners = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'owners', owners, allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, vgrid_name, "owners", owners, allow_empty=True + ) self.assertEqual(success_and_msg, (True, "")) if members is None: members = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'members', members, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "members", + members, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if resources is None: resources = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'resources', resources, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "resources", + resources, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if settings is None: - settings = [('vgrid_name', vgrid_name)] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'settings', settings, - allow_empty=True) + settings = [("vgrid_name", vgrid_name)] + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "settings", + settings, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if triggers is None: triggers = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'triggers', triggers, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "triggers", + triggers, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) def _create_resource(self, res_name, owners, config=None): """Helper to create valid skeleton resource for testing""" res_path = os.path.join(self.configuration.resource_home, res_name) - res_owners_path = os.path.join(res_path, 'owners') - res_config_path = os.path.join(res_path, 'config') + res_owners_path = os.path.join(res_path, "owners") + res_config_path = os.path.join(res_path, "config") # Add resource skeleton with owners ensure_dirs_exist(res_path) if owners is None: @@ -129,10 +194,11 @@ def _create_resource(self, res_name, owners, config=None): self.assertTrue(saved) if config is None: # Make sure conf has one valid field - config = {'HOSTURL': res_name, - 'EXECONFIG': [{'name': 'exe', 'vgrid': ['Generic']}], - 'STORECONFIG': [{'name': 'exe', 'vgrid': ['Generic']}] - } + config = { + "HOSTURL": res_name, + "EXECONFIG": [{"name": "exe", "vgrid": ["Generic"]}], + "STORECONFIG": [{"name": "exe", "vgrid": ["Generic"]}], + } saved = pickle(config, res_config_path, self.logger) self.assertTrue(saved) @@ -234,8 +300,10 @@ def test_force_update_vgrid_map(self): updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) self.assertTrue(updated_vgrid_map) - self.assertNotEqual(len(vgrid_map_before.get(VGRIDS, {})), - len(updated_vgrid_map.get(VGRIDS, {}))) + self.assertNotEqual( + len(vgrid_map_before.get(VGRIDS, {})), + len(updated_vgrid_map.get(VGRIDS, {})), + ) self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) def test_refresh_user_map(self): @@ -339,7 +407,7 @@ def test_get_vgrid_map_vgrids(self): vgrid_list = get_vgrid_map_vgrids(self.configuration) self.assertTrue(isinstance(vgrid_list, list)) - self.assertEqual(['Generic'], vgrid_list) + self.assertEqual(["Generic"], vgrid_list) def test_user_owned_res_exes(self): """Test user_owned_res_exes returns owned execution nodes""" @@ -367,7 +435,8 @@ def test_user_allowed_res_units(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) allowed = user_allowed_res_units( - self.configuration, self.TEST_OWNER_DN, "exe") + self.configuration, self.TEST_OWNER_DN, "exe" + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) @@ -394,7 +463,8 @@ def test_user_allowed_res_stores(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) allowed = user_allowed_res_stores( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) @@ -419,23 +489,29 @@ def test_user_visible_res_stores(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) visible = user_visible_res_stores( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(visible, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, visible) def test_user_allowed_user_confs(self): """Test user_allowed_user_confs returns allowed user confs""" - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) - - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) + + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_user_map(self.configuration) allowed = user_allowed_user_confs( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_OWNER_UUID, allowed) self.assertIn(self.TEST_MEMBER_UUID, allowed) @@ -443,32 +519,36 @@ def test_user_allowed_user_confs(self): def test_fill_placeholder_cache(self): """Test fill_placeholder_cache populates cache""" cache = {} - fill_placeholder_cache(self.configuration, cache, [ - self.TEST_VGRID_NAME]) + fill_placeholder_cache( + self.configuration, cache, [self.TEST_VGRID_NAME] + ) self.assertIn(self.TEST_VGRID_NAME, cache) def test_is_vgrid_parent_placeholder(self): """Test is_vgrid_parent_placeholder detection""" - test_path = os.path.join(self.configuration.user_home, 'testvgrid') - result = is_vgrid_parent_placeholder(self.configuration, test_path, - test_path) + test_path = os.path.join(self.configuration.user_home, "testvgrid") + result = is_vgrid_parent_placeholder( + self.configuration, test_path, test_path + ) self.assertIsNone(result) def test_resources_using_re_notfound(self): """Test RE with no assigned resources returns empty list""" # Nonexistent RE should have no resources - res_list = resources_using_re(self.configuration, 'NoSuchRE') + res_list = resources_using_re(self.configuration, "NoSuchRE") self.assertEqual(res_list, []) def test_vgrid_inherit_map_single(self): """Test inheritance mapping with single vgrid""" - test_settings = [('vgrid_name', self.TEST_VGRID_NAME), - ('hidden', True)] + test_settings = [ + ("vgrid_name", self.TEST_VGRID_NAME), + ("hidden", True), + ] test_map = { VGRIDS: { self.TEST_VGRID_NAME: { SETTINGS: test_settings, - OWNERS: [self.TEST_OWNER_DN] + OWNERS: [self.TEST_OWNER_DN], } } } @@ -477,13 +557,13 @@ def test_vgrid_inherit_map_single(self): self.assertIn(self.TEST_VGRID_NAME, vgrid_data) settings_dict = dict(vgrid_data[self.TEST_VGRID_NAME][SETTINGS]) self.assertIs(type(settings_dict), dict) - self.assertEqual(settings_dict.get('hidden'), True) + self.assertEqual(settings_dict.get("hidden"), True) # TODO: move these two modified tests to a test_mig_shared_modified.py def test_check_vgrids_modified_initial(self): """Verify initial modified vgrids list marks ALL and empty on reset""" modified, stamp = check_vgrids_modified(self.configuration) - self.assertEqual(modified, ['ALL']) + self.assertEqual(modified, ["ALL"]) reset_vgrids_modified(self.configuration) modified, stamp = check_vgrids_modified(self.configuration) self.assertEqual(modified, []) @@ -506,10 +586,9 @@ def test_user_vgrid_access(self): self._provision_test_user(self, TEST_USER_DN) # Start with global access to default vgrid - allowed_vgrids = user_vgrid_access(self.configuration, - TEST_USER_DN) + allowed_vgrids = user_vgrid_access(self.configuration, TEST_USER_DN) - self.assertIn('Generic', allowed_vgrids) + self.assertIn("Generic", allowed_vgrids) self.assertTrue(len(allowed_vgrids), 1) # Create private vgrid self._create_vgrid(self.TEST_VGRID_NAME, owners=[TEST_USER_DN]) @@ -517,20 +596,21 @@ def test_user_vgrid_access(self): initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) - allowed_vgrids = user_vgrid_access(self.configuration, - TEST_USER_DN) + allowed_vgrids = user_vgrid_access(self.configuration, TEST_USER_DN) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) def test_res_vgrid_access(self): """Minimal test for resource vgrid participation""" # Only Generic access initially allowed_vgrids = res_vgrid_access( - self.configuration, self.TEST_RESOURCE_ID) - self.assertEqual(allowed_vgrids, ['Generic']) + self.configuration, self.TEST_RESOURCE_ID + ) + self.assertEqual(allowed_vgrids, ["Generic"]) # Add to vgrid self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, resources=[ - self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, resources=[self.TEST_RESOURCE_ID] + ) # Refresh maps to reflect new content initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) @@ -550,33 +630,40 @@ def test_vgrid_map_refresh(self): self._verify_vgrid_map_integrity(updated_vgrid_map) vgrids = updated_vgrid_map.get(VGRIDS, {}) self.assertIn(self.TEST_VGRID_NAME, vgrids) - self.assertEqual(vgrids[self.TEST_VGRID_NAME] - [OWNERS], [self.TEST_OWNER_DN]) + self.assertEqual( + vgrids[self.TEST_VGRID_NAME][OWNERS], [self.TEST_OWNER_DN] + ) def test_user_map_access(self): """Test user permissions through cached access maps""" # Add user as member - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Verify member access - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) def test_resource_map_update(self): """Verify resource visibility in cache""" # Check cached resource map does not yet contain entry - res_map_before, _ = load_resource_map(self.configuration, - caching=True) + res_map_before, _ = load_resource_map(self.configuration, caching=True) self.assertEqual(res_map_before, {}) # Add vgrid with assigned resource self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + resources=[self.TEST_RESOURCE_ID], + ) updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) @@ -594,11 +681,13 @@ def test_resource_map_update(self): def test_settings_inheritance(self): """Test inherited settings propagation through cached maps""" # Create top and sub vgrids with 'hidden' setting on top vgrid - top_settings = [('vgrid_name', self.TEST_VGRID_NAME), - ('hidden', True)] - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - settings=top_settings) - sub_vgrid = os.path.join(self.TEST_VGRID_NAME, 'subvgrid') + top_settings = [("vgrid_name", self.TEST_VGRID_NAME), ("hidden", True)] + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + settings=top_settings, + ) + sub_vgrid = os.path.join(self.TEST_VGRID_NAME, "subvgrid") self._create_vgrid(sub_vgrid) # Force refresh of cached map @@ -618,7 +707,7 @@ def test_settings_inheritance(self): self.assertTrue(top_settings_dict) # Verify hidden setting in cache - self.assertEqual(top_settings_dict.get('hidden'), True) + self.assertEqual(top_settings_dict.get("hidden"), True) # Retrieve sub vgrid settings from cached map sub_vgrid_data = vgrid_data.get(sub_vgrid, {}) @@ -626,10 +715,9 @@ def test_settings_inheritance(self): sub_settings_dict = dict(sub_vgrid_data.get(SETTINGS, [])) # Verify hidden setting unset without inheritance - self.assertFalse(sub_settings_dict.get('hidden')) + self.assertFalse(sub_settings_dict.get("hidden")) - inherited_map = vgrid_inherit_map( - self.configuration, updated_vgrid_map) + inherited_map = vgrid_inherit_map(self.configuration, updated_vgrid_map) vgrid_data = inherited_map.get(VGRIDS, {}) self.assertTrue(vgrid_data) @@ -639,12 +727,12 @@ def test_settings_inheritance(self): sub_settings_dict = dict(sub_vgrid_data.get(SETTINGS, [])) # Verify hidden setting inheritance - self.assertEqual(sub_settings_dict.get('hidden'), True) + self.assertEqual(sub_settings_dict.get("hidden"), True) def test_unmap_inheritance(self): """Test unmap_inheritance clears inherited mappings""" self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN]) - sub_vgrid = os.path.join(self.TEST_VGRID_NAME, 'subvgrid') + sub_vgrid = os.path.join(self.TEST_VGRID_NAME, "subvgrid") self._create_vgrid(sub_vgrid) # Force refresh of cached map @@ -653,8 +741,9 @@ def test_unmap_inheritance(self): self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) # Unmap and verify mark modified - unmap_inheritance(self.configuration, self.TEST_VGRID_NAME, - self.TEST_OWNER_DN) + unmap_inheritance( + self.configuration, self.TEST_VGRID_NAME, self.TEST_OWNER_DN + ) modified, stamp = check_vgrids_modified(self.configuration) self.assertEqual(modified, [self.TEST_VGRID_NAME, sub_vgrid]) @@ -662,8 +751,9 @@ def test_unmap_inheritance(self): def test_user_map_fields(self): """Verify user map includes complete profile/settings data""" # First add a couple of test users - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) # Force fresh user map initial_vgrid_map = force_update_vgrid_map(self.configuration) @@ -680,8 +770,11 @@ def test_resource_revoked_access(self): """Verify resource removal propagates through cached maps""" # First add resource and vgrid self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + resources=[self.TEST_RESOURCE_ID], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) @@ -697,9 +790,14 @@ def test_resource_revoked_access(self): self.assertIn(self.TEST_RESOURCE_ID, initial_map) # Remove resource assignment from vgrid - success_and_msg = vgrid_set_entities(self.configuration, self.TEST_VGRID_NAME, - 'resources', [], allow_empty=True) - self.assertEqual(success_and_msg, (True, '')) + success_and_msg = vgrid_set_entities( + self.configuration, + self.TEST_VGRID_NAME, + "resources", + [], + allow_empty=True, + ) + self.assertEqual(success_and_msg, (True, "")) updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) @@ -717,9 +815,9 @@ def test_resource_revoked_access(self): def test_non_recursive_inheritance(self): """Verify non-recursive map excludes nested vgrids""" # Create parent+child vgrids - parent_vgrid = 'parent' + parent_vgrid = "parent" self._create_vgrid(parent_vgrid, owners=[self.TEST_OWNER_DN]) - child_vgrid = os.path.join(parent_vgrid, 'child') + child_vgrid = os.path.join(parent_vgrid, "child") self._create_vgrid(child_vgrid, members=[self.TEST_MEMBER_DN]) # Force update to avoid auto caching and get non-recursive map @@ -733,21 +831,25 @@ def test_non_recursive_inheritance(self): # Child should still appear when non-recursive but just not inherit self.assertIn(child_vgrid, vgrid_map.get(VGRIDS, {})) # Check owners and members to verify they aren't inherited - self.assertEqual(vgrid_map[VGRIDS][parent_vgrid][OWNERS], - [self.TEST_OWNER_DN]) + self.assertEqual( + vgrid_map[VGRIDS][parent_vgrid][OWNERS], [self.TEST_OWNER_DN] + ) self.assertEqual(len(vgrid_map[VGRIDS][parent_vgrid][MEMBERS]), 0) self.assertEqual(len(vgrid_map[VGRIDS][child_vgrid][OWNERS]), 0) - self.assertEqual(vgrid_map[VGRIDS][child_vgrid][MEMBERS], - [self.TEST_MEMBER_DN]) + self.assertEqual( + vgrid_map[VGRIDS][child_vgrid][MEMBERS], [self.TEST_MEMBER_DN] + ) def test_hidden_setting_propagation(self): """Verify hidden=True propagates to not infect parent settings""" - parent_vgrid = 'parent' + parent_vgrid = "parent" self._create_vgrid(parent_vgrid, owners=[self.TEST_OWNER_DN]) - child_vgrid = os.path.join(parent_vgrid, 'child') - self._create_vgrid(child_vgrid, owners=[self.TEST_OWNER_DN], - settings=[('vgrid_name', child_vgrid), - ('hidden', True)]) + child_vgrid = os.path.join(parent_vgrid, "child") + self._create_vgrid( + child_vgrid, + owners=[self.TEST_OWNER_DN], + settings=[("vgrid_name", child_vgrid), ("hidden", True)], + ) # Verify parent remains visible in cache updated_vgrid_map = force_update_vgrid_map(self.configuration) @@ -756,64 +858,79 @@ def test_hidden_setting_propagation(self): self.assertIn(child_vgrid, updated_vgrid_map.get(VGRIDS, {})) parent_data = updated_vgrid_map.get(VGRIDS, {}).get(parent_vgrid, {}) parent_settings = dict(parent_data.get(SETTINGS, [])) - self.assertNotEqual(parent_settings.get('hidden'), True) + self.assertNotEqual(parent_settings.get("hidden"), True) def test_default_vgrid_access(self): """Verify special access rules for default vgrid""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Even non-member should have access to default vgrid - participant = check_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN, - 'Generic') + participant = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, "Generic" + ) self.assertFalse(participant) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN) - self.assertIn('Generic', allowed_vgrids) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN + ) + self.assertIn("Generic", allowed_vgrids) # Invalid vgrid should not allow any participation or access - participant = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - 'invalid-vgrid-name') + participant = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, "invalid-vgrid-name" + ) self.assertFalse(participant) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_MEMBER_DN) - self.assertNotIn('invalid-vgrid-name', allowed_vgrids) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_MEMBER_DN + ) + self.assertNotIn("invalid-vgrid-name", allowed_vgrids) def test_general_vgrid_access(self): """Verify general access rules for vgrids""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Test vgrid must allow owner and members access - allowed = check_vgrid_access(self.configuration, self.TEST_OWNER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OWNER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OWNER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OWNER_DN + ) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_MEMBER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_MEMBER_DN + ) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) # Test vgrid must reject allow outsider access - allowed = check_vgrid_access(self.configuration, self.TEST_OUTSIDER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, self.TEST_VGRID_NAME + ) self.assertFalse(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN + ) self.assertNotIn(self.TEST_VGRID_NAME, allowed_vgrids) def test_user_allowed_res_confs(self): @@ -821,44 +938,49 @@ def test_user_allowed_res_confs(self): # Create test user and add test resource to vgrid self._provision_test_user(self, TEST_USER_DN) self._create_resource(self.TEST_RESOURCE_ID, [TEST_USER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[TEST_USER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[TEST_USER_DN], + resources=[self.TEST_RESOURCE_ID], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) # Owner should be allowed access - allowed = user_allowed_res_confs(self.configuration, - TEST_USER_DN) + allowed = user_allowed_res_confs(self.configuration, TEST_USER_DN) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) def test_user_visible_res_confs(self): """Minimal test for user_visible_res_confs""" # Owner should see owned resources even without vgrid access - self._create_resource(self.TEST_RESOURCE_ID, - owners=[self.TEST_OWNER_DN]) + self._create_resource( + self.TEST_RESOURCE_ID, owners=[self.TEST_OWNER_DN] + ) force_update_resource_map(self.configuration) - visible = user_visible_res_confs( - self.configuration, self.TEST_OWNER_DN) + visible = user_visible_res_confs(self.configuration, self.TEST_OWNER_DN) self.assertIn(self.TEST_RESOURCE_ALIAS, visible) def test_user_visible_user_confs(self): """Minimal test for user_visible_user_confs""" # Owners should see themselves in auto map # NOTE: use provision users to skip fixtures here - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) force_update_user_map(self.configuration) visible = user_visible_user_confs( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertIn(self.TEST_OWNER_UUID, visible) def test_get_re_provider_map(self): """Test RE provider map includes test resource""" - test_re = 'Python' - res_config = {'RUNTIMEENVIRONMENT': [(test_re, '/python/path')]} - self._create_resource(self.TEST_RESOURCE_ID, [ - self.TEST_OWNER_DN], res_config) + test_re = "Python" + res_config = {"RUNTIMEENVIRONMENT": [(test_re, "/python/path")]} + self._create_resource( + self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN], res_config + ) # Update maps to include new resource force_update_resource_map(self.configuration) @@ -870,10 +992,11 @@ def test_get_re_provider_map(self): def test_resources_using_re(self): """Test finding resources with specific runtime environment""" - test_re = 'Bash' - res_config = {'RUNTIMEENVIRONMENT': [(test_re, '/bash/path')]} - self._create_resource(self.TEST_RESOURCE_ID, [ - self.TEST_OWNER_DN], res_config) + test_re = "Bash" + res_config = {"RUNTIMEENVIRONMENT": [(test_re, "/bash/path")]} + self._create_resource( + self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN], res_config + ) # Refresh resource map force_update_resource_map(self.configuration) @@ -909,28 +1032,32 @@ def test_unmap_resource(self): def test_access_nonexistent_vgrid(self): """Ensure checks fail cleanly for non-existent vgrid""" - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - 'no-such-vgrid') + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, "no-such-vgrid" + ) self.assertFalse(allowed) # Should not appear in allowed vgrids allowed_vgrids = user_vgrid_access( - self.configuration, self.TEST_MEMBER_DN) - self.assertNotIn('no-such-vgrid', allowed_vgrids) + self.configuration, self.TEST_MEMBER_DN + ) + self.assertNotIn("no-such-vgrid", allowed_vgrids) def test_empty_member_access(self): """Verify members-only vgrid rejects outsiders""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, owners=[], members=[self.TEST_MEMBER_DN] + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Outsider should be blocked despite no owners - allowed = check_vgrid_access(self.configuration, self.TEST_OUTSIDER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, self.TEST_VGRID_NAME + ) self.assertFalse(allowed) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_unittest_testcore.py b/tests/test_mig_unittest_testcore.py index b27a74d33..127f77fe7 100644 --- a/tests/test_mig_unittest_testcore.py +++ b/tests/test_mig_unittest_testcore.py @@ -31,28 +31,28 @@ import os import sys -from tests.support import MigTestCase, testmain - from mig.unittest.testcore import main as testcore_main +from tests.support import MigTestCase, testmain class MigUnittestTestcore(MigTestCase): def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_existing_main(self): def raise_on_error_exit(exit_code, identifying_message=None): if exit_code != 0: if identifying_message is None: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) - print("") # account for wrapped tests printing to console + print("") # account for wrapped tests printing to console testcore_main(self.configuration, _exit=raise_on_error_exit) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_wsgibin.py b/tests/test_mig_wsgibin.py index 1d0f9ecdb..3c487ad60 100644 --- a/tests/test_mig_wsgibin.py +++ b/tests/test_mig_wsgibin.py @@ -38,12 +38,24 @@ # Imports required for the unit test wrapping import mig.shared.returnvalues as returnvalues -from mig.shared.base import allow_script, brief_list, client_dir_id, \ - client_id_dir, get_short_id, invisible_path +from mig.shared.base import ( + allow_script, + brief_list, + client_dir_id, + client_id_dir, + get_short_id, + invisible_path, +) from mig.shared.compat import SimpleNamespace + # Imports required for the unit tests themselves -from tests.support import MIG_BASE, MigTestCase, ensure_dirs_exist, \ - is_path_within, testmain +from tests.support import ( + MIG_BASE, + MigTestCase, + ensure_dirs_exist, + is_path_within, + testmain, +) from tests.support.snapshotsupp import SnapshotAssertMixin from tests.support.wsgisupp import WsgiAssertMixin, prepare_wsgi @@ -60,12 +72,12 @@ def __init__(self): def handle_decl(self, decl): try: - decltag, decltype = decl.split(' ') + decltag, decltype = decl.split(" ") except Exception: decltag = "" decltype = "" - if decltag.upper() == 'DOCTYPE': + if decltag.upper() == "DOCTYPE": self._saw_doctype = True else: decltype = "unknown" @@ -73,11 +85,11 @@ def handle_decl(self, decl): self._doctype = decltype def handle_starttag(self, tag, attrs): - if tag == 'html': + if tag == "html": if self._saw_tags: - tag_html = 'not_first' + tag_html = "not_first" else: - tag_html = 'was_first' + tag_html = "was_first" self._tag_html = tag_html self._saw_tags = True @@ -85,13 +97,13 @@ def assert_basics(self): if not self._saw_doctype: raise AssertionError("missing DOCTYPE") - if self._doctype != 'html': + if self._doctype != "html": raise AssertionError("non-html DOCTYPE") - if self._tag_html == 'none': + if self._tag_html == "none": raise AssertionError("missing ") - if self._tag_html != 'was_first': + if self._tag_html != "was_first": raise AssertionError("first tag seen was not ") @@ -110,13 +122,13 @@ def handle_data(self, *args, **kwargs): def handle_starttag(self, tag, attrs): DocumentBasicsHtmlParser.handle_starttag(self, tag, attrs) - if tag == 'title': + if tag == "title": self._within_title = True def handle_endtag(self, tag): DocumentBasicsHtmlParser.handle_endtag(self, tag) - if tag == 'title': + if tag == "title": self._within_title = False def title(self, trim_newlines=False): @@ -138,7 +150,7 @@ def _import_forcibly(module_name, relative_module_dir=None): that resides within a non-module directory. """ - module_path = os.path.join(MIG_BASE, 'mig') + module_path = os.path.join(MIG_BASE, "mig") if relative_module_dir is not None: module_path = os.path.join(module_path, relative_module_dir) sys.path.append(module_path) @@ -148,7 +160,7 @@ def _import_forcibly(module_name, relative_module_dir=None): # Imports of the code under test (indirect import needed here) -migwsgi = _import_forcibly('migwsgi', relative_module_dir='wsgi-bin') +migwsgi = _import_forcibly("migwsgi", relative_module_dir="wsgi-bin") class FakeBackend: @@ -160,8 +172,8 @@ class FakeBackend: def __init__(self): self.output_objects = [ - {'object_type': 'start'}, - {'object_type': 'title', 'text': 'ERROR'}, + {"object_type": "start"}, + {"object_type": "title", "text": "ERROR"}, ] self.return_value = returnvalues.ERROR @@ -175,6 +187,7 @@ def set_response(self, output_objects, returnvalue): def to_import_module(self): def _import_module(module_path): return self + return _import_module @@ -182,11 +195,11 @@ class MigWsgibin(MigTestCase, SnapshotAssertMixin, WsgiAssertMixin): """WSGI glue test cases""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/') + self.fake_wsgi = prepare_wsgi(self.configuration, "http://localhost/") self.application_args = ( self.fake_wsgi.environ, @@ -208,51 +221,45 @@ def assertHtmlTitle(self, value, title_text=None, trim_newlines=False): def test_top_level_request_returns_status_ok(self): wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) def test_objects_containing_only_title_has_expected_title(self): - output_objects = [ - {'object_type': 'title', 'text': 'TEST'} - ] + output_objects = [{"object_type": "title", "text": "TEST"}] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) - self.assertHtmlTitle(output, title_text='TEST', trim_newlines=True) + self.assertHtmlTitle(output, title_text="TEST", trim_newlines=True) def test_objects_containing_only_title_matches_snapshot(self): - output_objects = [ - {'object_type': 'title', 'text': 'TEST'} - ] + output_objects = [{"object_type": "title", "text": "TEST"}] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) - self.assertSnapshot(output, extension='html') + self.assertSnapshot(output, extension="html") -class MigWsgibin_output_objects(MigTestCase, WsgiAssertMixin, - SnapshotAssertMixin): +class MigWsgibin_output_objects( + MigTestCase, WsgiAssertMixin, SnapshotAssertMixin +): """Unit tests for output_object related part of wsgi functions.""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/') + self.fake_wsgi = prepare_wsgi(self.configuration, "http://localhost/") self.application_args = ( self.fake_wsgi.environ, @@ -273,51 +280,46 @@ def test_unknown_object_type_generates_valid_error_page(self): self.logger.forgive_errors() output_objects = [ { - 'object_type': 'nonexistent', # trigger error handling path + "object_type": "nonexistent", # trigger error handling path } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) - output, _ = self.assertWsgiResponse( - wsgi_result, self.fake_wsgi, 200) + output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) self.assertIsValidHtmlDocument(output) def test_objects_with_type_text(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) self.assertSnapshotOfHtmlContent(output) -class MigWsgibin_input_object(MigTestCase, WsgiAssertMixin, - SnapshotAssertMixin): +class MigWsgibin_input_object( + MigTestCase, WsgiAssertMixin, SnapshotAssertMixin +): """Unit tests for input_object related part of wsgi functions.""" - DUMMY_BYTES = 'dummyæøå-ßßß-value'.encode('utf-8') + DUMMY_BYTES = "dummyæøå-ßßß-value".encode("utf-8") def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() @@ -330,8 +332,9 @@ def _prepare_test(self, form_overrides=None, custom_env=None): # Set up a wsgi input with non-ascii bytes and open it in binary mode # If form_overrides is passed a list of tuples like [('key' 'val')] it # produces a fake_wsgi input on the form: b'key=val' - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/', - form=form_overrides) + self.fake_wsgi = prepare_wsgi( + self.configuration, "http://localhost/", form=form_overrides + ) # override the default environ fields from wsgisupp if custom_env: self.fake_wsgi.environ.update(custom_env) @@ -348,7 +351,7 @@ def _prepare_test(self, form_overrides=None, custom_env=None): # NOTE: enabled with underlying wsgi use of Fieldstorage fixed def test_put_text_plain_with_binary_input_succeeds(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = { "REQUEST_METHOD": "PUT", "CONTENT_TYPE": "text/plain", @@ -358,20 +361,16 @@ def test_put_text_plain_with_binary_input_succeeds(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must succeed with HTTP 200 when it parses input @@ -379,7 +378,7 @@ def test_put_text_plain_with_binary_input_succeeds(self): @unittest.skip("disabled with underlying wsgi use of Fieldstorage fixed") def test_put_text_plain_with_binary_input_fails(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = { "REQUEST_METHOD": "PUT", "CONTENT_TYPE": "text/plain", @@ -389,52 +388,44 @@ def test_put_text_plain_with_binary_input_fails(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) # TODO: can we add assertLogs to check error log explicitly? wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must fail with HTTP 500 from failing to parse input output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 500) def test_post_url_encoded_with_binary_input_succeeds(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = None self._prepare_test(test_form, test_env) output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must succeed with HTTP 200 when it parses input output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_support.py b/tests/test_support.py index 56b74f594..b68c4249c 100644 --- a/tests/test_support.py +++ b/tests/test_support.py @@ -28,15 +28,21 @@ """Unit tests for the tests module pointed to in the filename""" from __future__ import print_function + import os import sys import unittest -from tests.support import MigTestCase, PY2, testmain, temppath, \ - AssertOver, FakeConfiguration - from mig.shared.conf import get_configuration_object from mig.shared.configuration import Configuration +from tests.support import ( + PY2, + AssertOver, + FakeConfiguration, + MigTestCase, + temppath, + testmain, +) class InstrumentedAssertOver(AssertOver): @@ -62,6 +68,7 @@ def to_check_callable(self): def _wrapped_check_callable(): self._check_callable_called = True _check_callable() + self._check_callable = _wrapped_check_callable return _wrapped_check_callable @@ -71,8 +78,8 @@ class SupportTestCase(MigTestCase): def _class_attribute(self, name, **kwargs): cls = type(self) - if 'value' in kwargs: - setattr(cls, name, kwargs['value']) + if "value" in kwargs: + setattr(cls, name, kwargs["value"]) else: return getattr(cls, name, None) @@ -80,15 +87,17 @@ def test_requires_requesting_a_configuration(self): with self.assertRaises(AssertionError) as raised: self.configuration theexception = raised.exception - self.assertEqual(str(theexception), - "configuration access but testcase did not request it") + self.assertEqual( + str(theexception), + "configuration access but testcase did not request it", + ) @unittest.skipIf(PY2, "Python 3 only") def test_unclosed_files_are_recorded(self): tmp_path = temppath("support-unclosed", self) def open_without_close(): - with open(tmp_path, 'w'): + with open(tmp_path, "w"): pass open(tmp_path) return @@ -112,11 +121,13 @@ def assert_is_int(value): assert isinstance(value, int) attempt_wrapper = self.assert_over( - values=(1, 2, 3), _AssertOver=InstrumentedAssertOver) + values=(1, 2, 3), _AssertOver=InstrumentedAssertOver + ) # record the wrapper on the test case so the subsequent test can assert against it - self._class_attribute('surviving_attempt_wrapper', - value=attempt_wrapper) + self._class_attribute( + "surviving_attempt_wrapper", value=attempt_wrapper + ) with attempt_wrapper as attempt: attempt(assert_is_int) @@ -124,14 +135,15 @@ def assert_is_int(value): self.assertTrue(attempt_wrapper.has_check_callable()) # cleanup was recorded - self.assertIn(attempt_wrapper.get_check_callable(), - self._cleanup_checks) + self.assertIn( + attempt_wrapper.get_check_callable(), self._cleanup_checks + ) def test_when_asserting_over_multiple_values_after(self): # test name is purposefully after ..._recorded in sort order # such that we can check the check function was called correctly - attempt_wrapper = self._class_attribute('surviving_attempt_wrapper') + attempt_wrapper = self._class_attribute("surviving_attempt_wrapper") self.assertTrue(attempt_wrapper.was_check_callable_called()) @@ -139,7 +151,7 @@ class SupportTestCase_using_fakeconfig(MigTestCase): """Coverage of a MiG Testcase hat requests a fakeconfig""" def _provide_configuration(self): - return 'fakeconfig' + return "fakeconfig" def test_provides_a_fake_configuration(self): configuration = self.configuration @@ -157,10 +169,10 @@ class SupportTestCase_using_testconfig(MigTestCase): """Coverage of a MiG Testcase that requests a testconfig""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_provides_the_test_configuration(self): - expected_last_dir = 'testconfs-py2' if PY2 else 'testconfs-py3' + expected_last_dir = "testconfs-py2" if PY2 else "testconfs-py3" configuration = self.configuration @@ -173,5 +185,5 @@ def test_provides_the_test_configuration(self): self.assertTrue(config_file_last_dir, expected_last_dir) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_tests_support_assertover.py b/tests/test_tests_support_assertover.py index 702a80e94..ae9be10b2 100644 --- a/tests/test_tests_support_assertover.py +++ b/tests/test_tests_support_assertover.py @@ -35,7 +35,7 @@ def assert_a_thing(value): """A simple assert helper to test with""" - assert value.endswith(' thing'), "must end with a thing" + assert value.endswith(" thing"), "must end with a thing" class TestsSupportAssertOver(unittest.TestCase): @@ -44,7 +44,9 @@ class TestsSupportAssertOver(unittest.TestCase): def test_none_failing(self): saw_raise = False try: - with AssertOver(values=('some thing', 'other thing')) as value_block: + with AssertOver( + values=("some thing", "other thing") + ) as value_block: value_block(lambda _: assert_a_thing(_)) except Exception as exc: saw_raise = True @@ -52,13 +54,18 @@ def test_none_failing(self): def test_three_total_two_failing(self): with self.assertRaises(AssertionError) as raised: - with AssertOver(values=('some thing', 'other stuff', 'foobar')) as value_block: + with AssertOver( + values=("some thing", "other stuff", "foobar") + ) as value_block: value_block(lambda _: assert_a_thing(_)) theexception = raised.exception - self.assertEqual(str(theexception), """assertions raised for the following values: + self.assertEqual( + str(theexception), + """assertions raised for the following values: - <'other stuff'> : must end with a thing -- <'foobar'> : must end with a thing""") +- <'foobar'> : must end with a thing""", + ) def test_no_cases(self): with self.assertRaises(AssertionError) as raised: @@ -69,5 +76,5 @@ def test_no_cases(self): self.assertIsInstance(theexception, NoCasesError) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_tests_support_configsupp.py b/tests/test_tests_support_configsupp.py index 0e85cd2c1..8d8d23e2d 100644 --- a/tests/test_tests_support_configsupp.py +++ b/tests/test_tests_support_configsupp.py @@ -27,11 +27,10 @@ """Unit tests for the tests module pointed to in the filename""" +from mig.shared.configuration import Configuration from tests.support import MigTestCase, testmain from tests.support.configsupp import FakeConfiguration -from mig.shared.configuration import Configuration - class TestsSupportConfigsupp_FakeConfiguration(MigTestCase): """Check some basic behaviours of FakeConfiguration instances.""" @@ -43,13 +42,13 @@ def test_consistent_parameters(self): self.maxDiff = None self.assertEqual( Configuration.to_dict(default_configuration), - Configuration.to_dict(fake_configuration) + Configuration.to_dict(fake_configuration), ) def test_only_configuration_keys(self): with self.assertRaises(AssertionError): - FakeConfiguration(bar='1') + FakeConfiguration(bar="1") -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_tests_support_wsgisupp.py b/tests/test_tests_support_wsgisupp.py index 4adcb975d..03e5346e4 100644 --- a/tests/test_tests_support_wsgisupp.py +++ b/tests/test_tests_support_wsgisupp.py @@ -28,15 +28,15 @@ """Unit tests for the tests module pointed to in the filename""" import unittest -from mig.shared.compat import SimpleNamespace +from mig.shared.compat import SimpleNamespace from tests.support import AssertOver from tests.support.wsgisupp import prepare_wsgi def assert_a_thing(value): """A simple assert helper to test with""" - assert value.endswith(' thing'), "must end with a thing" + assert value.endswith(" thing"), "must end with a thing" class TestsSupportWsgisupp_prepare_wsgi(unittest.TestCase): @@ -44,56 +44,57 @@ class TestsSupportWsgisupp_prepare_wsgi(unittest.TestCase): def test_prepare_GET(self): configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, 'http://testhost/some/path') + environ, _ = prepare_wsgi(configuration, "http://testhost/some/path") - self.assertEqual(environ['MIG_CONF'], - '/path/to/the/confs/MiGserver.conf') - self.assertEqual(environ['HTTP_HOST'], 'testhost') - self.assertEqual(environ['PATH_INFO'], '/some/path') - self.assertEqual(environ['REQUEST_METHOD'], 'GET') + self.assertEqual( + environ["MIG_CONF"], "/path/to/the/confs/MiGserver.conf" + ) + self.assertEqual(environ["HTTP_HOST"], "testhost") + self.assertEqual(environ["PATH_INFO"], "/some/path") + self.assertEqual(environ["REQUEST_METHOD"], "GET") def test_prepare_GET_with_query(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, test_url, query={ - 'foo': 'true', - 'bar': 1 - }) + environ, _ = prepare_wsgi( + configuration, test_url, query={"foo": "true", "bar": 1} + ) - self.assertEqual(environ['QUERY_STRING'], 'foo=true&bar=1') + self.assertEqual(environ["QUERY_STRING"], "foo=true&bar=1") def test_prepare_POST(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, test_url, method='POST') + environ, _ = prepare_wsgi(configuration, test_url, method="POST") - self.assertEqual(environ['REQUEST_METHOD'], 'POST') + self.assertEqual(environ["REQUEST_METHOD"], "POST") def test_prepare_POST_with_headers(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) headers = { - 'Authorization': 'Basic XXXX', - 'Content-Length': 0, + "Authorization": "Basic XXXX", + "Content-Length": 0, } environ, _ = prepare_wsgi( - configuration, test_url, method='POST', headers=headers) + configuration, test_url, method="POST", headers=headers + ) - self.assertEqual(environ['CONTENT_LENGTH'], 0) - self.assertEqual(environ['HTTP_AUTHORIZATION'], 'Basic XXXX') + self.assertEqual(environ["CONTENT_LENGTH"], 0) + self.assertEqual(environ["HTTP_AUTHORIZATION"], "Basic XXXX") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()