Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import asyncio
import functools
import logging
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -43,6 +44,14 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
@abstractmethod
def generate(self, data: DataT) -> DataT: ...

async def agenerate(self, data: dict) -> dict:
"""Async fallback — delegates to sync generate via thread pool.

Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion)
should override this with a direct async implementation.
"""
return await asyncio.to_thread(self.generate, data)

Comment thread
eric-tramel marked this conversation as resolved.
def log_pre_generation(self) -> None:
"""A shared method to log info before the generator's `generate` method is called.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import functools
import logging
from typing import TYPE_CHECKING, Any

from data_designer.config.column_configs import (
LLMCodeColumnConfig,
Expand All @@ -24,6 +25,9 @@
from data_designer.engine.models.recipes.base import ResponseRecipe
from data_designer.engine.processing.utils import deserialize_json_values

if TYPE_CHECKING:
from data_designer.engine.models.utils import ChatMessage

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -56,36 +60,55 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer:
)

def generate(self, data: dict) -> dict:
kwargs = self._prepare_generation_kwargs(data)
response, trace = self.model.generate(**kwargs)
return self._process_generation_result(data, response, trace)

async def agenerate(self, data: dict) -> dict:
kwargs = self._prepare_generation_kwargs(data)
response, trace = await self.model.agenerate(**kwargs)
return self._process_generation_result(data, response, trace)

def _prepare_generation_kwargs(self, data: dict) -> dict[str, Any]:
"""Prepare keyword arguments for model.generate() / model.agenerate().

Deserializes input data, builds multi-modal context, and renders prompts.
"""
# Deserialize input data from previous columns so Jinja2 templates can access nested fields
# Example: If prev column stored '{"key": "value"}', templates can use {{ prev_column.key }}
# Note: This creates a new dict and doesn't mutate the original `data` argument
deserialized_record = deserialize_json_values(data)

multi_modal_context = None
multi_modal_context: list[dict[str, Any]] | None = None
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
multi_modal_context = []
for context in self.config.multi_modal_context:
multi_modal_context.extend(context.get_contexts(deserialized_record))

response, trace = self.model.generate(
prompt=self.prompt_renderer.render(
return {
"prompt": self.prompt_renderer.render(
record=deserialized_record,
prompt_template=self.config.prompt,
prompt_type=PromptType.USER_PROMPT,
),
system_prompt=self.prompt_renderer.render(
"system_prompt": self.prompt_renderer.render(
record=deserialized_record,
prompt_template=self.config.system_prompt,
prompt_type=PromptType.SYSTEM_PROMPT,
),
parser=self.response_recipe.parse,
multi_modal_context=multi_modal_context,
tool_alias=self.config.tool_alias,
max_correction_steps=self.max_conversation_correction_steps,
max_conversation_restarts=self.max_conversation_restarts,
purpose=f"running generation for column '{self.config.name}'",
)

"parser": self.response_recipe.parse,
"multi_modal_context": multi_modal_context,
"tool_alias": self.config.tool_alias,
"max_correction_steps": self.max_conversation_correction_steps,
"max_conversation_restarts": self.max_conversation_restarts,
"purpose": f"running generation for column '{self.config.name}'",
}

def _process_generation_result(self, data: dict, response: Any, trace: list[ChatMessage]) -> dict:
"""Process model response and trace into the output data dict.

Serializes the response, applies trace column logic, and extracts reasoning content.
"""
serialized_output = self.response_recipe.serialize_output(response)
data[self.config.name] = self._process_serialized_output(serialized_output)

Expand All @@ -102,7 +125,7 @@ def generate(self, data: dict) -> dict:

return data

def _extract_reasoning_content(self, trace: list) -> str | None:
def _extract_reasoning_content(self, trace: list[ChatMessage]) -> str | None:
"""Extract reasoning_content from the final assistant message in the trace.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import functools
import logging
import os
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from data_designer.config.column_configs import CustomColumnConfig
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
Expand Down Expand Up @@ -51,6 +52,20 @@

logger = logging.getLogger(__name__)

DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1"

if DATA_DESIGNER_ASYNC_ENGINE:
import sys

