Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion mapserver/competency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,40 @@
COMPETENCY_USER = os.environ.get('COMPETENCY_USER')
COMPETENCY_HOST = os.environ.get('COMPETENCY_HOST', 'localhost:5432')

COMPETENCY_SCHEMA_VERSION_KEY = 'schema_version'
COMPETENCY_SCHEMA_STATE_KEY = 'competency-schema-version'
COMPETENCY_SCHEMA_VERSION = '1.1'

if not COMPETENCY_USER:
print('Competency queries are unavailable because COMPETENCY_USER is not set')

#===============================================================================

async def table_exists(connection: asyncpg.Connection, table_name: str) -> bool:
#===========================================================================
reg_class = await connection.fetchval('SELECT to_regclass($1)', table_name)
return reg_class is not None

async def schema_version(connection: asyncpg.Connection) -> str|None:
#===================================================================
if not await table_exists(connection, 'metadata'):
return None
row = await connection.fetchrow(
'SELECT value FROM metadata WHERE name=$1',
COMPETENCY_SCHEMA_VERSION_KEY,
)
return row[0] if row is not None else None

def schema_mismatch_error(expected: str, actual: str|None, query_id: str|None=None) -> str:
#=============================================================================
found = actual if actual is not None else 'missing metadata/schema_version'
query = f' (query {query_id})' if query_id is not None else ''
return (
f'Competency schema version mismatch{query}: '
f'expected `{expected}` but found `{found}`. '
'Some queries may fail until the database schema and query definitions are aligned.'
)

#===============================================================================
#===============================================================================

Expand Down Expand Up @@ -76,6 +107,8 @@ async def competency_connection_context(app: Litestar) -> AsyncGenerator[None, N
timeout=5
)
app.state['competency-pool'] = competency_pool
async with competency_pool.acquire() as connection:
app.state[COMPETENCY_SCHEMA_STATE_KEY] = await schema_version(connection)
except Exception as err:
# log (where?)
print(f'Unable to connect to competency database: {COMPETENCY_HOST}/{COMPETENCY_DATABASE}')
Expand All @@ -91,6 +124,23 @@ def get_competency_pool(app: Litestar) -> Optional[asyncpg.Pool]:
#================================================================
return getattr(app.state, 'competency-pool', None)

def get_competency_schema_version(app: Litestar) -> str|None:
#==============================================================
return getattr(app.state, COMPETENCY_SCHEMA_STATE_KEY, None)

async def get_competency_schema_info(app: Litestar) -> dict[str, str|None]:
#======================================================================
if (get_competency_pool(app)) is None:
return {
'version': None,
'expected': COMPETENCY_SCHEMA_VERSION,
'error': 'Backend cannot connect to Competency database',
}
return {
'version': get_competency_schema_version(app),
'expected': COMPETENCY_SCHEMA_VERSION,
}

#===============================================================================
#===============================================================================

Expand Down Expand Up @@ -118,6 +168,7 @@ async def query(data: QueryRequest, request: Request) -> QueryResults|QueryError
return {'error': f'Error building query: {err}'}
if (pool := get_competency_pool(request.app)) is None:
return {'error': 'Backend cannot connect to Competency database'}
db_schema = get_competency_schema_version(request.app)
try:
async with pool.acquire() as connection:
records = await connection.fetch(sql, *params)
Expand All @@ -133,6 +184,9 @@ async def query(data: QueryRequest, request: Request) -> QueryResults|QueryError
}
}
except Exception as err:
return {'error': f'Error executing query: {err}'}
error_msg = f'Error executing query: {err}.'
if db_schema != COMPETENCY_SCHEMA_VERSION:
error_msg += f' {schema_mismatch_error(COMPETENCY_SCHEMA_VERSION, db_schema, data["query_id"])}'
return {'error': error_msg}

#===============================================================================
9 changes: 8 additions & 1 deletion mapserver/server/competency.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

#===============================================================================

