Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/run_benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,19 @@ jobs:
license_key='${LICENSE_KEY}',
)
"
python -m pip install "sdgym[all]"
python -m pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-532-define-yaml-files"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Revert before merging


echo "VIRTUAL_ENV=$(pwd)/venv" >> $GITHUB_ENV
echo "$(pwd)/venv/bin" >> $GITHUB_PATH

- name: Run SDGym Benchmark
env:
GCP_SERVICE_ACCOUNT_JSON: ${{ secrets.GCP_SERVICE_ACCOUNT_JSON }}
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
GCP_ZONE: ${{ secrets.GCP_ZONE }}
SDV_ENTERPRISE_USERNAME: ${{ secrets.SDV_ENTERPRISE_USERNAME }}
SDV_ENTERPRISE_LICENSE_KEY: ${{ secrets.SDV_ENTERPRISE_LICENSE_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
run: |
export CREDENTIALS_FILEPATH=$(python -c "from sdgym._benchmark.credentials_utils import create_credentials_file; print(create_credentials_file())")
invoke run-sdgym-benchmark --modality "${{ inputs.modality }}"
rm -f "$CREDENTIALS_FILEPATH"
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ namespaces = false
'make.bat',
'*.jpg',
'*.png',
'*.gif'
'*.gif',
'*.yaml'
]
'sdgym' = [
'leaderboard.csv',
'synthesizer_descriptions.yaml'
'synthesizer_descriptions.yaml',
'_benchmark_launcher/*.yaml',
]

[tool.setuptools.exclude-package-data]
Expand Down
21 changes: 10 additions & 11 deletions sdgym/_benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
resolve_compute_config,
validate_compute_config,
)
from sdgym._benchmark.credentials_utils import get_credentials, sdv_install_cmd
from sdgym._benchmark.credentials_utils import sdv_install_cmd
from sdgym.benchmark import (
DEFAULT_MULTI_TABLE_DATASETS,
DEFAULT_MULTI_TABLE_SYNTHESIZERS,
Expand Down Expand Up @@ -349,7 +349,7 @@ def _run_on_gcp(

def _benchmark_compute_gcp(
output_destination,
credential_filepath,
credentials,
compute_config,
synthesizers,
sdv_datasets,
Expand All @@ -364,7 +364,6 @@ def _benchmark_compute_gcp(
):
"""Run the SDGym benchmark on datasets for the given modality."""
compute_config = resolve_compute_config('gcp', compute_config)
credentials = get_credentials(credential_filepath)
validate_compute_config(compute_config)

s3_client = _validate_output_destination(
Expand Down Expand Up @@ -419,7 +418,7 @@ def _benchmark_compute_gcp(

def _benchmark_single_table_compute_gcp(
output_destination,
credential_filepath,
credentials,
compute_config=None,
synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS,
sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS,
Expand All @@ -436,8 +435,8 @@ def _benchmark_single_table_compute_gcp(
Args:
output_destination (str):
The S3 URI where results will be stored.
credential_filepath (str or Path):
Path to the credentials file for AWS, GCP and SDV-Enterprise.
credentials (dict):
The credentials for AWS, GCP and SDV-Enterprise.
compute_config (dict, optional):
The compute configuration for the GCP instance. If None, default settings will be used.
synthesizers (list of dict, optional):
Expand All @@ -461,7 +460,7 @@ def _benchmark_single_table_compute_gcp(
"""
return _benchmark_compute_gcp(
output_destination=output_destination,
credential_filepath=credential_filepath,
credentials=credentials,
compute_config=compute_config,
synthesizers=synthesizers,
sdv_datasets=sdv_datasets,
Expand All @@ -478,7 +477,7 @@ def _benchmark_single_table_compute_gcp(

def _benchmark_multi_table_compute_gcp(
output_destination,
credential_filepath,
credentials,
compute_config=None,
synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
Expand All @@ -494,8 +493,8 @@ def _benchmark_multi_table_compute_gcp(
Args:
output_destination (str):
The S3 URI where results will be stored.
credential_filepath (str or Path):
Path to the credentials file for AWS, GCP and SDV-Enterprise.
credentials (dict):
The credentials for AWS, GCP and SDV-Enterprise.
compute_config (dict, optional):
The compute configuration for the GCP instance. If None, default settings will be used.
synthesizers (list of dict, optional):
Expand All @@ -517,7 +516,7 @@ def _benchmark_multi_table_compute_gcp(
"""
return _benchmark_compute_gcp(
output_destination=output_destination,
credential_filepath=credential_filepath,
credentials=credentials,
compute_config=compute_config,
synthesizers=synthesizers,
sdv_datasets=sdv_datasets,
Expand Down
8 changes: 4 additions & 4 deletions sdgym/_benchmark/credentials_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def get_credentials(credential_filepath):
def sdv_install_cmd(credentials):
"""Return the shell command to install sdv-enterprise using sdv-installer."""
sdv_creds = credentials.get('sdv') or {}
username = sdv_creds.get('username')
license_key = sdv_creds.get('license_key')
username = sdv_creds.get('sdv_enterprise_username')
license_key = sdv_creds.get('sdv_enterprise_license_key')
if not (username and license_key):
return ''

Expand All @@ -92,8 +92,8 @@ def create_credentials_file():
'gcp_zone': os.getenv('GCP_ZONE'),
},
'sdv': {
'username': os.getenv('SDV_ENTERPRISE_USERNAME'),
'license_key': os.getenv('SDV_ENTERPRISE_LICENSE_KEY'),
'sdv_enterprise_username': os.getenv('SDV_ENTERPRISE_USERNAME'),
'sdv_enterprise_license_key': os.getenv('SDV_ENTERPRISE_LICENSE_KEY'),
},
}

Expand Down
6 changes: 6 additions & 0 deletions sdgym/_benchmark_launcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Benchmark Launcher Module."""

from sdgym._benchmark_launcher.benchmark_config import BenchmarkConfig
from sdgym._benchmark_launcher.benchmark_launcher import BenchmarkLauncher

__all__ = ('BenchmarkConfig', 'BenchmarkLauncher')
193 changes: 193 additions & 0 deletions sdgym/_benchmark_launcher/_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from sdgym._benchmark_launcher.utils import (
_AWS_CREDENTIAL_KEYS,
_GCP_SERVICE_ACCOUNT_REQUIRED_KEYS,
resolve_credentials,
)

_INJECTED_PARAMS = {
'credentials',
'synthesizers',
'sdv_datasets',
'compute_config',
'output_destination',
}


def _as_errors(value):
if value is None:
return []
if isinstance(value, list):
return [str(v) for v in value if v]

return [str(value)]


def _format_sectioned_errors(section_errors):
parts = ['BenchmarkConfig validation failed:\n']
for section, raw in section_errors.items():
errs = _as_errors(raw)
if not errs:
continue
parts.append(f'[{section}]')
parts.extend([f'- {e}' for e in errs])
parts.append('')

return '\n'.join(parts).rstrip()


def _validate_structure(config):
errors = []
if config.modality not in ('single_table', 'multi_table'):
errors.append(
f"modality: must be 'single_table' or 'multi_table'. Found: {config.modality!r}"
)

if config.credentials_filepath is not None and not isinstance(config.credentials_filepath, str):
errors.append('credentials_filepath must be a string or None.')

expected_types = {
'method_params': dict,
'compute': dict,
'instance_jobs': list,
}
for key, expected_type in expected_types.items():
value = getattr(config, key, None)
if value is None:
errors.append(f'{key}: is a required section but missing.')
elif not isinstance(value, expected_type):
errors.append(f'{key}: must be a {expected_type.__name__}. Found: {type(value)}')

compute = getattr(config, 'compute', None)
if isinstance(compute, dict):
service = compute.get('service')
if service not in ('gcp',):
errors.append(f"compute.service: must be 'gcp'. Found: {service!r}")

return sorted(errors)


def _validate_method_params(method_params, method_to_run):
errors = []
timeout = method_params.get('timeout')
if timeout is not None:
if not isinstance(timeout, int):
errors.append(
f'method_params.timeout: must be int seconds. Found: {timeout!r} ({type(timeout)})'
)
elif timeout <= 0:
errors.append('method_params.timeout: must be > 0.')

for key in ('compute_quality_score', 'compute_diagnostic_score', 'compute_privacy_score'):
value = method_params.get(key)
if value is not None and not isinstance(value, bool):
errors.append(f'method_params.{key}: must be bool. Found: {value!r} ({type(value)})')

illegal = _INJECTED_PARAMS & set(method_params)
if illegal:
errors.append(
f'method_params: must not define injected parameters {sorted(illegal)} '
f'(resolved from credentials/instance_jobs).'
)

return errors


def _validate_instance_jobs(instance_jobs):
error_message = (
"Each job in 'instance_jobs' must be a dict with an 'output_destination' (string), "
"'synthesizers' (list of strings), and 'datasets' (list of strings or dict with "
"'include' and optional 'exclude')."
)
invalid_jobs = []
for job in instance_jobs:
if not isinstance(job, dict):
invalid_jobs.append(job)
continue

if 'datasets' not in job or 'synthesizers' not in job or 'output_destination' not in job:
invalid_jobs.append(job)
continue

synthesizers = job['synthesizers']
if not isinstance(synthesizers, list) or not all(isinstance(s, str) for s in synthesizers):
invalid_jobs.append(job)
continue

output_destination = job['output_destination']
if not isinstance(output_destination, str) or not output_destination:
invalid_jobs.append(job)
continue

datasets = job['datasets']
if isinstance(datasets, list):
if not all(isinstance(d, str) for d in datasets):
invalid_jobs.append(job)
continue

if isinstance(datasets, dict):
include = datasets.get('include')
exclude = datasets.get('exclude')
if not isinstance(include, list) or not all(isinstance(d, str) for d in include):
invalid_jobs.append(job)
continue

if exclude is not None and (
not isinstance(exclude, list) or not all(isinstance(d, str) for d in exclude)
):
invalid_jobs.append(job)
continue

invalid_jobs.append(job)

if not invalid_jobs:
return []

invalid_jobs_str = '\n'.join(str(job) for job in invalid_jobs)

return [f'{error_message}\nInvalid jobs:\n{invalid_jobs_str}']


def _validate_resolved_credentials(credentials):
errors = []
aws = credentials.get('aws', {})
if not isinstance(aws, dict):
errors.append("credentials['aws'] must be a dict.")
else:
if any(aws.values()):
for key in _AWS_CREDENTIAL_KEYS:
key = key.lower()
if aws.get(key) in (None, ''):
errors.append(f"credentials['aws']['{key}'] is missing or empty.")

sdv = credentials.get('sdv_enterprise', {})
if not isinstance(sdv, dict):
errors.append("credentials['sdv_enterprise'] must be a dict.")
else:
username = sdv.get('sdv_enterprise_username')
license_key = sdv.get('sdv_enterprise_license_key')
message = (
"credentials['sdv_enterprise'] require both 'sdv_enterprise_username' and "
"'sdv_enterprise_license_key' to be provided and non-empty if any SDV Enterprise"
' credential is provided.'
)
if bool(username) != bool(license_key):
errors.append(message)

gcp = credentials.get('gcp', {})
if not isinstance(gcp, dict):
errors.append("credentials['gcp'] must be a dict.")
else:
if gcp:
for key in _GCP_SERVICE_ACCOUNT_REQUIRED_KEYS:
if gcp.get(key) in (None, ''):
errors.append(f"credentials['gcp']['{key}'] is missing or empty.")

return sorted(errors)


def _validate_credentials(credentials_filepath):
if credentials_filepath is not None and not isinstance(credentials_filepath, str):
return ['credentials_filepath: must be a string path to the credentials file or None.']

credentials = resolve_credentials(credentials_filepath)
return _validate_resolved_credentials(credentials)
10 changes: 10 additions & 0 deletions sdgym/_benchmark_launcher/benchmark_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
method_params:
timeout: 345600
compute_quality_score: true
compute_diagnostic_score: true
compute_privacy_score: false

compute:
service: 'gcp'
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, we only indicate which compute service to use. In a future issue, we will move the compute config defined here

Inside this yaml file also


credentials_filepath: null
Loading
Loading