-
Notifications
You must be signed in to change notification settings - Fork 0
Feat: Translation Memory #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| import asyncio | ||
| import logging | ||
| from datetime import datetime, timezone | ||
| from typing import AsyncGenerator | ||
|
|
||
| from api.Assistant.assistant_repository import get_assistant_by_id_repository | ||
|
|
@@ -7,20 +9,28 @@ | |
| from api.langgraph.workflow_stream import stream_workflow_events | ||
| from api.ai.ai_response_model import ( | ||
| WorkflowRequest, | ||
| SegmentRequest, | ||
| StreamResponse, | ||
| WorkflowResult, | ||
| ResponseMetadata, | ||
| AvailableModelsResponse, | ||
| ModelInfo, | ||
| EnhanceResponse | ||
| EnhanceResponse, | ||
| FuzzyMatch, | ||
| ) | ||
| from api.db.pg_database import SessionLocal | ||
| from api.llm.router import get_model_router | ||
| from api.external_api import get_related_segment_ids, get_segment_content | ||
| from api.ai.prompts import ENHANCE_META_PROMPT | ||
| from api.translation_memory.tm_model import TranslationMemory | ||
| from api.translation_memory.tm_repository import ( | ||
| find_exact_match, | ||
| find_fuzzy_matches, | ||
| batch_create_tm_entries, | ||
| ) | ||
| from fastapi import HTTPException | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def build_workflow_request(db_session, assistant_id, target_language, prompt, segments, model, instruction=None) -> WorkflowRequest: | ||
| assistant_detail = get_assistant_by_id_repository(db_session, assistant_id) | ||
|
|
@@ -63,12 +73,106 @@ def validate_model(model: str) -> None: | |
| ) | ||
|
|
||
|
|
||
| def _lookup_translation_memory(assistant_id, prompt, target_language): | ||
| tm_hits = {} | ||
| fuzzy_cache = {} | ||
| texts_for_llm = [] | ||
|
|
||
| if target_language: | ||
| with SessionLocal() as db_session: | ||
| for idx, source_text in enumerate(prompt): | ||
| exact = find_exact_match( | ||
| db_session, assistant_id, source_text, target_language | ||
| ) | ||
| if exact: | ||
| tm_hits[idx] = exact.target_text | ||
| continue | ||
|
|
||
| rows = find_fuzzy_matches( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if fails then we fuzzy match with the input
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here the input still go under llm translation |
||
| db_session, assistant_id, source_text, target_language | ||
| ) | ||
| fuzzy_cache[idx] = [ | ||
| FuzzyMatch( | ||
| source_text=row.source_text, | ||
| target_text=row.target_text, | ||
| score=round(float(row.score), 4), | ||
| ) | ||
| for row in rows | ||
| ] | ||
| texts_for_llm.append((idx, source_text)) | ||
| else: | ||
| texts_for_llm = list(enumerate(prompt)) | ||
|
|
||
| return tm_hits, fuzzy_cache, texts_for_llm | ||
|
|
||
|
|
||
| def _save_translations_to_memory(assistant_id, target_language, texts_for_llm, llm_results): | ||
| try: | ||
| with SessionLocal() as db_session: | ||
| new_entries = [ | ||
| TranslationMemory( | ||
| assistant_id=assistant_id, | ||
| source_text=source_text, | ||
| target_text=result.output_text, | ||
| target_language=target_language, | ||
| ) | ||
| for (_, source_text), result in zip(texts_for_llm, llm_results) | ||
| ] | ||
| batch_create_tm_entries(db_session, new_entries) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to save translations to TM: {e}") | ||
|
|
||
|
|
||
| def _merge_tm_and_llm_results(prompt, tm_hits, fuzzy_cache, llm_results): | ||
| llm_iter = iter(llm_results) | ||
| merged = [] | ||
| for idx in range(len(prompt)): | ||
| if idx in tm_hits: | ||
| merged.append( | ||
| WorkflowResult(output_text=tm_hits[idx], from_memory=True) | ||
| ) | ||
| else: | ||
| llm_result = next(llm_iter) | ||
| merged.append( | ||
| WorkflowResult( | ||
| output_text=llm_result.output_text, | ||
| from_memory=False, | ||
| fuzzy_matches=fuzzy_cache.get(idx, []), | ||
| ) | ||
| ) | ||
| return merged | ||
|
|
||
|
|
||
| async def run_workflow_service(assistant_id, target_language, prompt, segments, model, offset=0, instruction=None): | ||
| validate_model(model) | ||
|
|
||
|
|
||
| tm_hits, fuzzy_cache, texts_for_llm = _lookup_translation_memory( | ||
| assistant_id, prompt, target_language | ||
| ) | ||
|
|
||
| if not texts_for_llm: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if true, that means exact match successes so we return it without the need to call the llm |
||
| now = datetime.now(timezone.utc).isoformat() | ||
| return StreamResponse( | ||
| results=[ | ||
| WorkflowResult( | ||
| output_text=tm_hits[idx], from_memory=True | ||
| ) | ||
| for idx in range(len(prompt)) | ||
| ], | ||
| metadata=ResponseMetadata( | ||
| initialized_at=now, | ||
| total_batches=0, | ||
| completed_at=now, | ||
| total_processing_time=0.0, | ||
| ), | ||
| errors=[], | ||
| ) | ||
|
|
||
| llm_prompts = [text for _, text in texts_for_llm] | ||
|
|
||
| with SessionLocal() as db_session: | ||
| workflow_request = build_workflow_request( | ||
| db_session, assistant_id, target_language, prompt, segments, model, instruction | ||
| db_session, assistant_id, target_language, llm_prompts, segments, model, instruction | ||
| ) | ||
| if workflow_request.instance_ids and workflow_request.segments: | ||
| all_segment_ids = [] | ||
|
|
@@ -97,29 +201,30 @@ async def run_workflow_service(assistant_id, target_language, prompt, segments, | |
| workflow_request.contexts.append( | ||
| ContextRequest(content=content) | ||
| ) | ||
|
|
||
| workflow_response = await run_workflow(workflow_request) | ||
|
|
||
| results = [ | ||
| WorkflowResult(output_text=result.output_text) | ||
| for result in workflow_response.get("final_results", []) | ||
| ] | ||
|
|
||
| workflow_metadata = workflow_response.get("metadata", {}) | ||
| response_metadata = ResponseMetadata( | ||
| initialized_at=workflow_metadata.get("initialized_at"), | ||
| total_batches=workflow_metadata.get("total_batches"), | ||
| completed_at=workflow_metadata.get("completed_at"), | ||
| total_processing_time=workflow_metadata.get("total_processing_time") | ||
| llm_results = workflow_response.get("final_results", []) | ||
|
|
||
| if target_language and llm_results: | ||
| _save_translations_to_memory( | ||
| assistant_id, target_language, texts_for_llm, llm_results | ||
| ) | ||
|
|
||
| merged_results = _merge_tm_and_llm_results( | ||
| prompt, tm_hits, fuzzy_cache, llm_results | ||
| ) | ||
|
|
||
| response = StreamResponse( | ||
| results=results, | ||
| metadata=response_metadata, | ||
| errors=workflow_response.get("errors", []) | ||
|
|
||
| workflow_metadata = workflow_response.get("metadata", {}) | ||
| return StreamResponse( | ||
| results=merged_results, | ||
| metadata=ResponseMetadata( | ||
| initialized_at=workflow_metadata.get("initialized_at"), | ||
| total_batches=workflow_metadata.get("total_batches"), | ||
| completed_at=workflow_metadata.get("completed_at"), | ||
| total_processing_time=workflow_metadata.get("total_processing_time"), | ||
| ), | ||
| errors=workflow_response.get("errors", []), | ||
| ) | ||
|
|
||
| return response | ||
|
|
||
|
|
||
| async def stream_workflow_service( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| from sqlalchemy.orm import Session | ||
| from sqlalchemy import text | ||
| from api.translation_memory.tm_model import TranslationMemory | ||
| from uuid import UUID | ||
| from typing import List, Optional | ||
| from fastapi import HTTPException, status | ||
| import logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def find_exact_match( | ||
| db: Session, assistant_id: UUID, source_text: str, target_language: str | ||
| ) -> Optional[TranslationMemory]: | ||
| return db.query(TranslationMemory).filter( | ||
| TranslationMemory.assistant_id == assistant_id, | ||
| TranslationMemory.source_text == source_text, | ||
| TranslationMemory.target_language == target_language, | ||
| ).first() | ||
|
|
||
|
|
||
| def find_fuzzy_matches( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fuzzy match logic generated via opus |
||
| db: Session, | ||
| assistant_id: UUID, | ||
| source_text: str, | ||
| target_language: str, | ||
| limit: int = 5, | ||
| threshold: float = 0.3, | ||
| ) -> list: | ||
| results = db.execute( | ||
| text(""" | ||
| SELECT source_text, target_text, | ||
| similarity(source_text, :source_text) AS score | ||
| FROM translation_memory | ||
| WHERE assistant_id = :assistant_id | ||
| AND target_language = :target_language | ||
| AND similarity(source_text, :source_text) > :threshold | ||
| AND source_text != :source_text | ||
| ORDER BY score DESC | ||
| LIMIT :limit | ||
| """), | ||
| { | ||
| "assistant_id": str(assistant_id), | ||
| "source_text": source_text, | ||
| "target_language": target_language, | ||
| "threshold": threshold, | ||
| "limit": limit, | ||
| }, | ||
| ).fetchall() | ||
| return results | ||
|
|
||
| def batch_create_tm_entries( | ||
| db: Session, entries: List[TranslationMemory] | ||
| ) -> List[TranslationMemory]: | ||
| try: | ||
| db.add_all(entries) | ||
| db.commit() | ||
| for entry in entries: | ||
| db.refresh(entry) | ||
| return entries | ||
| except Exception as e: | ||
| db.rollback() | ||
| logger.error(f"Error batch creating TM entries: {e}") | ||
| raise HTTPException( | ||
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
| detail="Failed to batch create translation memory entries", | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first try to find exact match with the input, if matched ignore the llm step