Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions python/infinilm/server/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ async def chat_completions(request: Request):
else:
data["messages"] = [{"role": "user", "content": data.get("prompt")}]

# Normalize messages to handle multimodal content (list format)
data["messages"] = self._normalize_messages(data.get("messages", []))

stream = data.get("stream", False)
request_id = f"cmpl-{uuid.uuid4().hex}"

Expand Down Expand Up @@ -206,6 +209,39 @@ async def list_models():
async def list_models_legacy():
return _models_payload()

def _normalize_messages(self, messages: list) -> list:
"""Normalize messages to handle multimodal content (list format).

Converts content from list format [{"type": "text", "text": "..."}]
to string format for chat template compatibility.
"""
normalized = []
for msg in messages:
if not isinstance(msg, dict):
normalized.append(msg)
continue

content = msg.get("content")
if isinstance(content, list):
# Extract text from multimodal content list
text_parts = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text" and "text" in part:
text_parts.append(part["text"])
elif isinstance(part, str):
text_parts.append(part)
elif isinstance(part, str):
text_parts.append(part)
# Join all text parts
normalized_msg = msg.copy()
normalized_msg["content"] = "".join(text_parts) if text_parts else ""
normalized.append(normalized_msg)
else:
normalized.append(msg)

return normalized

def _build_sampling_params(self, data: dict) -> SamplingParams:
"""Build SamplingParams from request data."""
# Support both:
Expand Down