diff --git a/.semversioner/next-release/patch-20260315024056229023.json b/.semversioner/next-release/patch-20260315024056229023.json new file mode 100644 index 000000000..84a731604 --- /dev/null +++ b/.semversioner/next-release/patch-20260315024056229023.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "reconfigure vector store size by embedding model" +} diff --git a/packages/graphrag/graphrag/index/validate_config.py b/packages/graphrag/graphrag/index/validate_config.py index 4062b8de9..29593b21b 100644 --- a/packages/graphrag/graphrag/index/validate_config.py +++ b/packages/graphrag/graphrag/index/validate_config.py @@ -6,12 +6,16 @@ import asyncio import logging import sys +from typing import TYPE_CHECKING from graphrag_llm.completion import create_completion from graphrag_llm.embedding import create_embedding from graphrag.config.models.graph_rag_config import GraphRagConfig +if TYPE_CHECKING: + from graphrag_llm.types import LLMEmbeddingResponse + logger = logging.getLogger(__name__) @@ -29,13 +33,41 @@ def validate_config_names(parameters: GraphRagConfig) -> None: for id, config in parameters.embedding_models.items(): embed_llm = create_embedding(config) try: - asyncio.run( + response = asyncio.run( embed_llm.embedding_async( input=["This is an LLM Embedding Test String"] ) ) logger.info("Embedding LLM Config Params Validated") + + if id == parameters.embed_text.embedding_model_id: + _sync_vector_store_dimensions(parameters, response) + except Exception as e: # noqa: BLE001 logger.error(f"Embedding configuration error detected.\n{e}") # noqa print(f"Failed to validate embedding model ({id}) params", e) # noqa: T201 sys.exit(1) + + +def _sync_vector_store_dimensions( + parameters: GraphRagConfig, + response: "LLMEmbeddingResponse", +) -> None: + """Sync vector store dimensions to match the actual embedding model output.""" + detected = len(response.first_embedding) + if detected == 0: + return + + configured = parameters.vector_store.vector_size + if detected == configured: + return + + logger.warning( + "Embedding model produces %d-dimensional vectors but vector_size is " + "configured as %d. Overriding vector_size to match the model.", + detected, + configured, + ) + parameters.vector_store.vector_size = detected + for schema in parameters.vector_store.index_schema.values(): + schema.vector_size = detected