Skip to content
156 changes: 96 additions & 60 deletions compute_worker/compute_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import json
import os
import traceback
import shutil
import signal
import socket
Expand All @@ -18,7 +19,6 @@
from rich.progress import Progress
from rich.pretty import pprint
import requests

import websockets
import yaml
from billiard.exceptions import SoftTimeLimitExceeded
Expand Down Expand Up @@ -89,22 +89,31 @@

def show_progress(line, progress):
try:
if "Status: Image is up to date" in line["status"]:
logger.info(line["status"])
status = line.get("status") or ""
layer_id = line.get("id")
detail = line.get("progressDetail") or {}
current = detail.get("current")
total = detail.get("total")

if "Status: Image is up to date" in status:
logger.info(status)

if not layer_id:
return

completed = False
if line["status"] == "Download complete":
if status == "Download complete":
description = (
f"[blue][Download complete, waiting for extraction {line['id']}]"
f"[blue][Download complete, waiting for extraction {layer_id}]"
)
completed = True
elif line["status"] == "Downloading":
description = f"[bold][Downloading {line['id']}]"
elif line["status"] == "Pull complete":
description = f"[green][Extraction complete {line['id']}]"
elif status == "Downloading":
description = f"[bold][Downloading {layer_id}]"
elif status == "Pull complete":
description = f"[green][Extraction complete {layer_id}]"
completed = True
elif line["status"] == "Extracting":
description = f"[blue][Extracting {line['id']}]"
elif status == "Extracting":
description = f"[blue][Extracting {layer_id}]"

else:
# skip other statuses, but show extraction progress
Expand All @@ -121,7 +130,7 @@ def show_progress(line, progress):
)
else:
tasks[task_id] = progress.add_task(
description, total=line["progressDetail"]["total"]
description, total=total
)
else:
if completed:
Expand All @@ -134,12 +143,12 @@ def show_progress(line, progress):
else:
progress.update(
tasks[task_id],
completed=line["progressDetail"]["current"],
total=line["progressDetail"]["total"],
completed=current,
total=total,
)
except Exception as e:
logger.error("There was an error showing the progress bar")
logger.error(e)
if os.environ.get("LOG_LEVEL", "info").lower() == "debug":
logger.exception("There was an error showing the progress bar")


# -----------------------------------------------
Expand Down Expand Up @@ -242,20 +251,25 @@ def rewrite_bundle_url_if_needed(url):
def run_wrapper(run_args):
logger.info(f"Received run arguments: \n {colorize_run_args(json.dumps(run_args))}")
run = Run(run_args)

try:
run.prepare()
run.start()
if run.is_scoring:
run.push_scores()
run.push_output()
except DockerImagePullException as e:
run._update_status(STATUS_FAILED, str(e))
except SubmissionException as e:
run._update_status(STATUS_FAILED, str(e))
except SoftTimeLimitExceeded:
run._update_status(STATUS_FAILED, "Soft time limit exceeded!")
except (DockerImagePullException, SubmissionException, SoftTimeLimitExceeded) as e:
run._update_status(STATUS_FAILED, traceback.format_exc())
raise
except Exception as e:
# Catch any exception to avoid getting stuck in Running status
run._update_status(STATUS_FAILED, traceback.format_exc())
raise
finally:
try:
# Try to push logs before cleanup
run.push_logs()
except Exception:
logger.exception("push_logs failed")
run.clean_up()


Expand Down Expand Up @@ -444,6 +458,22 @@ async def watch_detailed_results(self):
if file_path:
await self.send_detailed_results(file_path)

def push_logs(self):
"""Upload any collected logs, even in case of crash.
"""
try:
for kind, logs in (self.logs or {}).items():
for stream_key in ("stdout", "stderr"):
entry = logs.get(stream_key) if isinstance(logs, dict) else None
if not entry:
continue
location = entry.get("location")
data = entry.get("data") or b""
if location:
self._put_file(location, raw_data=data)
except Exception as e:
logger.exception(f"Failed best-effort log upload: {e}")