from ..competency import query, query_definition, query_definitions
from ..competency import query, query_definition, query_definitions, get_competency_schema_info

from ..competency.definition import QueryDefinitionDict, QueryDefinitionSummary
from ..competency.definition import QueryRequest, QueryError, QueryResults

Expand All @@ -52,13 +53,19 @@ async def competency_query(data: QueryRequest, request: Request) -> QueryResults
request.logger.warning(result["error"])
return result

@get('schema-version')
async def competency_schema_version(request: Request) -> dict[str, str|None]:
#==========================================================================
return await get_competency_schema_info(request.app)

#===============================================================================
#===============================================================================

competency_router = Router(
path="/competency",
route_handlers=[
competency_query,
competency_schema_version,
competency_query_definition,
competency_query_definitions,
]
Expand Down
28 changes: 24 additions & 4 deletions tools/competency-query/competency_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def print_table(header: list[str], rows: Iterable[Iterable[str]]):
QUERY_ENDPOINT = '/competency/query'

QUERY_DEFINITIONS_ENDPOINT = '/competency/queries'
QUERY_SCHEMA_VERSION_ENDPOINT = '/competency/schema-version'

#===============================================================================

Expand All @@ -81,7 +82,7 @@ class CompetencyQueryService:
def __init__(self, map_server: str):
self.__map_server = map_server

def request_json(self, method: str, endpoint: str, **kwds) -> dict|list:
def request_json(self, method: str, endpoint: str, quiet: bool=False, **kwds) -> dict|list:
#=======================================================================
endpoint = self.__map_server + endpoint
try:
Expand All @@ -102,15 +103,16 @@ def request_json(self, method: str, endpoint: str, **kwds) -> dict|list:
error = f'HTTP error for request: {response.status_code} {response.reason}'
except requests.exceptions.RequestException as exception:
error = f'Exception: {exception}'
print_formatted_text(FormattedText([('class:error', error),]),
if not quiet:
print_formatted_text(FormattedText([('class:error', error),]),
style=Style.from_dict({'error': '#ff0000 bold'}))
return []

def get_json(self, endpoint: str, param: Optional[str]=None) -> dict|list:
def get_json(self, endpoint: str, param: Optional[str]=None, quiet: bool=False) -> dict|list:
#=========================================================================
if param is not None:
endpoint += f'/{param}'
return self.request_json('GET', endpoint)
return self.request_json('GET', endpoint, quiet=quiet)

def post_query(self, request: QueryRequest) -> dict|list:
#========================================================
Expand All @@ -123,13 +125,31 @@ class CompetencyQueryShell:

def __init__(self, map_server: str):
self.__query_service = CompetencyQueryService(map_server)
self.__warn_if_schema_mismatch()
self.__queries: dict[str, str] = { str(query['id']): str(query['label'])
for query in self.__query_service.get_json(QUERY_DEFINITIONS_ENDPOINT)
if 'id' in query }
self.__cmd_session = PromptSession(message=HTML('<p fg="ansiwhite"><b>cq> </b></p>'),
style=Style.from_dict({'': COMMAND_INPUT_STYLE}))
self.__input_session = PromptSession()

def __warn_if_schema_mismatch(self):
#===================================
schema_info = self.__query_service.get_json(QUERY_SCHEMA_VERSION_ENDPOINT, quiet=True)
if isinstance(schema_info, dict):
server_schema = schema_info.get('version')
expected_schema = schema_info.get('expected')
if expected_schema is not None and server_schema != expected_schema:
warning = (
'WARNING: Competency schema version mismatch. '
f'Expected {expected_schema}, server has {server_schema}. '
'A schema upgrade may be required.'
)
print_formatted_text(
FormattedText([('class:warning', warning)]),
style=Style.from_dict({'warning': '#ffaf00 bold'})
)

def __list_queries(self):
#========================
print_table(['ID', 'Name'], list(self.__queries.items()))
Expand Down