if sys.version_info < (3, 11):
raise RuntimeError(
"DATA_DESIGNER_ASYNC_ENGINE requires Python 3.11+ (asyncio.TaskGroup). "
f"Current version: {sys.version_info.major}.{sys.version_info.minor}"
)
from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor

logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency")

_CLIENT_VERSION: str = get_library_version()


Expand Down Expand Up @@ -255,7 +270,11 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers
if isinstance(generator, ColumnGeneratorWithModel):
max_workers = generator.inference_parameters.max_parallel_requests
self._fan_out_with_threads(generator, max_workers=max_workers)
if DATA_DESIGNER_ASYNC_ENGINE:
logger.info("⚡ Using async engine for concurrent execution")
self._fan_out_with_async(generator, max_workers=max_workers)
else:
self._fan_out_with_threads(generator, max_workers=max_workers)
Comment thread
eric-tramel marked this conversation as resolved.

def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
df = generator.generate(self.batch_manager.get_current_batch(as_dataframe=True))
Expand All @@ -282,42 +301,60 @@ def _run_mcp_tool_check_if_needed(self) -> None:
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
self._resource_provider.mcp_registry.run_health_check(tool_aliases)

def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
def _setup_fan_out(
self, generator: ColumnGeneratorWithModelRegistry, max_workers: int
) -> tuple[ProgressTracker, dict[str, Any]]:
if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL:
raise DatasetGenerationError(
f"Generator {generator.name} is not a {GenerationStrategy.CELL_BY_CELL} "
"generator so concurrency through threads is not supported."
"generator so concurrent fan-out is not supported."
)

if getattr(generator.config, "tool_alias", None):
logger.info("🛠️ Tool calling enabled")

progress_tracker = ProgressTracker(
total_records=self.batch_manager.num_records_batch,
label=f"{generator.config.column_type} column '{generator.config.name}'",
)
progress_tracker.log_start(max_workers)

settings = self._resource_provider.run_config
with ConcurrentThreadExecutor(
max_workers=max_workers,
column_name=generator.config.name,
result_callback=self._make_result_callback(progress_tracker),
error_callback=self._make_error_callback(progress_tracker),
shutdown_error_rate=settings.shutdown_error_rate,
shutdown_error_window=settings.shutdown_error_window,
disable_early_shutdown=settings.disable_early_shutdown,
) as executor:
for i, record in self.batch_manager.iter_current_batch():
executor.submit(lambda record: generator.generate(record), record, context={"index": i})

executor_kwargs: dict = {
"column_name": generator.config.name,
"result_callback": self._make_result_callback(progress_tracker),
"error_callback": self._make_error_callback(progress_tracker),
"shutdown_error_rate": settings.shutdown_error_rate,
"shutdown_error_window": settings.shutdown_error_window,
"disable_early_shutdown": settings.disable_early_shutdown,
}

return progress_tracker, executor_kwargs

def _finalize_fan_out(self, progress_tracker: ProgressTracker) -> None:
progress_tracker.log_final()

if len(self._records_to_drop) > 0:
self._cleanup_dropped_record_images(self._records_to_drop)
self.batch_manager.drop_records(self._records_to_drop)
self._records_to_drop.clear()

def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
if getattr(generator.config, "tool_alias", None):
logger.info("🛠️ Tool calling enabled")
progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers)
Comment thread
eric-tramel marked this conversation as resolved.
executor = AsyncConcurrentExecutor(max_workers=max_workers, **executor_kwargs)
work_items = [
(generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch()
]
executor.run(work_items)
self._finalize_fan_out(progress_tracker)

def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None:
if getattr(generator.config, "tool_alias", None):
logger.info("🛠️ Tool calling enabled")
progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers)
with ConcurrentThreadExecutor(max_workers=max_workers, **executor_kwargs) as executor:
for i, record in self.batch_manager.iter_current_batch():
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
self._finalize_fan_out(progress_tracker)

def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]:
def callback(result: dict, *, context: dict | None = None) -> None:
self._worker_result_callback(result, context=context)
Expand Down
Loading