def get_detailed_results_file_path(self):
default_detailed_results_path = os.path.join(
self.output_dir, "detailed_results.html"
Expand All @@ -465,7 +495,7 @@ async def send_detailed_results(self, file_path):
)
websocket_url = f"{self.websocket_url}?kind=detailed_results"
logger.info(f"Connecting to {websocket_url} for detailed results")
# Wrap this with a Try ... Except otherwise a failure here will make the submission get stuck on Running
# Wrap this with a Try block to avoid getting stuck on Running
try:
websocket = await asyncio.wait_for(
websockets.connect(websocket_url), timeout=30.0
Expand All @@ -478,14 +508,8 @@ async def send_detailed_results(self, file_path):
)
)
except Exception as e:
logger.error(
f"This error might result in a Execution Time Exceeded error: {e}"
)
if os.environ.get("LOG_LEVEL", "info").lower() == "debug":
logger.exception(e)
raise SubmissionException(
"Could not connect to instance to update detailed result"
)
logger.exception(e)
return

def _get_stdout_stderr_file_names(self, run_args):
# run_args should be the run_args argument passed to __init__ from the run_wrapper.
Expand All @@ -511,7 +535,7 @@ def _update_submission(self, data):

logger.info(f"Updating submission @ {url} with data = {data}")

resp = self.requests_session.patch(url, data, timeout=150)
resp = self.requests_session.patch(url, data=data, timeout=150)
if resp.status_code == 200:
logger.info("Submission updated successfully!")
else:
Expand All @@ -521,23 +545,17 @@ def _update_submission(self, data):
raise SubmissionException("Failure updating submission data.")

def _update_status(self, status, extra_information=None):
# Update submission status
if status not in AVAILABLE_STATUSES:
raise SubmissionException(
f"Status '{status}' is not in available statuses: {AVAILABLE_STATUSES}"
)

data = {
"status": status,
"status_details": extra_information,
}

# TODO: figure out if we should pull this task code later(submission.task should always be set)
# When we start
# if status == STATUS_SCORING:
# data.update({
# "task_pk": self.task_pk,
# })
self._update_submission(data)
data = {"status": status, "status_details": extra_information}
try:
self._update_submission(data)
except Exception as e:
# Always catch exception and never raise error
logger.exception(f"Failed to update submission status to {status}: {e}")

def _get_container_image(self, image_name):
logger.info("Running pull for image: {}".format(image_name))
Expand All @@ -547,6 +565,8 @@ def _get_container_image(self, image_name):
with Progress() as progress:
resp = client.pull(image_name, stream=True, decode=True)
for line in resp:
if isinstance(line, dict) and line.get("error"):
raise DockerImagePullException(line["error"])
show_progress(line, progress)
break # Break if the loop is successful to exit "with Progress() as progress"

Expand Down Expand Up @@ -684,6 +704,7 @@ async def _run_container_engine_cmd(self, container, kind):

