diff --git a/api/ai/ai_response_model.py b/api/ai/ai_response_model.py index 4c4faed..17cd943 100644 --- a/api/ai/ai_response_model.py +++ b/api/ai/ai_response_model.py @@ -43,8 +43,16 @@ class StreamRequest(BaseModel): offset: int = 0 instruction: Optional[str] = None +class FuzzyMatch(BaseModel): + source_text: str + target_text: str + score: float + + class WorkflowResult(BaseModel): output_text: str + from_memory: bool = False + fuzzy_matches: List[FuzzyMatch] = [] class ResponseMetadata(BaseModel): diff --git a/api/ai/ai_service.py b/api/ai/ai_service.py index d603b0a..18f7f2e 100644 --- a/api/ai/ai_service.py +++ b/api/ai/ai_service.py @@ -1,4 +1,6 @@ import asyncio +import logging +from datetime import datetime, timezone from typing import AsyncGenerator from api.Assistant.assistant_repository import get_assistant_by_id_repository @@ -7,20 +9,28 @@ from api.langgraph.workflow_stream import stream_workflow_events from api.ai.ai_response_model import ( WorkflowRequest, - SegmentRequest, StreamResponse, WorkflowResult, ResponseMetadata, AvailableModelsResponse, ModelInfo, - EnhanceResponse + EnhanceResponse, + FuzzyMatch, ) from api.db.pg_database import SessionLocal from api.llm.router import get_model_router from api.external_api import get_related_segment_ids, get_segment_content from api.ai.prompts import ENHANCE_META_PROMPT +from api.translation_memory.tm_model import TranslationMemory +from api.translation_memory.tm_repository import ( + find_exact_match, + find_fuzzy_matches, + batch_create_tm_entries, +) from fastapi import HTTPException +logger = logging.getLogger(__name__) + def build_workflow_request(db_session, assistant_id, target_language, prompt, segments, model, instruction=None) -> WorkflowRequest: assistant_detail = get_assistant_by_id_repository(db_session, assistant_id) @@ -63,12 +73,106 @@ def validate_model(model: str) -> None: ) +def _lookup_translation_memory(assistant_id, prompt, target_language): + tm_hits = {} + fuzzy_cache = {} + texts_for_llm = [] + + if target_language: + with SessionLocal() as db_session: + for idx, source_text in enumerate(prompt): + exact = find_exact_match( + db_session, assistant_id, source_text, target_language + ) + if exact: + tm_hits[idx] = exact.target_text + continue + + rows = find_fuzzy_matches( + db_session, assistant_id, source_text, target_language + ) + fuzzy_cache[idx] = [ + FuzzyMatch( + source_text=row.source_text, + target_text=row.target_text, + score=round(float(row.score), 4), + ) + for row in rows + ] + texts_for_llm.append((idx, source_text)) + else: + texts_for_llm = list(enumerate(prompt)) + + return tm_hits, fuzzy_cache, texts_for_llm + + +def _save_translations_to_memory(assistant_id, target_language, texts_for_llm, llm_results): + try: + with SessionLocal() as db_session: + new_entries = [ + TranslationMemory( + assistant_id=assistant_id, + source_text=source_text, + target_text=result.output_text, + target_language=target_language, + ) + for (_, source_text), result in zip(texts_for_llm, llm_results) + ] + batch_create_tm_entries(db_session, new_entries) + except Exception as e: + logger.warning(f"Failed to save translations to TM: {e}") + + +def _merge_tm_and_llm_results(prompt, tm_hits, fuzzy_cache, llm_results): + llm_iter = iter(llm_results) + merged = [] + for idx in range(len(prompt)): + if idx in tm_hits: + merged.append( + WorkflowResult(output_text=tm_hits[idx], from_memory=True) + ) + else: + llm_result = next(llm_iter) + merged.append( + WorkflowResult( + output_text=llm_result.output_text, + from_memory=False, + fuzzy_matches=fuzzy_cache.get(idx, []), + ) + ) + return merged + + async def run_workflow_service(assistant_id, target_language, prompt, segments, model, offset=0, instruction=None): validate_model(model) - + + tm_hits, fuzzy_cache, texts_for_llm = _lookup_translation_memory( + assistant_id, prompt, target_language + ) + + if not texts_for_llm: + now = datetime.now(timezone.utc).isoformat() + return StreamResponse( + results=[ + WorkflowResult( + output_text=tm_hits[idx], from_memory=True + ) + for idx in range(len(prompt)) + ], + metadata=ResponseMetadata( + initialized_at=now, + total_batches=0, + completed_at=now, + total_processing_time=0.0, + ), + errors=[], + ) + + llm_prompts = [text for _, text in texts_for_llm] + with SessionLocal() as db_session: workflow_request = build_workflow_request( - db_session, assistant_id, target_language, prompt, segments, model, instruction + db_session, assistant_id, target_language, llm_prompts, segments, model, instruction ) if workflow_request.instance_ids and workflow_request.segments: all_segment_ids = [] @@ -97,29 +201,30 @@ async def run_workflow_service(assistant_id, target_language, prompt, segments, workflow_request.contexts.append( ContextRequest(content=content) ) - + workflow_response = await run_workflow(workflow_request) - - results = [ - WorkflowResult(output_text=result.output_text) - for result in workflow_response.get("final_results", []) - ] - - workflow_metadata = workflow_response.get("metadata", {}) - response_metadata = ResponseMetadata( - initialized_at=workflow_metadata.get("initialized_at"), - total_batches=workflow_metadata.get("total_batches"), - completed_at=workflow_metadata.get("completed_at"), - total_processing_time=workflow_metadata.get("total_processing_time") + llm_results = workflow_response.get("final_results", []) + + if target_language and llm_results: + _save_translations_to_memory( + assistant_id, target_language, texts_for_llm, llm_results + ) + + merged_results = _merge_tm_and_llm_results( + prompt, tm_hits, fuzzy_cache, llm_results ) - - response = StreamResponse( - results=results, - metadata=response_metadata, - errors=workflow_response.get("errors", []) + + workflow_metadata = workflow_response.get("metadata", {}) + return StreamResponse( + results=merged_results, + metadata=ResponseMetadata( + initialized_at=workflow_metadata.get("initialized_at"), + total_batches=workflow_metadata.get("total_batches"), + completed_at=workflow_metadata.get("completed_at"), + total_processing_time=workflow_metadata.get("total_processing_time"), + ), + errors=workflow_response.get("errors", []), ) - - return response async def stream_workflow_service( diff --git a/api/translation-memory/tm_model.py b/api/translation_memory/tm_model.py similarity index 100% rename from api/translation-memory/tm_model.py rename to api/translation_memory/tm_model.py diff --git a/api/translation_memory/tm_repository.py b/api/translation_memory/tm_repository.py new file mode 100644 index 0000000..a6c08a9 --- /dev/null +++ b/api/translation_memory/tm_repository.py @@ -0,0 +1,67 @@ +from sqlalchemy.orm import Session +from sqlalchemy import text +from api.translation_memory.tm_model import TranslationMemory +from uuid import UUID +from typing import List, Optional +from fastapi import HTTPException, status +import logging + +logger = logging.getLogger(__name__) + + +def find_exact_match( + db: Session, assistant_id: UUID, source_text: str, target_language: str +) -> Optional[TranslationMemory]: + return db.query(TranslationMemory).filter( + TranslationMemory.assistant_id == assistant_id, + TranslationMemory.source_text == source_text, + TranslationMemory.target_language == target_language, + ).first() + + +def find_fuzzy_matches( + db: Session, + assistant_id: UUID, + source_text: str, + target_language: str, + limit: int = 5, + threshold: float = 0.3, +) -> list: + results = db.execute( + text(""" + SELECT source_text, target_text, + similarity(source_text, :source_text) AS score + FROM translation_memory + WHERE assistant_id = :assistant_id + AND target_language = :target_language + AND similarity(source_text, :source_text) > :threshold + AND source_text != :source_text + ORDER BY score DESC + LIMIT :limit + """), + { + "assistant_id": str(assistant_id), + "source_text": source_text, + "target_language": target_language, + "threshold": threshold, + "limit": limit, + }, + ).fetchall() + return results + +def batch_create_tm_entries( + db: Session, entries: List[TranslationMemory] +) -> List[TranslationMemory]: + try: + db.add_all(entries) + db.commit() + for entry in entries: + db.refresh(entry) + return entries + except Exception as e: + db.rollback() + logger.error(f"Error batch creating TM entries: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to batch create translation memory entries", + ) diff --git a/migrations/versions/351aedf7ebe6_translation_memory_schema.py b/migrations/versions/351aedf7ebe6_translation_memory_schema.py index 6d0cf03..f2f648f 100644 --- a/migrations/versions/351aedf7ebe6_translation_memory_schema.py +++ b/migrations/versions/351aedf7ebe6_translation_memory_schema.py @@ -11,7 +11,6 @@ import sqlalchemy as sa -# revision identifiers, used by Alembic. revision: str = '351aedf7ebe6' down_revision: Union[str, Sequence[str], None] = '8ef199eeb359' branch_labels: Union[str, Sequence[str], None] = None @@ -19,14 +18,38 @@ def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### + op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") + + op.create_table( + "translation_memory", + sa.Column("id", sa.UUID(), primary_key=True), + sa.Column( + "assistant_id", + sa.UUID(), + sa.ForeignKey("assistant.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("source_text", sa.Text(), nullable=False), + sa.Column("target_text", sa.Text(), nullable=False), + sa.Column("target_language", sa.String(255), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + + op.execute( + """ + CREATE INDEX idx_translation_memory_source_trgm + ON translation_memory + USING gist (source_text gist_trgm_ops) + """ + ) def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### + op.execute("DROP INDEX IF EXISTS idx_translation_memory_source_trgm") + op.drop_table("translation_memory") + op.execute("DROP EXTENSION IF EXISTS pg_trgm")