Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions examples/survey/survey_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
from dataclasses import dataclass
from typing import Annotated

import aiofiles
from aiocsv import AsyncWriter
from aiocsv import AsyncDictWriter
from dotenv import load_dotenv
from pydantic import Field

Expand Down Expand Up @@ -32,6 +31,18 @@
CommuteMethods = ["driving", "bus", "subway", "none"]
WorkStyles = ["independent", "team_player"]

CSV_COLUMNS = (
"name",
"get_name_intro_task",
"get_email_task",
"commute_task",
"experience_task",
"behavorial_task",
"summary",
"evaluation",
"disqualification_reason",
)


@dataclass
class Userdata:
Expand Down Expand Up @@ -67,10 +78,11 @@ class BehavioralResults:

async def write_to_csv(filename: str, data: dict):
async with aiofiles.open(filename, "a", newline="") as csvfile:
writer = AsyncWriter(csvfile, data.keys())
if not os.path.exists(filename):
is_empty = await csvfile.tell() == 0
writer = AsyncDictWriter(csvfile, fieldnames=CSV_COLUMNS, extrasaction="ignore")
if is_empty:
await writer.writeheader()
await writer.writerow(data.values())
await writer.writerow(data)


async def evaluate_candidate(llm_model, summary) -> str:
Expand Down Expand Up @@ -112,7 +124,7 @@ async def disqualify(context: RunContext, disqualification_reason: str) -> None:
disqualification_reason = "[DISQUALIFIED] " + disqualification_reason
data = {
"name": context.session.userdata.candidate_name,
"disqualification reason": disqualification_reason,
"disqualification_reason": disqualification_reason,
}
await write_to_csv(context.session.userdata.filename, data)
context.session.shutdown()
Expand Down Expand Up @@ -317,6 +329,7 @@ async def on_enter(self) -> AgentTask:
# TaskGroup returns a TaskGroupResult object. The task_results field holds a dictionary with Task IDs as the keys and the results as the values
summary = self.chat_ctx.items[-1]
evaluation = await evaluate_candidate(llm_model=self.session.llm, summary=summary)
results["name"] = self.session.userdata.candidate_name
results["summary"] = summary.content
results["evaluation"] = evaluation
self.session.userdata.task_results = results
Expand Down
Loading