diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 020041d2..7964e45b 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -1,6 +1,6 @@ """Main SDGym benchmarking module.""" -import concurrent +import functools import logging import math import multiprocessing @@ -16,11 +16,11 @@ from datetime import datetime, timedelta from importlib.metadata import version from pathlib import Path +from typing import Any, NamedTuple, Optional from urllib.parse import urlparse import boto3 import cloudpickle -import compress_pickle import numpy as np import pandas as pd import tqdm @@ -42,15 +42,13 @@ from sdgym.datasets import get_dataset_paths, load_dataset from sdgym.errors import BenchmarkError, SDGymError from sdgym.metrics import get_metrics -from sdgym.progress import TqdmLogger, progress +from sdgym.progress import TqdmLogger from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter from sdgym.s3 import ( S3_PREFIX, S3_REGION, is_s3_path, parse_s3_path, - write_csv, - write_file, ) from sdgym.synthesizers import UniformSynthesizer from sdgym.synthesizers.base import BaselineSynthesizer @@ -59,7 +57,6 @@ convert_metadata_to_sdmetrics, format_exception, get_duplicates, - get_num_gpus, get_size_of, get_synthesizers, get_utc_now, @@ -94,18 +91,23 @@ ] -def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers): - if output_filepath and os.path.exists(output_filepath): - raise ValueError( - f'{output_filepath} already exists. Please provide a file that does not already exist.' - ) +class JobArgs(NamedTuple): + """Arguments needed to run a single synthesizer + dataset benchmark job.""" - if detailed_results_folder and os.path.exists(detailed_results_folder): - raise ValueError( - f'{detailed_results_folder} already exists. ' - 'Please provide a folder that does not already exist.' - ) + synthesizer: dict + data: Any + metadata: Any + metrics: Any + timeout: Optional[int] + compute_quality_score: bool + compute_diagnostic_score: bool + compute_privacy_score: bool + dataset_name: str + modality: str + output_directions: Optional[dict] + +def _validate_inputs(synthesizers, custom_synthesizers): duplicates = get_duplicates(synthesizers) if synthesizers else {} if custom_synthesizers: duplicates.update(get_duplicates(custom_synthesizers)) @@ -116,12 +118,6 @@ def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, cus ) -def _create_detailed_results_directory(detailed_results_folder): - if detailed_results_folder and not is_s3_path(detailed_results_folder): - detailed_results_folder = Path(detailed_results_folder) - os.makedirs(detailed_results_folder, exist_ok=True) - - def _get_metainfo_increment(top_folder, s3_client=None): increments = [] first_file_message = 'No metainfo file found, starting from increment (0)' @@ -238,7 +234,6 @@ def _generate_job_args_list( sdv_datasets, additional_datasets_folder, sdmetrics, - detailed_results_folder, timeout, output_destination, compute_quality_score, @@ -306,21 +301,21 @@ def _generate_job_args_list( 'single_table', dataset, limit_dataset_size=limit_dataset_size ) path = paths.get(dataset.name, {}).get(synthesizer['name'], None) - args = ( - synthesizer, - data, - metadata_dict, - sdmetrics, - detailed_results_folder, - timeout, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, - dataset.name, - 'single_table', - path, + job_args_list.append( + JobArgs( + synthesizer=synthesizer, + data=data, + metadata=metadata_dict, + metrics=sdmetrics, + timeout=timeout, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + dataset_name=dataset.name, + modality='single_table', + output_directions=path, + ) ) - job_args_list.append(args) return job_args_list @@ -668,7 +663,6 @@ def _format_output( compute_quality_score, compute_diagnostic_score, compute_privacy_score, - cache_dir, ): evaluate_time = 0 if 'quality_score_time' in output: @@ -716,40 +710,24 @@ def _format_output( if 'error' in output: scores['error'] = output['error'] - if cache_dir: - cache_dir_name = str(cache_dir) - base_path = f'{cache_dir_name}/{name}_{dataset_name}' - if scores is not None: - write_csv(scores, f'{base_path}_scores.csv', None, None) - if 'synthetic_data' in output: - synthetic_data = compress_pickle.dumps(output['synthetic_data'], compression='gzip') - write_file(synthetic_data, f'{base_path}.data.gz', None, None) - if 'exception' in output: - exception = output['exception'].encode('utf-8') - write_file(exception, f'{base_path}_error.txt', None, None) - return scores -def _run_job(args): +def _run_job(job_args, result_writer=None): # Reset random seed np.random.seed() - ( - synthesizer, - data, - metadata, - metrics, - cache_dir, - timeout, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, - dataset_name, - modality, - synthesizer_path, - result_writer, - ) = args + synthesizer = job_args.synthesizer + data = job_args.data + metadata = job_args.metadata + metrics = job_args.metrics + timeout = job_args.timeout + compute_quality_score = job_args.compute_quality_score + compute_diagnostic_score = job_args.compute_diagnostic_score + compute_privacy_score = job_args.compute_privacy_score + dataset_name = job_args.dataset_name + modality = job_args.modality + synthesizer_path = job_args.output_directions name = synthesizer['name'] LOGGER.info( @@ -800,7 +778,6 @@ def _run_job(args): compute_quality_score, compute_diagnostic_score, compute_privacy_score, - cache_dir, ) if synthesizer_path and result_writer: result_writer.write_dataframe(scores, synthesizer_path['benchmark_result']) @@ -808,49 +785,8 @@ def _run_job(args): return scores -def _run_on_dask(jobs, verbose): - """Run the tasks in parallel using dask.""" - try: - import dask - except ImportError as ie: - ie.msg += ( - '\n\nIt seems like `dask` is not installed.\n' - 'Please install `dask` and `distributed` using:\n' - '\n pip install dask distributed' - ) - raise - - scorer = dask.delayed(_run_job) - persisted = dask.persist(*[scorer(args) for args in jobs]) - if verbose: - try: - progress(persisted) - except ValueError: - pass - - return dask.compute(*persisted) - - -def _run_jobs(multi_processing_config, job_args_list, show_progress, result_writer=None): - workers = 1 - if multi_processing_config: - if multi_processing_config['package_name'] == 'dask': - workers = 'dask' - scores = _run_on_dask(job_args_list, show_progress) - else: - num_gpus = get_num_gpus() - if num_gpus > 0: - workers = num_gpus - else: - workers = multiprocessing.cpu_count() - - job_args_list = [job_args + (result_writer,) for job_args in job_args_list] - if workers in (0, 1): - scores = map(_run_job, job_args_list) - elif workers != 'dask': - pool = concurrent.futures.ProcessPoolExecutor(workers) - scores = pool.map(_run_job, job_args_list) - +def _run_jobs(job_args_list, show_progress, result_writer=None): + scores = map(functools.partial(_run_job, result_writer=result_writer), job_args_list) if show_progress: scores = tqdm.tqdm(scores, total=len(job_args_list), position=0, leave=True) else: @@ -862,9 +798,8 @@ def _run_jobs(multi_processing_config, job_args_list, show_progress, result_writ raise SDGymError('No valid Dataset/Synthesizer combination given.') scores = pd.concat(scores, ignore_index=True) - _add_adjusted_scores(scores=scores, timeout=job_args_list[0][5]) - output_directions = job_args_list[0][-2] - result_writer = job_args_list[0][-1] + _add_adjusted_scores(scores=scores, timeout=job_args_list[0].timeout) + output_directions = job_args_list[0].output_directions if output_directions and result_writer: path = output_directions['results'] result_writer.write_dataframe(scores, path, append=True) @@ -903,15 +838,6 @@ def _get_empty_dataframe( return scores -def _directory_exists(bucket_name, s3_file_path): - # Find the last occurrence of '/' in the file path - last_slash_index = s3_file_path.rfind('/') - directory_prefix = s3_file_path[: last_slash_index + 1] - s3_client = boto3.client('s3') - response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=directory_prefix, Delimiter='/') - return 'Contents' in response or 'CommonPrefixes' in response - - def _check_write_permissions(s3_client, bucket_name): s3_client = s3_client or boto3.client('s3') try: @@ -926,157 +852,6 @@ def _check_write_permissions(s3_client, bucket_name): return write_permission -def _create_sdgym_script(params, output_filepath): - # Confirm the path works - if not is_s3_path(output_filepath): - raise ValueError("""Invalid S3 path format. - Expected 's3:///'.""") - bucket_name, key_prefix = parse_s3_path(output_filepath) - if not _directory_exists(bucket_name, key_prefix): - raise ValueError(f'Directories in {key_prefix} do not exist') - if not _check_write_permissions(None, bucket_name): - raise ValueError('No write permissions allowed for the bucket.') - - # Add quotes to parameter strings - if params['additional_datasets_folder']: - params['additional_datasets_folder'] = "'" + params['additional_datasets_folder'] + "'" - if params['detailed_results_folder']: - params['detailed_results_folder'] = "'" + params['detailed_results_folder'] + "'" - if params['output_filepath']: - params['output_filepath'] = "'" + params['output_filepath'] + "'" - - # Generate the output script to run on the e2 instance - synthesizers = params.get('synthesizers', []) - names = [] - for synthesizer in synthesizers: - if isinstance(synthesizer, str): - names.append(synthesizer) - elif hasattr(synthesizer, '__name__'): - names.append(synthesizer.__name__) - else: - names.append(synthesizer.__class__.__name__) - - all_names = '", "'.join(names) - synthesizer_string = f'synthesizers=["{all_names}"]' - # The indentation of the string is important for the python script - script_content = f"""import boto3 -from io import StringIO -import sdgym - -results = sdgym.benchmark_single_table( - {synthesizer_string}, custom_synthesizers={params['custom_synthesizers']}, - sdv_datasets={params['sdv_datasets']}, output_filepath={params['output_filepath']}, - additional_datasets_folder={params['additional_datasets_folder']}, - limit_dataset_size={params['limit_dataset_size']}, - compute_quality_score={params['compute_quality_score']}, - compute_diagnostic_score={params['compute_diagnostic_score']}, - compute_privacy_score={params['compute_privacy_score']}, - sdmetrics={params['sdmetrics']}, - timeout={params['timeout']}, - detailed_results_folder={params['detailed_results_folder']}, - multi_processing_config={params['multi_processing_config']} -) -""" - - return script_content - - -def _create_instance_on_ec2(script_content): - ec2_client = boto3.client('ec2') - session = boto3.session.Session() - credentials = session.get_credentials() - print(f'This instance is being created in region: {session.region_name}') # noqa - escaped_script = script_content.strip().replace('"', '\\"') - - # User data script to install the library - user_data_script = f"""#!/bin/bash - sudo apt update -y - sudo apt install -y python3-pip python3-venv awscli - echo "======== Create Virtual Environment ============" - python3 -m venv ~/env - source ~/env/bin/activate - echo "======== Install Dependencies in venv ============" - pip install --upgrade pip - pip install sdgym[all] - pip install anyio - echo "======== Configure AWS CLI ============" - aws configure set aws_access_key_id {credentials.access_key} - aws configure set aws_secret_access_key {credentials.secret_key} - aws configure set region {session.region_name} - echo "======== Write Script ===========" - printf '%s\\n' "{escaped_script}" > ~/sdgym_script.py - echo "======== Run Script ===========" - python ~/sdgym_script.py - - echo "======== Complete ===========" - INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id) - aws ec2 terminate-instances --instance-ids $INSTANCE_ID - """ - - response = ec2_client.run_instances( - ImageId='ami-080e1f13689e07408', - InstanceType='g4dn.4xlarge', - MinCount=1, - MaxCount=1, - UserData=user_data_script, - TagSpecifications=[ - {'ResourceType': 'instance', 'Tags': [{'Key': 'Name', 'Value': 'SDGym_Temp'}]} - ], - BlockDeviceMappings=[ - { - 'DeviceName': '/dev/sda1', - 'Ebs': { - 'VolumeSize': 32, # Specify the desired size in GB - 'VolumeType': 'gp2', # Change the volume type as needed - }, - } - ], - ) - - # Wait until the instance is running before terminating - instance_id = response['Instances'][0]['InstanceId'] - waiter = ec2_client.get_waiter('instance_status_ok') - waiter.wait(InstanceIds=[instance_id]) - print(f'Job kicked off for SDGym on {instance_id}') # noqa - - -def _handle_deprecated_parameters( - output_filepath, - detailed_results_folder, - multi_processing_config, - run_on_ec2, - output_destination, -): - """Handle deprecated parameters and issue warnings.""" - parameters_to_deprecate = { - 'output_filepath': output_filepath, - 'detailed_results_folder': detailed_results_folder, - 'multi_processing_config': multi_processing_config, - 'run_on_ec2': run_on_ec2, - } - parameters = [] - old_parameters_to_save = ('output_filepath', 'detailed_results_folder') - for name, value in parameters_to_deprecate.items(): - if value is not None and value: - if name in old_parameters_to_save and output_destination is not None: - raise ValueError( - f"The '{name}' parameter is deprecated and cannot be used together with " - "'output_destination'. Please use only 'output_destination' to specify " - 'the output path.' - ) - - parameters.append(name) - - if parameters: - parameters = "', '".join(sorted(parameters)) - message = ( - f"Parameters '{parameters}' are deprecated in the 'benchmark_single_table' " - "function. For saving results, please use the 'output_destination' parameter." - " For running SDGym remotely on AWS please use the 'benchmark_single_table_aws' method." - ) - warnings.warn(message, FutureWarning) - - def _validate_output_destination(output_destination, aws_keys=None): """Validate the output destination parameter.""" if output_destination is None and aws_keys is None: @@ -1100,11 +875,11 @@ def _validate_output_destination(output_destination, aws_keys=None): def _write_metainfo_file(synthesizers, job_args_list, result_writer=None): - jobs = [[job[-3], job[0]['name']] for job in job_args_list] - if not job_args_list or not job_args_list[0][-1]: + jobs = [[job.dataset_name, job.synthesizer['name']] for job in job_args_list] + if not job_args_list or not job_args_list[0].output_directions: return - output_directions = job_args_list[0][-1] + output_directions = job_args_list[0].output_directions path = output_directions['metainfo'] stem = Path(path).stem match = FILE_INCREMENT_PATTERN.search(stem) @@ -1233,11 +1008,7 @@ def benchmark_single_table( sdmetrics=None, timeout=None, output_destination=None, - output_filepath=None, - detailed_results_folder=None, show_progress=False, - multi_processing_config=None, - run_on_ec2=False, ): """Run the SDGym benchmark on single-table datasets. @@ -1292,75 +1063,38 @@ def benchmark_single_table( / synthesizer.pkl synthetic_data.csv - output_filepath (str or ``None``): - A file path for where to write the output as a csv file. If ``None``, no output - is written. If run_on_ec2 flag output_filepath needs to be defined and - the filepath should be structured as: s3://{s3_bucket_name}/{path_to_file} - Please make sure the path exists and permissions are given. - detailed_results_folder (str or ``None``): - The folder for where to store the intermediary results. If ``None``, do not store - the intermediate results anywhere. show_progress (bool): Whether to use tqdm to keep track of the progress. Defaults to ``False``. - multi_processing_config (dict or ``None``): - The config to use if multi-processing is desired. For example, - { - 'package_name': 'dask' or 'multiprocessing', - 'num_workers': 4 - } - run_on_ec2 (bool): - The flag is used to run the benchmark on an EC2 instance that will be created - by a script using the authentication of the current user. The EC2 instance - uses the LATEST released version of sdgym. Local changes or changes NOT - in the released version will NOT be used in the ec2 instance. Returns: pandas.DataFrame: A table containing one row per synthesizer + dataset + metric. """ - _handle_deprecated_parameters( - output_filepath, - detailed_results_folder, - multi_processing_config, - run_on_ec2, - output_destination, - ) _validate_output_destination(output_destination) if not synthesizers: synthesizers = [] _ensure_uniform_included(synthesizers) + _validate_inputs(synthesizers, custom_synthesizers) result_writer = LocalResultsWriter() - if run_on_ec2: - print("This will create an instance for the current AWS user's account.") # noqa - if output_filepath is not None: - script_content = _create_sdgym_script(dict(locals()), output_filepath) - _create_instance_on_ec2(script_content) - else: - raise ValueError('In order to run on EC2, please provide an S3 folder output.') - return None - - _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers) - _create_detailed_results_directory(detailed_results_folder) job_args_list = _generate_job_args_list( - limit_dataset_size, - sdv_datasets, - additional_datasets_folder, - sdmetrics, - detailed_results_folder, - timeout, - output_destination, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, - synthesizers, - custom_synthesizers, + limit_dataset_size=limit_dataset_size, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + sdmetrics=sdmetrics, + timeout=timeout, + output_destination=output_destination, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + synthesizers=synthesizers, + custom_synthesizers=custom_synthesizers, s3_client=None, ) _write_metainfo_file(synthesizers, job_args_list, result_writer) if job_args_list: - scores = _run_jobs(multi_processing_config, job_args_list, show_progress, result_writer) + scores = _run_jobs(job_args_list, show_progress, result_writer) # If no synthesizers/datasets are passed, return an empty dataframe else: @@ -1371,11 +1105,8 @@ def benchmark_single_table( sdmetrics=sdmetrics, ) - if output_filepath: - write_csv(scores, output_filepath, None, None) - if output_destination and job_args_list: - metainfo_filename = job_args_list[0][-1]['metainfo'] + metainfo_filename = job_args_list[0].output_directions['metainfo'] _update_metainfo_file(metainfo_filename, result_writer) return scores @@ -1433,7 +1164,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client): parsed_url = urlparse(output_destination) bucket_name = parsed_url.netloc path = parsed_url.path.lstrip('/') if parsed_url.path else '' - filename = os.path.basename(job_args_list[0][-1]['metainfo']) + filename = os.path.basename(job_args_list[0].output_directions['metainfo']) metainfo = os.path.splitext(filename)[0] job_args_key = f'job_args_list_{metainfo}.pkl' job_args_key = f'{path}{job_args_key}' if path else job_args_key @@ -1451,7 +1182,6 @@ def _get_s3_script_content( import boto3 import pickle from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file -from io import StringIO from sdgym.result_writer import S3ResultsWriter s3_client = boto3.client( @@ -1464,8 +1194,8 @@ def _get_s3_script_content( job_args_list = pickle.loads(response['Body'].read()) result_writer = S3ResultsWriter(s3_client=s3_client) _write_metainfo_file({synthesizers}, job_args_list, result_writer) -scores = _run_jobs(None, job_args_list, False, result_writer=result_writer) -metainfo_filename = job_args_list[0][-1]['metainfo'] +scores = _run_jobs(job_args_list, False, result_writer=result_writer) +metainfo_filename = job_args_list[0].output_directions['metainfo'] _update_metainfo_file(metainfo_filename, result_writer) s3_client.delete_object(Bucket='{bucket_name}', Key='{job_args_key}') """ @@ -1653,7 +1383,6 @@ def benchmark_single_table_aws( compute_diagnostic_score=compute_diagnostic_score, compute_privacy_score=compute_privacy_score, synthesizers=synthesizers, - detailed_results_folder=None, custom_synthesizers=None, s3_client=s3_client, ) diff --git a/sdgym/cli/__main__.py b/sdgym/cli/__main__.py index c89ac2e9..2ef5c7bc 100644 --- a/sdgym/cli/__main__.py +++ b/sdgym/cli/__main__.py @@ -83,7 +83,7 @@ def _run(args): sdmetrics=args.sdmetrics, timeout=args.timeout, show_progress=args.progress, - output_filepath=args.output_path, + output_destination=args.output_destination, ) if args.groupby: @@ -170,19 +170,12 @@ def _get_parser(): required=False, help='List of datasets to benchmark.', ) - run.add_argument( - '-c', - '--cache-dir', - type=str, - required=False, - help='Directory where the intermediate results will be stored.', - ) run.add_argument( '-o', - '--output-path', + '--output-destination', type=str, required=False, - help='Path to the CSV file where the report will be dumped', + help='Directory where the SDGym results folder will be written.', ) run.add_argument( '-m', @@ -220,11 +213,6 @@ def _get_parser(): run.add_argument( '-p', '--progress', action='store_true', help='Print a progress bar using tqdm.' ) - run.add_argument( - 'run_on_ec2', - action='store_true', - help='Run job on created ec2 instance with environment aws variables', - ) run.add_argument('-t', '--timeout', type=int, help='Maximum seconds to run for each dataset.') run.add_argument( '-g', '--groupby', nargs='+', help='Group scores leaderboard by the given fields.' diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index b73b0408..7990afb8 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -13,16 +13,14 @@ from sdgym import benchmark_single_table from sdgym.benchmark import ( + JobArgs, _add_adjusted_scores, _check_write_permissions, - _create_sdgym_script, - _directory_exists, _ensure_uniform_included, _fill_adjusted_scores_with_none, _format_output, _generate_job_args_list, _get_metainfo_increment, - _handle_deprecated_parameters, _setup_output_destination, _setup_output_destination_aws, _update_metainfo_file, @@ -35,25 +33,6 @@ from sdgym.s3 import S3_REGION -@patch('sdgym.benchmark.os.path') -def test_output_file_exists(path_mock): - """Test the benchmark function when the output path already exists.""" - # Setup - path_mock.exists.return_value = True - output_filepath = 's3://test_output.csv' - - # Run and assert - with pytest.raises( - ValueError, - match='test_output.csv already exists. Please provide a file that does not already exist.', - ): - benchmark_single_table( - synthesizers=['DataIdentity', 'ColumnSynthesizer', 'UniformSynthesizer'], - sdv_datasets=['student_placements'], - output_filepath=output_filepath, - ) - - @patch('sdgym.benchmark.boto3.client') @patch('sdgym.benchmark.LOGGER') def test__get_metainfo_increment_aws(mock_logger, mock_client): @@ -109,8 +88,7 @@ def test__get_metainfo_increment_local(mock_logger, tmp_path): @patch('sdgym.benchmark.tqdm.tqdm') -@patch('sdgym.benchmark._handle_deprecated_parameters') -def test_benchmark_single_table_deprecated_params(mock_handle_deprecated, tqdm_mock): +def test_benchmark_single_table_progress_bar(tqdm_mock): """Test that the benchmarking function updates the progress bar on one line.""" # Setup scores_mock = MagicMock() @@ -132,7 +110,6 @@ def test_benchmark_single_table_deprecated_params(mock_handle_deprecated, tqdm_m ) # Assert - mock_handle_deprecated.assert_called_once_with(None, None, None, False, None) tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) @@ -176,40 +153,6 @@ def test_benchmark_single_table_with_timeout(mock_multiprocessing, mock__score): pd.testing.assert_frame_equal(scores, expected_scores, check_dtype=False) -@patch('sdgym.benchmark.boto3.client') -def test__directory_exists(mock_client): - # Setup - mock_client.return_value.list_objects_v2.return_value = { - 'Contents': [ - { - 'Key': 'example.txt', - 'ETag': '"1234567890abcdef1234567890abcdef"', - 'Size': 1024, - 'StorageClass': 'STANDARD', - }, - { - 'Key': 'example_folder/', - 'ETag': '"0987654321fedcba0987654321fedcba"', - 'Size': 0, - 'StorageClass': 'STANDARD', - }, - ], - 'CommonPrefixes': [ - {'Prefix': 'example_folder/subfolder1/'}, - {'Prefix': 'example_folder/subfolder2/'}, - ], - } - - # Run and Assert - assert _directory_exists('bucket', 'file_path/mock.csv') - - # Setup Failure - mock_client.return_value.list_objects_v2.return_value = {} - - # Run and Assert - assert not _directory_exists('bucket', 'file_path/mock.csv') - - def test__check_write_permissions(): """Test the `_check_write_permissions` function.""" # Setup @@ -221,60 +164,6 @@ def test__check_write_permissions(): assert not _check_write_permissions(mock_client, 'bucket') -@patch('sdgym.benchmark._directory_exists') -@patch('sdgym.benchmark._check_write_permissions') -@patch('sdgym.benchmark.boto3.session.Session') -@patch('sdgym.benchmark._create_instance_on_ec2') -def test_run_ec2_flag(create_ec2_mock, session_mock, mock_write_permissions, mock_directory_exists): - """Test that the benchmarking function updates the progress bar on one line.""" - # Setup - create_ec2_mock.return_value = MagicMock() - session_mock.get_credentials.return_value = MagicMock() - mock_write_permissions.return_value = True - mock_directory_exists.return_value = True - - # Run - benchmark_single_table(run_on_ec2=True, output_filepath='s3://BucketName/path') - - # Assert - create_ec2_mock.assert_called_once() - - # Run - with pytest.raises( - ValueError, match=r'In order to run on EC2, please provide an S3 folder output.' - ): - benchmark_single_table(run_on_ec2=True) - - # Assert - create_ec2_mock.assert_called_once() - - # Run - with pytest.raises( - ValueError, - match=r"""Invalid S3 path format. - Expected 's3:///'.""", - ): - benchmark_single_table(run_on_ec2=True, output_filepath='Wrong_Format') - - # Assert - create_ec2_mock.assert_called_once() - - # Setup for failure in permissions - mock_write_permissions.return_value = False - - # Run - with pytest.raises(ValueError, match=r'No write permissions allowed for the bucket.'): - benchmark_single_table(run_on_ec2=True, output_filepath='s3://BucketName/path') - - # Setup for failure in directory exists - mock_write_permissions.return_value = True - mock_directory_exists.return_value = False - - # Run - with pytest.raises(ValueError, match=r'Directories in mock/path do not exist'): - benchmark_single_table(run_on_ec2=True, output_filepath='s3://BucketName/mock/path') - - def test__ensure_uniform_included_adds_uniform(caplog): """Test that UniformSynthesizer gets added to the synthesizers list.""" # Setup @@ -320,59 +209,6 @@ def test__ensure_uniform_included_detects_uniform_string(caplog): assert all(expected_message not in record.message for record in caplog.records) -@patch('sdgym.benchmark._directory_exists') -@patch('sdgym.benchmark._check_write_permissions') -@patch('sdgym.benchmark.boto3.session.Session') -def test__create_sdgym_script(session_mock, mock_write_permissions, mock_directory_exists): - """Test that the created SDGym script contains the expected values.""" - # Setup - session_mock.get_credentials.return_value = MagicMock() - test_params = { - 'synthesizers': ['GaussianCopulaSynthesizer', 'CTGANSynthesizer'], - 'custom_synthesizers': None, - 'sdv_datasets': [ - 'adult', - 'alarm', - 'census', - 'child', - 'expedia_hotel_logs', - 'insurance', - 'intrusion', - 'news', - 'covtype', - ], - 'limit_dataset_size': True, - 'compute_quality_score': False, - 'compute_privacy_score': False, - 'compute_diagnostic_score': False, - 'sdmetrics': None, - 'timeout': 600, - 'output_filepath': 's3://sdgym-results/address_comments.csv', - 'detailed_results_folder': None, - 'additional_datasets_folder': 'Details/', - 'show_progress': False, - 'multi_processing_config': None, - 'dummy': True, - } - mock_write_permissions.return_value = True - mock_directory_exists.return_value = True - - # Run - result = _create_sdgym_script(test_params, 's3://Bucket/Filepath') - - # Assert - assert 'synthesizers=["GaussianCopulaSynthesizer", "CTGANSynthesizer"]' in result - assert 'detailed_results_folder=None' in result - assert "additional_datasets_folder='Details/'" in result - assert 'multi_processing_config=None' in result - assert 'sdmetrics=None' in result - assert 'timeout=600' in result - assert 'compute_quality_score=False' in result - assert 'compute_diagnostic_score=False' in result - assert 'compute_privacy_score=False' in result - assert 'import boto3' in result - - def test__format_output(): """Test the method ``_format_output`` and confirm that metrics are properly computed.""" # Setup @@ -410,7 +246,7 @@ def test__format_output(): } # Run - scores = _format_output(mock_output, 'mock_name', 'mock_dataset', True, True, True, False) + scores = _format_output(mock_output, 'mock_name', 'mock_dataset', True, True, True) # Assert expected_scores = pd.DataFrame({ @@ -431,53 +267,6 @@ def test__format_output(): pd.testing.assert_frame_equal(scores, expected_scores) -def test__handle_deprecated_parameters(): - """Test the ``_handle_deprecated_parameters`` function.""" - # Setup - output_filepath = 's3://BucketName/path' - detailed_results_folder = 'mock/path' - multi_processing_config = {'num_processes': 4} - run_on_ec2 = True - base_warning = ( - "are deprecated in the 'benchmark_single_table' function. For saving results, " - "please use the 'output_destination' parameter. For running SDGym remotely on AWS " - "please use the 'benchmark_single_table_aws' method." - ) - base_error = ( - "parameter is deprecated and cannot be used together with 'output_destination'. " - "Please use only 'output_destination' to specify the output path." - ) - - # Expected messages - expected_warning_1 = "Parameters 'detailed_results_folder', 'output_filepath' " + base_warning - expected_warning_2 = ( - "Parameters 'detailed_results_folder', 'multi_processing_config', " - "'output_filepath', 'run_on_ec2' " + base_warning - ) - expected_error_1 = f"The 'output_filepath' {base_error}" - expected_error_2 = f"The 'detailed_results_folder' {base_error}" - - # Run and Assert - _handle_deprecated_parameters(None, None, None, False, None) - with pytest.warns(FutureWarning, match=expected_warning_1): - _handle_deprecated_parameters(output_filepath, detailed_results_folder, None, False, None) - - with pytest.warns(FutureWarning, match=expected_warning_2): - _handle_deprecated_parameters( - output_filepath, detailed_results_folder, multi_processing_config, run_on_ec2, None - ) - - with pytest.raises(ValueError, match=expected_error_1): - _handle_deprecated_parameters( - output_filepath, None, multi_processing_config, run_on_ec2, 'output_destination' - ) - - with pytest.raises(ValueError, match=expected_error_2): - _handle_deprecated_parameters( - None, detailed_results_folder, multi_processing_config, run_on_ec2, 'output_destination' - ) - - def test__validate_output_destination(tmp_path): """Test the `_validate_output_destination` function.""" # Setup @@ -568,8 +357,32 @@ def test__write_metainfo_file(mock_datetime, tmp_path): file_name = {'metainfo': f'{output_destination}/metainfo.yaml'} result_writer = LocalResultsWriter() jobs = [ - ({'name': 'GaussianCopulaSynthesizer'}, 'adult', None, file_name), - ({'name': 'CTGANSynthesizer'}, 'census', None, None), + JobArgs( + synthesizer={'name': 'GaussianCopulaSynthesizer'}, + data=None, + metadata=None, + metrics=None, + timeout=None, + compute_quality_score=False, + compute_diagnostic_score=False, + compute_privacy_score=False, + dataset_name='adult', + modality='single_table', + output_directions=file_name, + ), + JobArgs( + synthesizer={'name': 'CTGANSynthesizer'}, + data=None, + metadata=None, + metrics=None, + timeout=None, + compute_quality_score=False, + compute_diagnostic_score=False, + compute_privacy_score=False, + dataset_name='census', + modality='single_table', + output_directions=None, + ), ] expected_jobs = [['adult', 'GaussianCopulaSynthesizer'], ['census', 'CTGANSynthesizer']] synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'RealTabFormerSynthesizer'] @@ -791,7 +604,6 @@ def test_benchmark_single_table_aws( compute_diagnostic_score=True, compute_privacy_score=True, synthesizers=synthesizers, - detailed_results_folder=None, custom_synthesizers=None, s3_client='s3_client_mock', ) @@ -849,7 +661,6 @@ def test_benchmark_single_table_aws_synthesizers_none( compute_diagnostic_score=True, compute_privacy_score=True, synthesizers=['UniformSynthesizer'], - detailed_results_folder=None, custom_synthesizers=None, s3_client='s3_client_mock', ) @@ -1025,7 +836,6 @@ def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_ sdv_datasets=None, additional_datasets_folder=str(local_root), sdmetrics=None, - detailed_results_folder=None, timeout=None, output_destination=None, compute_quality_score=False, @@ -1059,7 +869,6 @@ def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_moc sdv_datasets=None, additional_datasets_folder=s3_root, sdmetrics=None, - detailed_results_folder=None, timeout=None, output_destination=None, compute_quality_score=False,