diff --git a/.github/workflows/run_benchmark.yml b/.github/workflows/run_benchmark.yml index 32297c71..db66e6f2 100644 --- a/.github/workflows/run_benchmark.yml +++ b/.github/workflows/run_benchmark.yml @@ -1,31 +1,72 @@ name: Run SDGym Benchmark on: - workflow_dispatch: - schedule: - - cron: '0 5 1 * *' + workflow_call: + inputs: + modality: + required: true + type: string + secrets: + SDV_ENTERPRISE_USERNAME: + required: true + SDV_ENTERPRISE_LICENSE_KEY: + required: true + GCP_SERVICE_ACCOUNT_JSON: + required: true + AWS_ACCESS_KEY_ID: + required: true + AWS_SECRET_ACCESS_KEY: + required: true + SLACK_TOKEN: + required: true jobs: run-sdgym-benchmark: runs-on: ubuntu-latest + steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Set up latest Python - uses: actions/setup-python@v5 - with: - python-version-file: 'pyproject.toml' - - name: Install dependencies - run: | + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: "pyproject.toml" + + - name: Install dependencies + env: + USERNAME: ${{ secrets.SDV_ENTERPRISE_USERNAME }} + LICENSE_KEY: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }} + run: | + python -m venv venv + source venv/bin/activate + python -m pip install --upgrade pip - python -m pip install --no-cache-dir -e .[dev] - - - name: Run SDGym Benchmark - env: - SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }} - - run: invoke run-sdgym-benchmark + python -m pip install sdv-installer + python -c " + from sdv_installer.installation.installer import install_packages + install_packages( + username='${USERNAME}', + license_key='${LICENSE_KEY}', + package='sdv-enterprise', + ) + python -m pip install sdgym[all] + + echo "VIRTUAL_ENV=$(pwd)/venv" >> $GITHUB_ENV + echo "$(pwd)/venv/bin" >> $GITHUB_PATH + + - name: Run SDGym Benchmark + env: + GCP_SERVICE_ACCOUNT_JSON: ${{ secrets.GCP_SERVICE_ACCOUNT_JSON }} + GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} + GCP_ZONE: ${{ secrets.GCP_ZONE }} + SDV_ENTERPRISE_USERNAME: ${{ secrets.SDV_ENTERPRISE_USERNAME }} + SDV_ENTERPRISE_LICENSE_KEY: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} + run: | + export CREDENTIALS_FILEPATH=$(python -c "from sdgym._benchmark.credentials_utils import create_credentials_file; print(create_credentials_file())") + invoke run-sdgym-benchmark --modality "${{ inputs.modality }}" + rm -f "$CREDENTIALS_FILEPATH" diff --git a/.github/workflows/run_benchmark_multi_table.yaml b/.github/workflows/run_benchmark_multi_table.yaml new file mode 100644 index 00000000..38bc85b3 --- /dev/null +++ b/.github/workflows/run_benchmark_multi_table.yaml @@ -0,0 +1,13 @@ +name: Run SDGym Benchmark Multi-Table + +on: + workflow_dispatch: + schedule: + - cron: "0 5 1 * *" + +jobs: + call-run-sdgym-benchmark: + uses: ./.github/workflows/run_benchmark.yml + with: + modality: multi_table + secrets: inherit diff --git a/.github/workflows/run_benchmark_single_table.yml b/.github/workflows/run_benchmark_single_table.yml new file mode 100644 index 00000000..1b0b96e3 --- /dev/null +++ b/.github/workflows/run_benchmark_single_table.yml @@ -0,0 +1,13 @@ +name: Run SDGym Benchmark Single-Table + +on: + workflow_dispatch: + schedule: + - cron: "0 5 1 * *" + +jobs: + call-run-sdgym-benchmark: + uses: ./.github/workflows/run_benchmark.yml + with: + modality: single_table + secrets: inherit diff --git a/.github/workflows/upload_benchmark_results.yml b/.github/workflows/upload_benchmark_results.yml index 5423ac6d..0f154ae2 100644 --- a/.github/workflows/upload_benchmark_results.yml +++ b/.github/workflows/upload_benchmark_results.yml @@ -1,91 +1,107 @@ -name: Upload SDGym Benchmark results +name: Upload SDGym Benchmark Results on: - workflow_run: - workflows: ["Run SDGym Benchmark"] - types: - - completed - workflow_dispatch: - schedule: - - cron: '0 6 * * *' + workflow_call: + inputs: + modality: + description: "Benchmark modality to upload" + required: true + type: string + secrets: + PYDRIVE_TOKEN: + required: true + AWS_ACCESS_KEY_ID: + required: true + AWS_SECRET_ACCESS_KEY: + required: true + GH_TOKEN: + required: true + SLACK_TOKEN: + required: true jobs: upload-sdgym-benchmark: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up latest Python - uses: actions/setup-python@v5 - with: - python-version-file: 'pyproject.toml' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install --no-cache-dir -e .[dev] - - - name: Upload SDGym Benchmark - env: - PYDRIVE_TOKEN: ${{ secrets.PYDRIVE_TOKEN }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - GITHUB_LOCAL_RESULTS_DIR: ${{ runner.temp }}/sdgym-leaderboard-files - run: | - invoke upload-benchmark-results - echo "GITHUB_LOCAL_RESULTS_DIR=$GITHUB_LOCAL_RESULTS_DIR" >> $GITHUB_ENV - - - name: Prepare files for commit - if: env.SKIP_UPLOAD != 'true' - run: | - mkdir pr-staging - echo "Looking for files in: $GITHUB_LOCAL_RESULTS_DIR" - ls -l "$GITHUB_LOCAL_RESULTS_DIR" - for f in "$GITHUB_LOCAL_RESULTS_DIR"/*; do - if [ -f "$f" ]; then - base=$(basename "$f") - cp "$f" "pr-staging/${base}" - fi - done - - echo "Files staged for PR:" - ls -l pr-staging - - - name: Checkout target repo (sdv-dev.github.io) - if: env.SKIP_UPLOAD != 'true' - run: | - git clone https://github.com/sdv-dev/sdv-dev.github.io.git target-repo - cd target-repo - git checkout gatsby-home - - - name: Copy results and commit - if: env.SKIP_UPLOAD != 'true' - env: - GH_TOKEN: ${{ secrets.GH_TOKEN }} - FOLDER_NAME: ${{ env.FOLDER_NAME }} - run: | - cp pr-staging/* target-repo/assets/sdgym-leaderboard-files/ - cd target-repo - git checkout gatsby-home - git config --local user.name "github-actions[bot]" - git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" - git add assets/ - git commit -m "Upload SDGym Benchmark Results ($FOLDER_NAME)" || echo "No changes to commit" - git remote set-url origin https://x-access-token:${GH_TOKEN}@github.com/sdv-dev/sdv-dev.github.io.git - git push origin gatsby-home - COMMIT_HASH=$(git rev-parse HEAD) - COMMIT_URL="https://github.com/sdv-dev/sdv-dev.github.io/commit/${COMMIT_HASH}" - echo "Commit URL: $COMMIT_URL" - echo "COMMIT_URL=$COMMIT_URL" >> $GITHUB_ENV - - - name: Send Slack notification - if: env.SKIP_UPLOAD != 'true' - env: - SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} - run: | - invoke notify-sdgym-benchmark-uploaded \ - --folder-name "$FOLDER_NAME" \ - --commit-url "$COMMIT_URL" + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up latest Python + uses: actions/setup-python@v5 + with: + python-version-file: "pyproject.toml" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install --no-cache-dir -e .[dev] + + - name: Upload SDGym Benchmark + env: + PYDRIVE_TOKEN: ${{ secrets.PYDRIVE_TOKEN }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + GITHUB_LOCAL_RESULTS_DIR: ${{ runner.temp }}/sdgym-leaderboard-files + run: | + invoke upload-benchmark-results --modality "${{ inputs.modality }}" + echo "GITHUB_LOCAL_RESULTS_DIR=$GITHUB_LOCAL_RESULTS_DIR" >> $GITHUB_ENV + + - name: Prepare files for commit + if: env.SKIP_UPLOAD != 'true' + run: | + set -euo pipefail + mkdir -p pr-staging + + echo "Looking for files in: $GITHUB_LOCAL_RESULTS_DIR" + ls -l "$GITHUB_LOCAL_RESULTS_DIR" || true + + shopt -s nullglob + for f in "$GITHUB_LOCAL_RESULTS_DIR"/*; do + [ -f "$f" ] && cp "$f" "pr-staging/$(basename "$f")" + done + + echo "Files staged for PR:" + ls -l pr-staging || true + + - name: Checkout target repo (sdv-dev.github.io) + if: env.SKIP_UPLOAD != 'true' + run: | + git clone https://github.com/sdv-dev/sdv-dev.github.io.git target-repo + cd target-repo + git checkout gatsby-home + + - name: Copy results and commit + if: env.SKIP_UPLOAD != 'true' + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + FOLDER_NAME: ${{ env.FOLDER_NAME }} + run: | + set -euo pipefail + + cp -f pr-staging/* target-repo/assets/sdgym-leaderboard-files/ || true + cd target-repo + + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + + git add assets/ + git commit -m "Upload SDGym Benchmark Results ($FOLDER_NAME) - Modality: ${{ inputs.modality }}" || echo "No changes to commit" + + git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/sdv-dev/sdv-dev.github.io.git" + git push origin gatsby-home + + COMMIT_HASH=$(git rev-parse HEAD) + COMMIT_URL="https://github.com/sdv-dev/sdv-dev.github.io/commit/${COMMIT_HASH}" + echo "COMMIT_URL=$COMMIT_URL" >> $GITHUB_ENV + + - name: Send Slack notification + if: env.SKIP_UPLOAD != 'true' + env: + SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} + run: | + invoke notify-sdgym-benchmark-uploaded \ + --folder-name "$FOLDER_NAME" \ + --commit-url "$COMMIT_URL" \ + --modality "${{ inputs.modality }}" diff --git a/.github/workflows/upload_benchmark_results_multi_table.yml b/.github/workflows/upload_benchmark_results_multi_table.yml new file mode 100644 index 00000000..e62d0971 --- /dev/null +++ b/.github/workflows/upload_benchmark_results_multi_table.yml @@ -0,0 +1,16 @@ +name: Upload SDGym Multi-Table Benchmark results + +on: + workflow_run: + workflows: ["Run SDGym Benchmark Multi-Table"] + types: [completed] + workflow_dispatch: + schedule: + - cron: "0 6 * * *" + +jobs: + call-upload-benchmark-results: + uses: ./.github/workflows/upload_benchmark_results.yml + with: + modality: multi_table + secrets: inherit diff --git a/.github/workflows/upload_benchmark_results_single_table.yml b/.github/workflows/upload_benchmark_results_single_table.yml new file mode 100644 index 00000000..3fd8817d --- /dev/null +++ b/.github/workflows/upload_benchmark_results_single_table.yml @@ -0,0 +1,16 @@ +name: Upload SDGym Single-Table Benchmark results + +on: + workflow_run: + workflows: ["Run SDGym Benchmark Single-Table"] + types: [completed] + workflow_dispatch: + schedule: + - cron: "0 6 * * *" + +jobs: + call-upload-benchmark-results: + uses: ./.github/workflows/upload_benchmark_results.yml + with: + modality: single_table + secrets: inherit diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 28831dbb..655c6fbc 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -179,7 +179,7 @@ def _get_user_data_script( log "======== Configure kernel OOM behavior ==========" sudo sysctl -w vm.panic_on_oom=1 - sudo sysctl -w kernel.panic=10 + sudo sysctl -w kernel.panic=0 log "======== Update and Install Dependencies ==========" sudo apt update -y @@ -428,7 +428,7 @@ def _benchmark_single_table_compute_gcp( limit_dataset_size=False, compute_quality_score=True, compute_diagnostic_score=True, - compute_privacy_score=True, + compute_privacy_score=False, sdmetrics=None, timeout=None, ): diff --git a/sdgym/_benchmark/config_utils.py b/sdgym/_benchmark/config_utils.py index 2c38fb89..c283ac5e 100644 --- a/sdgym/_benchmark/config_utils.py +++ b/sdgym/_benchmark/config_utils.py @@ -118,6 +118,6 @@ def validate_compute_config(config): def _make_instance_name(prefix): - day = datetime.now(timezone.utc).strftime('%Y_%m_%d_%H:%M') + day = datetime.now(timezone.utc).strftime('%Y%m%d-%H%M') suffix = uuid.uuid4().hex[:6] return f'{prefix}-{day}-{suffix}' diff --git a/sdgym/_benchmark/credentials_utils.py b/sdgym/_benchmark/credentials_utils.py index 708fef31..0306a89a 100644 --- a/sdgym/_benchmark/credentials_utils.py +++ b/sdgym/_benchmark/credentials_utils.py @@ -1,5 +1,7 @@ import json +import os import textwrap +from tempfile import NamedTemporaryFile CREDENTIAL_KEYS = { 'aws': {'aws_access_key_id', 'aws_secret_access_key'}, @@ -74,3 +76,29 @@ def sdv_install_cmd(credentials): python -c "from sdv_installer.installation.installer import install_packages; \\ install_packages(username='{username}', license_key='{license_key}', package='sdv-enterprise')" """) + + +def create_credentials_file(): + """Create a credentials file.""" + gcp_json = os.getenv('GCP_SERVICE_ACCOUNT_JSON') + credentials = { + 'aws': { + 'aws_access_key_id': os.getenv('AWS_ACCESS_KEY_ID'), + 'aws_secret_access_key': os.getenv('AWS_SECRET_ACCESS_KEY'), + }, + 'gcp': { + **json.loads(gcp_json), + 'gcp_project': os.getenv('GCP_PROJECT_ID'), + 'gcp_zone': os.getenv('GCP_ZONE'), + }, + 'sdv': { + 'username': os.getenv('SDV_ENTERPRISE_USERNAME'), + 'license_key': os.getenv('SDV_ENTERPRISE_LICENSE_KEY'), + }, + } + + tmp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.json') + json.dump(credentials, tmp_file) + tmp_file.flush() + + return tmp_file.name diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index d73ca441..b7bceded 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -1572,8 +1572,9 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client): bucket_name = parsed_url.netloc path = parsed_url.path.lstrip('/') if parsed_url.path else '' filename = os.path.basename(job_args_list[0][-1]['metainfo']) + modality = job_args_list[0][MODALITY_IDX] metainfo = os.path.splitext(filename)[0] - job_args_key = f'job_args_list_{metainfo}.pkl.gz' + job_args_key = f'job_args_list_{modality}_{metainfo}.pkl.gz' job_args_key = f'{path}{job_args_key}' if path else job_args_key serialized_data = cloudpickle.dumps(job_args_list) @@ -1643,7 +1644,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@feature_branch/mutli_table_benchmark" + pip install sdgym[all] pip install s3fs echo "======== Write Script ===========" diff --git a/sdgym/result_explorer/result_explorer.py b/sdgym/result_explorer/result_explorer.py index 46f3f46c..409a2d5d 100644 --- a/sdgym/result_explorer/result_explorer.py +++ b/sdgym/result_explorer/result_explorer.py @@ -20,7 +20,7 @@ def _validate_local_path(path): _BASELINE_BY_MODALITY = { 'single_table': SYNTHESIZER_BASELINE, - 'multi_table': 'MultiTableUniformSynthesizer', + 'multi_table': 'IndependentSynthesizer', } diff --git a/sdgym/run_benchmark/run_benchmark.py b/sdgym/run_benchmark/run_benchmark.py index 5ae5c609..59263bb7 100644 --- a/sdgym/run_benchmark/run_benchmark.py +++ b/sdgym/run_benchmark/run_benchmark.py @@ -6,18 +6,36 @@ from botocore.exceptions import ClientError -from sdgym.benchmark import benchmark_single_table_aws +from sdgym._benchmark.benchmark import ( + _benchmark_multi_table_compute_gcp, + _benchmark_single_table_compute_gcp, +) from sdgym.run_benchmark.utils import ( KEY_DATE_FILE, OUTPUT_DESTINATION_AWS, - SYNTHESIZERS_SPLIT, + SYNTHESIZERS_SPLIT_MULTI_TABLE, + SYNTHESIZERS_SPLIT_SINGLE_TABLE, + _parse_args, get_result_folder_name, post_benchmark_launch_message, ) from sdgym.s3 import get_s3_client, parse_s3_path +MODALITY_TO_SETUP = { + 'single_table': { + 'method': _benchmark_single_table_compute_gcp, + 'synthesizers_split': SYNTHESIZERS_SPLIT_SINGLE_TABLE, + }, + 'multi_table': { + 'method': _benchmark_multi_table_compute_gcp, + 'synthesizers_split': SYNTHESIZERS_SPLIT_MULTI_TABLE, + }, +} + -def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str): +def append_benchmark_run( + aws_access_key_id, aws_secret_access_key, date_str, modality='single_table' +): """Append a new benchmark run to the benchmark dates file in S3.""" s3_client = get_s3_client( aws_access_key_id=aws_access_key_id, @@ -25,7 +43,7 @@ def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str): ) bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS) try: - object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}') + object = s3_client.get_object(Bucket=bucket, Key=f'{prefix}{modality}/{KEY_DATE_FILE}') body = object['Body'].read().decode('utf-8') data = json.loads(body) except ClientError as e: @@ -37,27 +55,29 @@ def append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str): data['runs'].append({'date': date_str, 'folder_name': get_result_folder_name(date_str)}) data['runs'] = sorted(data['runs'], key=lambda x: x['date']) s3_client.put_object( - Bucket=bucket, Key=f'{prefix}{KEY_DATE_FILE}', Body=json.dumps(data).encode('utf-8') + Bucket=bucket, + Key=f'{prefix}{modality}/{KEY_DATE_FILE}', + Body=json.dumps(data).encode('utf-8'), ) def main(): """Main function to run the benchmark and upload results.""" + args = _parse_args() aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d') - for synthesizer_group in SYNTHESIZERS_SPLIT: - benchmark_single_table_aws( + modality = args.modality + for synthesizer_group in MODALITY_TO_SETUP[modality]['synthesizers_split']: + MODALITY_TO_SETUP[modality]['method']( output_destination=OUTPUT_DESTINATION_AWS, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, + credential_filepath=os.getenv('CREDENTIALS_FILEPATH'), synthesizers=synthesizer_group, - compute_privacy_score=False, timeout=345600, # 4 days ) - append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str) - post_benchmark_launch_message(date_str) + append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str, modality=modality) + post_benchmark_launch_message(date_str, compute_service='GCP', modality=modality) if __name__ == '__main__': diff --git a/sdgym/run_benchmark/upload_benchmark_results.py b/sdgym/run_benchmark/upload_benchmark_results.py index 29d29343..85de7d4e 100644 --- a/sdgym/run_benchmark/upload_benchmark_results.py +++ b/sdgym/run_benchmark/upload_benchmark_results.py @@ -16,7 +16,13 @@ from sdgym.result_explorer.result_explorer import ResultsExplorer from sdgym.result_writer import LocalResultsWriter -from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, get_df_to_plot +from sdgym.run_benchmark.utils import ( + MODALITY_TO_GDRIVE_LINK, + OUTPUT_DESTINATION_AWS, + _extract_google_file_id, + _parse_args, + get_df_to_plot, +) from sdgym.s3 import S3_REGION, parse_s3_path LOGGER = logging.getLogger(__name__) @@ -28,8 +34,9 @@ 'Column': 'top center', 'CopulaGAN': 'top center', 'RealTabFormer': 'bottom center', + 'HSA': 'bottom center', + 'Independent': 'top center', } -SDGYM_FILE_ID = '1W3tsGOOtbtTw3g0EVE0irLgY_TN_cy2W4ONiZQ57OPo' RESULT_FILENAME = 'SDGym Monthly Run.xlsx' @@ -45,17 +52,21 @@ def get_latest_run_from_file(s3_client, bucket, key): raise RuntimeError(f'Failed to read {key} from S3: {e}') -def write_uploaded_marker(s3_client, bucket, prefix, folder_name): +def write_uploaded_marker(s3_client, bucket, prefix, folder_name, modality='single_table'): """Write a marker file to indicate that the upload is complete.""" s3_client.put_object( - Bucket=bucket, Key=f'{prefix}{folder_name}/upload_complete.marker', Body=b'Upload complete' + Bucket=bucket, + Key=f'{prefix}{modality}/{folder_name}/upload_complete.marker', + Body=b'Upload complete', ) -def upload_already_done(s3_client, bucket, prefix, folder_name): +def upload_already_done(s3_client, bucket, prefix, folder_name, modality='single_table'): """Check if the upload has already been done by looking for the marker file.""" try: - s3_client.head_object(Bucket=bucket, Key=f'{prefix}{folder_name}/upload_complete.marker') + s3_client.head_object( + Bucket=bucket, Key=f'{prefix}{modality}/{folder_name}/upload_complete.marker' + ) return True except ClientError as e: if e.response['Error']['Code'] == '404': @@ -64,7 +75,9 @@ def upload_already_done(s3_client, bucket, prefix, folder_name): raise -def get_result_folder_name_and_s3_vars(aws_access_key_id, aws_secret_access_key): +def get_result_folder_name_and_s3_vars( + aws_access_key_id, aws_secret_access_key, modality='single_table' +): """Get the result folder name and S3 client variables.""" bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS) s3_client = boto3.client( @@ -73,7 +86,9 @@ def get_result_folder_name_and_s3_vars(aws_access_key_id, aws_secret_access_key) aws_secret_access_key=aws_secret_access_key, region_name=S3_REGION, ) - folder_infos = get_latest_run_from_file(s3_client, bucket, f'{prefix}_BENCHMARK_DATES.json') + folder_infos = get_latest_run_from_file( + s3_client, bucket, f'{prefix}{modality}/_BENCHMARK_DATES.json' + ) return folder_infos, s3_client, bucket, prefix @@ -109,14 +124,21 @@ def upload_to_drive(file_path, file_id): def upload_results( - aws_access_key_id, aws_secret_access_key, folder_infos, s3_client, bucket, prefix, github_env + aws_access_key_id, + aws_secret_access_key, + folder_infos, + s3_client, + bucket, + prefix, + github_env, + modality='single_table', ): """Upload benchmark results to S3, GDrive, and save locally.""" folder_name = folder_infos['folder_name'] run_date = folder_infos['date'] result_explorer = ResultsExplorer( OUTPUT_DESTINATION_AWS, - modality='single_table', + modality=modality, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) @@ -144,31 +166,38 @@ def upload_results( local_export_dir = temp_dir Path(local_export_dir).mkdir(parents=True, exist_ok=True) - local_file_path = str(Path(local_export_dir) / RESULT_FILENAME) - s3_key = f'{prefix}{RESULT_FILENAME}' - s3_client.download_file(bucket, s3_key, local_file_path) + local_file_path = str(Path(local_export_dir) / f'[{modality}] {RESULT_FILENAME}') + s3_key = f'{prefix}{modality}/{RESULT_FILENAME}' + try: + s3_client.download_file(bucket, s3_key, local_file_path) + except ClientError as e: + if not e.response['Error']['Code'] == '404': + raise + datas = { 'Wins': summary, f'{run_date}_Detailed_results': results, f'{run_date}_plot_data': df_to_plot, } local_results_writer.write_xlsx(datas, local_file_path) - upload_to_drive((local_file_path), SDGYM_FILE_ID) + upload_to_drive((local_file_path), _extract_google_file_id(MODALITY_TO_GDRIVE_LINK[modality])) s3_client.upload_file(local_file_path, bucket, s3_key) - write_uploaded_marker(s3_client, bucket, prefix, folder_name) + write_uploaded_marker(s3_client, bucket, prefix, folder_name, modality=modality) if temp_dir: shutil.rmtree(temp_dir) def main(): """Main function to upload benchmark results.""" + args = _parse_args() + modality = args.modality aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') folder_infos, s3_client, bucket, prefix = get_result_folder_name_and_s3_vars( - aws_access_key_id, aws_secret_access_key + aws_access_key_id, aws_secret_access_key, modality=modality ) github_env = os.getenv('GITHUB_ENV') - if upload_already_done(s3_client, bucket, prefix, folder_infos['folder_name']): + if upload_already_done(s3_client, bucket, prefix, folder_infos['folder_name'], modality): LOGGER.warning('Benchmark results have already been uploaded. Exiting.') if github_env: with open(github_env, 'a') as env_file: @@ -184,6 +213,7 @@ def main(): bucket, prefix, github_env, + modality, ) diff --git a/sdgym/run_benchmark/utils.py b/sdgym/run_benchmark/utils.py index 25981224..6ab8e371 100644 --- a/sdgym/run_benchmark/utils.py +++ b/sdgym/run_benchmark/utils.py @@ -1,8 +1,9 @@ """Utils file for the run_benchmark module.""" +import argparse import os from datetime import datetime -from urllib.parse import quote_plus +from urllib.parse import parse_qs, quote_plus, urlparse import numpy as np from slack_sdk import WebClient @@ -10,7 +11,6 @@ from sdgym.s3 import parse_s3_path OUTPUT_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/' -UPLOAD_DESTINATION_AWS = 's3://sdgym-benchmark/Benchmarks/' DEBUG_SLACK_CHANNEL = 'sdv-alerts-debug' SLACK_CHANNEL = 'sdv-alerts' KEY_DATE_FILE = '_BENCHMARK_DATES.json' @@ -46,14 +46,22 @@ 'diamond-cross', 'diamond-x', ] +MODALITY_TO_GDRIVE_LINK = { + 'single_table': 'https://docs.google.com/spreadsheets/d/1W3tsGOOtbtTw3g0EVE0irLgY_TN_cy2W4ONiZQ57OPo/edit?usp=sharing', + 'multi_table': 'https://docs.google.com/spreadsheets/d/1srmXx2ddq025hqzAE4JRdebuoBfro_7wbgeUHUkMEMM/edit?usp=sharing', +} # The synthesizers inside the same list will be run by the same ec2 instance -SYNTHESIZERS_SPLIT = [ +SYNTHESIZERS_SPLIT_SINGLE_TABLE = [ ['UniformSynthesizer', 'ColumnSynthesizer', 'GaussianCopulaSynthesizer', 'TVAESynthesizer'], ['CopulaGANSynthesizer'], ['CTGANSynthesizer'], ['RealTabFormerSynthesizer'], ] +SYNTHESIZERS_SPLIT_MULTI_TABLE = [ + ['HMASynthesizer'], + ['HSASynthesizer', 'IndependentSynthesizer', 'MultiTableUniformSynthesizer'], +] def get_result_folder_name(date_str): @@ -91,26 +99,28 @@ def post_slack_message(channel, text): client.chat_postMessage(channel=channel, text=text) -def post_benchmark_launch_message(date_str): +def post_benchmark_launch_message(date_str, compute_service='AWS', modality='single_table'): """Post a message to the SDV Alerts Slack channel when the benchmark is launched.""" channel = SLACK_CHANNEL folder_name = get_result_folder_name(date_str) bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS) - url_link = get_s3_console_link(bucket, f'{prefix}{folder_name}/') - body = 'πŸƒ SDGym benchmark has been launched! EC2 Instances are running. ' + url_link = get_s3_console_link(bucket, f'{prefix}{modality}/{folder_name}/') + modality_text = modality.replace('_', '-') + body = f'πŸƒ SDGym {modality_text} benchmark has been launched on {compute_service}! ' body += f'Intermediate results can be found <{url_link}|here>.\n' post_slack_message(channel, body) -def post_benchmark_uploaded_message(folder_name, commit_url=None): +def post_benchmark_uploaded_message(folder_name, commit_url=None, modality='single_table'): """Post benchmark uploaded message to sdv-alerts slack channel.""" channel = SLACK_CHANNEL bucket, prefix = parse_s3_path(OUTPUT_DESTINATION_AWS) - url_link = get_s3_console_link(bucket, quote_plus(f'{prefix}SDGym Monthly Run.xlsx')) + modality_text = modality.replace('_', '-') + url_link = get_s3_console_link(bucket, quote_plus(f'{prefix}{modality}/SDGym Monthly Run.xlsx')) body = ( - f'πŸ€ΈπŸ»β€β™€οΈ SDGym benchmark results for *{folder_name}* are available! πŸ‹οΈβ€β™€οΈ\n' + f'πŸ€ΈπŸ»β€β™€οΈ SDGym {modality_text} benchmark results for *{folder_name}* are available! πŸ‹οΈβ€β™€οΈ\n' f'Check the results:\n' - f' - On GDrive: <{GDRIVE_LINK}|link>\n' + f' - On GDrive: <{MODALITY_TO_GDRIVE_LINK[modality]}|link>\n' f' - On S3: <{url_link}|link>\n' ) if commit_url: @@ -162,3 +172,27 @@ def get_df_to_plot(benchmark_result): df_to_plot = df_to_plot.rename(columns={'Adjusted_Quality_Score': 'Quality_Score'}) return df_to_plot.drop(columns=['Cumulative Quality Score']).reset_index(drop=True) + + +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--modality', + choices=['single_table', 'multi_table'], + default='single_table', + help='Benchmark modality to run.', + ) + return parser.parse_args() + + +def _extract_google_file_id(google_drive_link: str) -> str: + parsed = urlparse(google_drive_link) + file_id = parse_qs(parsed.query).get('id') + if file_id: + return file_id[0] + + for marker in ('/d/', '/file/d/'): + if marker in parsed.path: + return parsed.path.split(marker, 1)[1].split('/', 1)[0] + + raise ValueError(f'Invalid Google Drive link format: {google_drive_link}') diff --git a/tasks.py b/tasks.py index 741dc3d4..b788827b 100644 --- a/tasks.py +++ b/tasks.py @@ -203,18 +203,18 @@ def rmdir(c, path): pass @task -def run_sdgym_benchmark(c): +def run_sdgym_benchmark(c, modality='single_table'): """Run the SDGym benchmark.""" - c.run('python sdgym/run_benchmark/run_benchmark.py') + c.run(f'python sdgym/run_benchmark/run_benchmark.py --modality {modality}') @task -def upload_benchmark_results(c): +def upload_benchmark_results(c, modality='single_table'): """Upload the benchmark results to S3.""" - c.run(f'python sdgym/run_benchmark/upload_benchmark_results.py') + c.run(f'python sdgym/run_benchmark/upload_benchmark_results.py --modality {modality}') @task -def notify_sdgym_benchmark_uploaded(c, folder_name, commit_url=None): +def notify_sdgym_benchmark_uploaded(c, folder_name, commit_url=None, modality='single_table'): """Notify Slack about the SDGym benchmark upload.""" from sdgym.run_benchmark.utils import post_benchmark_uploaded_message - post_benchmark_uploaded_message(folder_name, commit_url) \ No newline at end of file + post_benchmark_uploaded_message(folder_name, commit_url, modality) diff --git a/tests/integration/result_explorer/test_result_explorer.py b/tests/integration/result_explorer/test_result_explorer.py index 0d2fc0b4..a92e053b 100644 --- a/tests/integration/result_explorer/test_result_explorer.py +++ b/tests/integration/result_explorer/test_result_explorer.py @@ -103,8 +103,8 @@ def test_summarize_multi_table(): # Assert expected_summary = pd.DataFrame({ - 'Synthesizer': ['HMASynthesizer'], - '12_02_2025 - # datasets: 1 - sdgym version: 0.11.2.dev0': [1], + 'Synthesizer': ['HMASynthesizer', 'MultiTableUniformSynthesizer'], + '12_02_2025 - # datasets: 1 - sdgym version: 0.11.2.dev0': [0, 0], }) expected_results = ( pd @@ -115,9 +115,7 @@ def test_summarize_multi_table(): .sort_values(by=['Dataset', 'Synthesizer']) .reset_index(drop=True) ) - expected_results['Win'] = ( - expected_results['Synthesizer'] != 'MultiTableUniformSynthesizer' - ).astype('int64') + expected_results['Win'] = [0, 0] pd.testing.assert_frame_equal(summary, expected_summary) pd.testing.assert_frame_equal(results, expected_results) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 0ebede70..4c740f99 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,10 @@ """Tests for the ``tasks.py`` file.""" -from tasks import _get_extra_dependencies, _get_minimum_versions, _resolve_version_conflicts +from tasks import ( + _get_extra_dependencies, + _get_minimum_versions, + _resolve_version_conflicts, +) def test_get_minimum_versions(): diff --git a/tests/unit/_benchmark/test_benchmark.py b/tests/unit/_benchmark/test_benchmark.py index e67e115b..6ab2870f 100644 --- a/tests/unit/_benchmark/test_benchmark.py +++ b/tests/unit/_benchmark/test_benchmark.py @@ -433,7 +433,7 @@ def test__benchmark_single_table_compute_gcp(mock_benchmark_compute): limit_dataset_size=limit_dataset_size, compute_quality_score=compute_quality_score, compute_diagnostic_score=compute_diagnostic_score, - compute_privacy_score=True, + compute_privacy_score=False, sdmetrics=sdmetrics, timeout=timeout, modality='single_table', @@ -464,7 +464,7 @@ def test__benchmark_single_table_compute_gcp_defaults(mock_benchmark_compute): limit_dataset_size=False, compute_quality_score=True, compute_diagnostic_score=True, - compute_privacy_score=True, + compute_privacy_score=False, sdmetrics=None, timeout=None, modality='single_table', diff --git a/tests/unit/_benchmark/test_config_utils.py b/tests/unit/_benchmark/test_config_utils.py index 73ab6e00..2a2e0a89 100644 --- a/tests/unit/_benchmark/test_config_utils.py +++ b/tests/unit/_benchmark/test_config_utils.py @@ -106,7 +106,7 @@ def test_make_instance_name(mock_datetime, mock_uuid): result = _make_instance_name('sdgym-run') # Assert - assert result == 'sdgym-run-2025_01_15_12:00-abcdef' + assert result == 'sdgym-run-20250115-1200-abcdef' @patch('sdgym._benchmark.config_utils._apply_compute_service_keymap') diff --git a/tests/unit/_benchmark/test_credentials_utils.py b/tests/unit/_benchmark/test_credentials_utils.py index 91c509f2..69422e36 100644 --- a/tests/unit/_benchmark/test_credentials_utils.py +++ b/tests/unit/_benchmark/test_credentials_utils.py @@ -1,8 +1,15 @@ import json +import os +from pathlib import Path +from unittest.mock import patch import pytest -from sdgym._benchmark.credentials_utils import get_credentials, sdv_install_cmd +from sdgym._benchmark.credentials_utils import ( + create_credentials_file, + get_credentials, + sdv_install_cmd, +) def test_get_credentials(tmp_path): @@ -66,3 +73,47 @@ def test_sdv_install_cmd(credentials, expected_cmd): # Assert assert cmd == expected_cmd + + +@patch.dict( + os.environ, + { + 'GCP_SERVICE_ACCOUNT_JSON': json.dumps({ + 'type': 'service_account', + 'project_id': 'my-project', + }), + 'AWS_ACCESS_KEY_ID': 'fake-access-key', + 'AWS_SECRET_ACCESS_KEY': 'fake-secret-key', + 'SDV_ENTERPRISE_USERNAME': 'fake-username', + 'SDV_ENTERPRISE_LICENSE_KEY': 'fake-license', + 'GCP_PROJECT_ID': 'sdgym-337614', + 'GCP_ZONE': 'us-central1-a', + }, +) +def test_create_credentials_file(tmp_path): + """Test the `create_credentials_file` method.""" + # Run + filepath = create_credentials_file() + + # Assert + assert Path(filepath).exists() + with open(filepath, 'r') as f: + data = json.load(f) + + assert data == { + 'aws': { + 'aws_access_key_id': 'fake-access-key', + 'aws_secret_access_key': 'fake-secret-key', + }, + 'gcp': { + 'type': 'service_account', + 'project_id': 'my-project', + 'gcp_project': 'sdgym-337614', + 'gcp_zone': 'us-central1-a', + }, + 'sdv': { + 'username': 'fake-username', + 'license_key': 'fake-license', + }, + } + os.remove(filepath) diff --git a/tests/unit/run_benchmark/test_run_benchmark.py b/tests/unit/run_benchmark/test_run_benchmark.py index aacab84e..07c4e1ea 100644 --- a/tests/unit/run_benchmark/test_run_benchmark.py +++ b/tests/unit/run_benchmark/test_run_benchmark.py @@ -2,10 +2,18 @@ from datetime import datetime, timezone from unittest.mock import Mock, call, patch +import pytest from botocore.exceptions import ClientError -from sdgym.run_benchmark.run_benchmark import append_benchmark_run, main -from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT +from sdgym.run_benchmark.run_benchmark import ( + append_benchmark_run, + main, +) +from sdgym.run_benchmark.utils import ( + OUTPUT_DESTINATION_AWS, + SYNTHESIZERS_SPLIT_MULTI_TABLE, + SYNTHESIZERS_SPLIT_SINGLE_TABLE, +) @patch('sdgym.run_benchmark.run_benchmark.get_s3_client') @@ -47,11 +55,11 @@ def test_append_benchmark_run(mock_get_result_folder_name, mock_parse_s3_path, m mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS) mock_get_result_folder_name.assert_called_once_with(date) mock_s3_client.get_object.assert_called_once_with( - Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json' + Bucket='my-bucket', Key='my-prefix/single_table/_BENCHMARK_DATES.json' ) mock_s3_client.put_object.assert_called_once_with( Bucket='my-bucket', - Key='my-prefix/_BENCHMARK_DATES.json', + Key='my-prefix/single_table/_BENCHMARK_DATES.json', Body=json.dumps(expected_data).encode('utf-8'), ) @@ -91,53 +99,84 @@ def test_append_benchmark_run_new_file( mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS) mock_get_result_folder_name.assert_called_once_with(date) mock_s3_client.get_object.assert_called_once_with( - Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json' + Bucket='my-bucket', Key='my-prefix/single_table/_BENCHMARK_DATES.json' ) mock_s3_client.put_object.assert_called_once_with( Bucket='my-bucket', - Key='my-prefix/_BENCHMARK_DATES.json', + Key='my-prefix/single_table/_BENCHMARK_DATES.json', Body=json.dumps(expected_data).encode('utf-8'), ) -@patch('sdgym.run_benchmark.run_benchmark.benchmark_single_table_aws') -@patch('sdgym.run_benchmark.run_benchmark.os.getenv') -@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run') +@pytest.mark.parametrize( + 'modality,synthesizer_split', + [ + ('single_table', SYNTHESIZERS_SPLIT_SINGLE_TABLE), + ('multi_table', SYNTHESIZERS_SPLIT_MULTI_TABLE), + ], +) @patch('sdgym.run_benchmark.run_benchmark.post_benchmark_launch_message') +@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run') +@patch('sdgym.run_benchmark.run_benchmark.os.getenv') +@patch('sdgym.run_benchmark.run_benchmark._parse_args') +@patch.dict( + 'sdgym.run_benchmark.run_benchmark.MODALITY_TO_SETUP', + values={ + 'single_table': { + 'method': Mock(name='mock_single_method'), + 'synthesizers_split': [], + }, + 'multi_table': { + 'method': Mock(name='mock_multi_method'), + 'synthesizers_split': [], + }, + }, + clear=True, +) def test_main( - mock_post_benchmark_launch_message, - mock_append_benchmark_run, + mock_parse_args, mock_getenv, - mock_benchmark_single_table_aws, + mock_append_benchmark_run, + mock_post_benchmark_launch_message, + modality, + synthesizer_split, ): - """Test the `main` method.""" + """Test the `main` function with both single_table and multi_table modalities.""" # Setup - mock_getenv.side_effect = ['my_access_key', 'my_secret_key'] + from sdgym.run_benchmark.run_benchmark import MODALITY_TO_SETUP + + mock_parse_args.return_value = Mock(modality=modality) + mock_getenv.side_effect = lambda key: { + 'AWS_ACCESS_KEY_ID': 'my_access_key', + 'AWS_SECRET_ACCESS_KEY': 'my_secret_key', + 'CREDENTIALS_FILEPATH': '/path/to/creds.json', + }.get(key) + MODALITY_TO_SETUP[modality]['synthesizers_split'] = synthesizer_split + mock_method = MODALITY_TO_SETUP[modality]['method'] date = datetime.now(timezone.utc).strftime('%Y-%m-%d') # Run main() # Assert - mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID') - mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY') - expected_calls = [] - for synthesizer in SYNTHESIZERS_SPLIT: - expected_calls.append( - call( - output_destination=OUTPUT_DESTINATION_AWS, - aws_access_key_id='my_access_key', - aws_secret_access_key='my_secret_key', - synthesizers=synthesizer, - compute_privacy_score=False, - timeout=345600, - ) + expected_calls = [ + call( + output_destination=OUTPUT_DESTINATION_AWS, + credential_filepath='/path/to/creds.json', + synthesizers=group, + timeout=345600, ) - - mock_benchmark_single_table_aws.assert_has_calls(expected_calls) + for group in synthesizer_split + ] + mock_method.assert_has_calls(expected_calls) mock_append_benchmark_run.assert_called_once_with( 'my_access_key', 'my_secret_key', date, + modality=modality, + ) + mock_post_benchmark_launch_message.assert_called_once_with( + date, + compute_service='GCP', + modality=modality, ) - mock_post_benchmark_launch_message.assert_called_once_with(date) diff --git a/tests/unit/run_benchmark/test_upload_benchmark_result.py b/tests/unit/run_benchmark/test_upload_benchmark_result.py index 3c776ebd..d0d292b1 100644 --- a/tests/unit/run_benchmark/test_upload_benchmark_result.py +++ b/tests/unit/run_benchmark/test_upload_benchmark_result.py @@ -7,7 +7,6 @@ from botocore.exceptions import ClientError from sdgym.run_benchmark.upload_benchmark_results import ( - SDGYM_FILE_ID, get_result_folder_name_and_s3_vars, main, upload_already_done, @@ -25,13 +24,16 @@ def test_write_uploaded_marker(): bucket = 'test-bucket' prefix = 'test-prefix/' run_name = 'test_run' + modality = 'single_table' # Run write_uploaded_marker(s3_client, bucket, prefix, run_name) # Assert s3_client.put_object.assert_called_once_with( - Bucket=bucket, Key=f'{prefix}{run_name}/upload_complete.marker', Body=b'Upload complete' + Bucket=bucket, + Key=f'{prefix}{modality}/{run_name}/upload_complete.marker', + Body=b'Upload complete', ) @@ -79,9 +81,9 @@ def test_get_result_folder_name_and_s3_vars( # Setup aws_access_key_id = 'my_access_key' aws_secret_access_key = 'my_secret_key' - expected_result = ('SDGym_results_10_01_2023', 's3_client', 'bucket', 'prefix') + expected_result = ('SDGym_results_10_01_2023', 's3_client', 'bucket', 'prefix/') mock_boto_client.return_value = 's3_client' - mock_parse_s3_path.return_value = ('bucket', 'prefix') + mock_parse_s3_path.return_value = ('bucket', 'prefix/') mock_get_latest_run_from_file.return_value = 'SDGym_results_10_01_2023' # Run @@ -97,7 +99,7 @@ def test_get_result_folder_name_and_s3_vars( ) mock_parse_s3_path.assert_called_once_with(mock_output_destination_aws) mock_get_latest_run_from_file.assert_called_once_with( - 's3_client', 'bucket', 'prefix_BENCHMARK_DATES.json' + 's3_client', 'bucket', 'prefix/single_table/_BENCHMARK_DATES.json' ) @@ -162,7 +164,9 @@ def test_upload_to_drive_file_not_found(tmp_path): @patch('sdgym.run_benchmark.upload_benchmark_results.os.environ.get') @patch('sdgym.run_benchmark.upload_benchmark_results.get_df_to_plot') @patch('sdgym.run_benchmark.upload_benchmark_results.upload_to_drive') +@patch('sdgym.run_benchmark.upload_benchmark_results._extract_google_file_id') def test_upload_results( + mock_extract_google_file_id, mock_upload_to_drive, mock_get_df_to_plot, mock_os_environ_get, @@ -191,7 +195,8 @@ def test_upload_results( '10_01_2023_Detailed_results': 'results', '10_01_2023_plot_data': 'df_to_plot', } - local_path = str(Path('/tmp/sdgym_results/SDGym Monthly Run.xlsx')) + local_path = str(Path('/tmp/sdgym_results/[single_table] SDGym Monthly Run.xlsx')) + mock_extract_google_file_id.return_value = 'google_file_id' # Run upload_results( @@ -205,7 +210,7 @@ def test_upload_results( ) # Assert - mock_upload_to_drive.assert_called_once_with(local_path, SDGYM_FILE_ID) + mock_upload_to_drive.assert_called_once_with(local_path, 'google_file_id') mock_logger.info.assert_called_once_with( f'Run {run_name} is complete! Proceeding with summarization...' ) @@ -217,7 +222,9 @@ def test_upload_results( ) result_explorer_instance.all_runs_complete.assert_called_once_with(run_name) result_explorer_instance.summarize.assert_called_once_with(run_name) - mock_write_uploaded_marker.assert_called_once_with(s3_client, bucket, prefix, run_name) + mock_write_uploaded_marker.assert_called_once_with( + s3_client, bucket, prefix, run_name, modality='single_table' + ) mock_local_results_writer.return_value.write_xlsx.assert_called_once_with(datas, local_path) mock_get_df_to_plot.assert_called_once_with('results') @@ -275,7 +282,9 @@ def test_upload_results_not_all_runs_complete( @patch('sdgym.run_benchmark.upload_benchmark_results.upload_already_done') @patch('sdgym.run_benchmark.upload_benchmark_results.LOGGER') @patch('sdgym.run_benchmark.upload_benchmark_results.os.getenv') +@patch('sdgym.run_benchmark.upload_benchmark_results._parse_args') def test_main_already_upload( + mock_parse_args, mock_getenv, mock_logger, mock_upload_already_done, @@ -284,6 +293,7 @@ def test_main_already_upload( ): """Test the `method` when results are already uploaded.""" # Setup + mock_parse_args.return_value = Mock(modality='single_table') mock_getenv.side_effect = ['my_access_key', 'my_secret_key', None] folder_infos = {'folder_name': 'SDGym_results_10_01_2023', 'date': '10_01_2023'} mock_get_result_folder_name_and_s3_vars.return_value = ( @@ -300,8 +310,9 @@ def test_main_already_upload( main() # Assert + mock_parse_args.assert_called_once() mock_get_result_folder_name_and_s3_vars.assert_called_once_with( - 'my_access_key', 'my_secret_key' + 'my_access_key', 'my_secret_key', modality='single_table' ) mock_logger.warning.assert_called_once_with(expected_log_message) mock_upload_results.assert_not_called() @@ -311,7 +322,9 @@ def test_main_already_upload( @patch('sdgym.run_benchmark.upload_benchmark_results.upload_results') @patch('sdgym.run_benchmark.upload_benchmark_results.upload_already_done') @patch('sdgym.run_benchmark.upload_benchmark_results.os.getenv') +@patch('sdgym.run_benchmark.upload_benchmark_results._parse_args') def test_main( + mock_parse_args, mock_getenv, mock_upload_already_done, mock_upload_results, @@ -319,6 +332,7 @@ def test_main( ): """Test the `main` method.""" # Setup + mock_parse_args.return_value = Mock(modality='single_table') mock_getenv.side_effect = ['my_access_key', 'my_secret_key', None] folder_infos = {'folder_name': 'SDGym_results_10_11_2024', 'date': '10_11_2024'} mock_get_result_folder_name_and_s3_vars.return_value = ( @@ -334,11 +348,18 @@ def test_main( # Assert mock_get_result_folder_name_and_s3_vars.assert_called_once_with( - 'my_access_key', 'my_secret_key' + 'my_access_key', 'my_secret_key', modality='single_table' ) mock_upload_already_done.assert_called_once_with( - 's3_client', 'bucket', 'prefix', folder_infos['folder_name'] + 's3_client', 'bucket', 'prefix', folder_infos['folder_name'], 'single_table' ) mock_upload_results.assert_called_once_with( - 'my_access_key', 'my_secret_key', folder_infos, 's3_client', 'bucket', 'prefix', None + 'my_access_key', + 'my_secret_key', + folder_infos, + 's3_client', + 'bucket', + 'prefix', + None, + 'single_table', ) diff --git a/tests/unit/run_benchmark/test_utils.py b/tests/unit/run_benchmark/test_utils.py index a7309e03..2894f747 100644 --- a/tests/unit/run_benchmark/test_utils.py +++ b/tests/unit/run_benchmark/test_utils.py @@ -4,9 +4,9 @@ import pytest from sdgym.run_benchmark.utils import ( - GDRIVE_LINK, + MODALITY_TO_GDRIVE_LINK, OUTPUT_DESTINATION_AWS, - SLACK_CHANNEL, + _extract_google_file_id, _get_slack_client, get_df_to_plot, get_result_folder_name, @@ -95,7 +95,7 @@ def test_post_benchmark_launch_message( url = 'https://s3.console.aws.amazon.com/' mock_get_s3_console_link.return_value = url expected_body = ( - 'πŸƒ SDGym benchmark has been launched! EC2 Instances are running. ' + 'πŸƒ SDGym single-table benchmark has been launched on AWS! ' f'Intermediate results can be found <{url}|here>.\n' ) # Run @@ -104,8 +104,10 @@ def test_post_benchmark_launch_message( # Assert mock_get_result_folder_name.assert_called_once_with(date_str) mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS) - mock_get_s3_console_link.assert_called_once_with('my-bucket', f'my-prefix/{folder_name}/') - mock_post_slack_message.assert_called_once_with(SLACK_CHANNEL, expected_body) + mock_get_s3_console_link.assert_called_once_with( + 'my-bucket', f'my-prefix/single_table/{folder_name}/' + ) + mock_post_slack_message.assert_called_once_with('sdv-alerts', expected_body) @patch('sdgym.run_benchmark.utils.post_slack_message') @@ -123,9 +125,9 @@ def test_post_benchmark_uploaded_message( url = 'https://s3.console.aws.amazon.com/' mock_get_s3_console_link.return_value = url expected_body = ( - f'πŸ€ΈπŸ»β€β™€οΈ SDGym benchmark results for *{folder_name}* are available! πŸ‹οΈβ€β™€οΈ\n' + f'πŸ€ΈπŸ»β€β™€οΈ SDGym single-table benchmark results for *{folder_name}* are available! πŸ‹οΈβ€β™€οΈ\n' f'Check the results:\n' - f' - On GDrive: <{GDRIVE_LINK}|link>\n' + f' - On GDrive: <{MODALITY_TO_GDRIVE_LINK["single_table"]}|link>\n' f' - On S3: <{url}|link>\n' ) @@ -133,10 +135,10 @@ def test_post_benchmark_uploaded_message( post_benchmark_uploaded_message(folder_name) # Assert - mock_post_slack_message.assert_called_once_with(SLACK_CHANNEL, expected_body) + mock_post_slack_message.assert_called_once_with('sdv-alerts', expected_body) mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS) mock_get_s3_console_link.assert_called_once_with( - 'my-bucket', 'my-prefix%2FSDGym+Monthly+Run.xlsx' + 'my-bucket', 'my-prefix%2Fsingle_table%2FSDGym+Monthly+Run.xlsx' ) @@ -156,9 +158,9 @@ def test_post_benchmark_uploaded_message_with_commit( url = 'https://s3.console.aws.amazon.com/' mock_get_s3_console_link.return_value = url expected_body = ( - f'πŸ€ΈπŸ»β€β™€οΈ SDGym benchmark results for *{folder_name}* are available! πŸ‹οΈβ€β™€οΈ\n' + f'πŸ€ΈπŸ»β€β™€οΈ SDGym single-table benchmark results for *{folder_name}* are available! πŸ‹οΈβ€β™€οΈ\n' f'Check the results:\n' - f' - On GDrive: <{GDRIVE_LINK}|link>\n' + f' - On GDrive: <{MODALITY_TO_GDRIVE_LINK["single_table"]}|link>\n' f' - On S3: <{url}|link>\n' f' - On GitHub: <{commit_url}|link>\n' ) @@ -167,10 +169,10 @@ def test_post_benchmark_uploaded_message_with_commit( post_benchmark_uploaded_message(folder_name, commit_url) # Assert - mock_post_slack_message.assert_called_once_with(SLACK_CHANNEL, expected_body) + mock_post_slack_message.assert_called_once_with('sdv-alerts', expected_body) mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS) mock_get_s3_console_link.assert_called_once_with( - 'my-bucket', 'my-prefix%2FSDGym+Monthly+Run.xlsx' + 'my-bucket', 'my-prefix%2Fsingle_table%2FSDGym+Monthly+Run.xlsx' ) @@ -203,3 +205,31 @@ def test_get_df_to_plot(): 'Marker': ['circle', 'square', 'diamond'], }) pd.testing.assert_frame_equal(result, expected_result) + + +@pytest.mark.parametrize( + 'url', + [ + 'https://drive.google.com/file/d/1A2B3C4D5E6F7G8H9I0J/view?usp=sharing', + 'https://drive.google.com/open?id=1A2B3C4D5E6F7G8H9I0J', + 'https://docs.google.com/uc?id=1A2B3C4D5E6F7G8H9I0J&export=download', + ], +) +def test_extract_google_file_id(url): + """Test the `_extract_google_file_id` method.""" + # Run + file_id = _extract_google_file_id(url) + + # Assert + assert file_id == '1A2B3C4D5E6F7G8H9I0J' + + +def test_extract_google_file_id_invalid_url(): + """Test the `_extract_google_file_id` method with an invalid URL.""" + # Setup + invalid_url = 'https://example.com/some/invalid/url' + expected_message = 'Invalid Google Drive link format: https://example.com/some/invalid/url' + + # Run and Assert + with pytest.raises(ValueError, match=expected_message): + _extract_google_file_id(invalid_url)