# Create a websocket to send the logs in real time to the codabench instance
# We need to set a timeout for the websocket connection otherwise the program will get stuck if he websocket does not connect.
websocket = None
try:
websocket_url = f"{self.websocket_url}?kind={kind}"
logger.debug(
Expand Down Expand Up @@ -733,18 +754,20 @@ async def _run_container_engine_cmd(self, container, kind):
if str(log[0]) != "None":
logger.info(log[0].decode())
try:
await websocket.send(
json.dumps({"kind": kind, "message": log[0].decode()})
)
if websocket is not None:
await websocket.send(
json.dumps({"kind": kind, "message": log[0].decode()})
)
except Exception as e:
logger.error(e)

elif str(log[1]) != "None":
logger.error(log[1].decode())
try:
await websocket.send(
json.dumps({"kind": kind, "message": log[1].decode()})
)
if websocket is not None:
await websocket.send(
json.dumps({"kind": kind, "message": log[1].decode()})
)
except Exception as e:
logger.error(e)

Expand All @@ -765,7 +788,8 @@ async def _run_container_engine_cmd(self, container, kind):
logger.debug(
f"WORKER_MARKER: Disconnecting from {websocket_url}, program counter = {self.completed_program_counter}"
)
await websocket.close()
if websocket is not None:
await websocket.close()
client.remove_container(container, force=True)

logger.debug(
Expand All @@ -783,6 +807,13 @@ async def _run_container_engine_cmd(self, container, kind):
logger.error(e)
return_Code = {"StatusCode": 1}

finally:
try:
# Last chance of removing container
client.remove_container(container_id, force=True)
except Exception:
pass

self.logs[kind] = {
"returncode": return_Code["StatusCode"],
"start": start,
Expand Down Expand Up @@ -1053,9 +1084,8 @@ async def _run_program_directory(self, program_dir, kind):
try:
return await self._run_container_engine_cmd(container, kind=kind)
except Exception as e:
logger.error(e)
if os.environ.get("LOG_LEVEL", "info").lower() == "debug":
logger.exception(e)
logger.exception("Program directory execution failed")
raise SubmissionException(str(e))

def _put_dir(self, url, directory):
"""Zip the directory and send it to the given URL using _put_file."""
Expand Down Expand Up @@ -1097,7 +1127,7 @@ def _put_file(self, url, file=None, raw_data=None, content_type="application/zip
logger.info("Putting file %s in %s" % (file, url))
data = open(file, "rb")
headers["Content-Length"] = str(os.path.getsize(file))
elif raw_data:
elif raw_data is not None:
logger.info("Putting raw data %s in %s" % (raw_data, url))
data = raw_data
else:
Expand Down Expand Up @@ -1183,21 +1213,23 @@ def start(self):

logger.info("Running scoring program, and then ingestion program")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
gathered_tasks = asyncio.gather(
self._run_program_directory(program_dir, kind="program"),
self._run_program_directory(ingestion_program_dir, kind="ingestion"),
self.watch_detailed_results(),
loop=loop,
return_exceptions=True,
)

task_results = [] # will store results/exceptions from gather
signal.signal(signal.SIGALRM, alarm_handler)
signal.alarm(self.execution_time_limit)

try:
# run tasks
# keep what gather returned so we can detect async errors later
task_results = loop.run_until_complete(gathered_tasks) or []

except ExecutionTimeLimitExceeded:
error_message = f"Execution Time Limit exceeded. Limit was {self.execution_time_limit} seconds"
logger.error(error_message)
Expand Down Expand Up @@ -1231,7 +1263,11 @@ def start(self):
# Send error through web socket to the frontend
asyncio.run(self._send_data_through_socket(error_message))
raise SubmissionException(error_message)

finally:
signal.alarm(0)
asyncio.set_event_loop(None)
loop.close()
self.watch = False
for kind, logs in self.logs.items():
if logs["end"] is not None:
Expand Down Expand Up @@ -1277,7 +1313,6 @@ def start(self):

# set logs of this kind to None, since we handled them already
logger.info("Program finished")
signal.alarm(0)

if self.is_scoring:
# Check if scoring program failed
Expand All @@ -1287,7 +1322,7 @@ def start(self):
program_results, BaseException
) and not isinstance(program_results, asyncio.CancelledError)
program_rc = getattr(self, "program_exit_code", None)
failed_rc = program_rc not in (0, None)
failed_rc = (program_rc is None) or (program_rc != 0)
if had_async_exc or failed_rc:
self._update_status(
STATUS_FAILED,
Expand All @@ -1296,6 +1331,7 @@ def start(self):
# Raise so upstream marks failed immediately
raise SubmissionException("Child task failed or non-zero return code")
self._update_status(STATUS_FINISHED)

else:
self._update_status(STATUS_SCORING)

Expand Down
Loading