diff --git a/src/server/services/ai/llm/cohere.ts b/src/server/services/ai/llm/cohere.ts new file mode 100644 index 0000000..37c2005 --- /dev/null +++ b/src/server/services/ai/llm/cohere.ts @@ -0,0 +1,105 @@ +import { + type LlmAdapter, + LlmError, + type LlmRequest, + type LlmResponse, + classifyHttpError, + estimateTokens, +} from "./types.js"; + +interface CohereRequestParams { + apiKey: string; + baseUrl?: string; +} + +/** + * Cohere Chat API adapter (command-r, command-r-plus). + * + * Cohere's wire format differs from the OpenAI chat completions surface: + * - Request: { message, preamble, chat_history? } + * - Response: { text, meta: { tokens: { input_tokens, output_tokens } } } + * - Error: { message: "..." } (no nested error object) + * + * The watcher uses one system + one user message per turn, so we map + * `systemPrompt` -> `preamble` and `transcriptPrompt` -> `message`. + * `chat_history` is omitted because the watcher always sends a fresh + * transcript excerpt; multi-turn conversation history is the caller's + * responsibility. + */ +export function createCohereAdapter(params: CohereRequestParams): LlmAdapter { + const apiKey = params.apiKey; + const baseUrl = params.baseUrl?.replace(/\/+$/, "") || "https://api.cohere.com"; + + return { + kind: "cohere", + async complete(request: LlmRequest): Promise { + const timeoutMs = request.timeoutMs ?? 60_000; + const body: Record = { + model: request.model, + message: request.transcriptPrompt, + preamble: request.systemPrompt, + max_tokens: request.maxTokens ?? 1024, + temperature: request.temperature ?? 0.2, + }; + + let response: Response; + try { + response = await fetch(`${baseUrl}/v1/chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(body), + signal: AbortSignal.timeout(timeoutMs), + }); + } catch (err) { + if (err instanceof Error && err.name === "TimeoutError") { + throw new LlmError("transient_timeout", "Cohere request timed out", undefined, err); + } + throw new LlmError("unknown", `Cohere request failed: ${err}`, undefined, err); + } + + if (!response.ok) { + const { subType, body: errBody, status } = await classifyHttpError(response); + const message = extractCohereErrorMessage(errBody) ?? errBody.slice(0, 300); + throw new LlmError(subType, `Cohere ${status}: ${message}`, status); + } + + const json = (await response.json()) as { + text?: string; + meta?: { + tokens?: { + input_tokens?: number; + output_tokens?: number; + }; + }; + }; + + const text = json.text ?? ""; + const tokens = json.meta?.tokens ?? {}; + const inputReported = typeof tokens.input_tokens === "number"; + return { + text, + usage: { + inputTokens: tokens.input_tokens ?? estimateTokens(request.transcriptPrompt), + outputTokens: tokens.output_tokens ?? estimateTokens(text), + estimated: !inputReported, + }, + rawResponse: json, + }; + }, + }; +} + +function extractCohereErrorMessage(body: string): string | null { + try { + const parsed = JSON.parse(body) as { message?: unknown }; + if (typeof parsed.message === "string" && parsed.message.length > 0) { + return parsed.message.slice(0, 300); + } + } catch { + // fall through + } + return null; +} diff --git a/src/server/services/ai/llm/list-models.ts b/src/server/services/ai/llm/list-models.ts index 82492e4..67a2cd9 100644 --- a/src/server/services/ai/llm/list-models.ts +++ b/src/server/services/ai/llm/list-models.ts @@ -54,6 +54,8 @@ function defaultBaseUrl(kind: ProviderKind): string { return "https://generativelanguage.googleapis.com/v1beta/openai"; case "openai_compatible": return "http://localhost:11434/v1"; + case "cohere": + return "https://api.cohere.com"; } } @@ -77,6 +79,13 @@ export async function listAvailableModels( if (!input.apiKey) return { ok: false, error: "Anthropic API key is required." }; headers["x-api-key"] = input.apiKey; headers["anthropic-version"] = "2023-06-01"; + } else if (kind === "cohere") { + // Cohere exposes /v1/models on the api.cohere.com host. The + // response shape is { models: [{ name }] } which the existing + // `.data ?? .models` fallback below already handles. + url = `${baseUrl}/v1/models`; + if (!input.apiKey) return { ok: false, error: "Cohere API key is required." }; + headers.Authorization = `Bearer ${input.apiKey}`; } else { // OpenAI-compatible surface: baseUrl already ends with /v1 for // the normal provider configs. diff --git a/src/server/services/ai/llm/llm.test.ts b/src/server/services/ai/llm/llm.test.ts index 879efee..12e3535 100644 --- a/src/server/services/ai/llm/llm.test.ts +++ b/src/server/services/ai/llm/llm.test.ts @@ -1,5 +1,6 @@ import { afterEach, beforeEach, describe, expect, test } from "bun:test"; import { createAnthropicAdapter } from "./anthropic.js"; +import { createCohereAdapter } from "./cohere.js"; import { createOpenAICompatibleAdapter } from "./openai-compatible.js"; import { priceCompletion } from "./pricing.js"; import { LlmError, estimateTokens } from "./types.js"; @@ -218,4 +219,108 @@ describe("pricing", () => { // default 50c + 200c = 250c expect(cents).toBe(250); }); + + test("charges cohere command-r at flash-tier rates", () => { + const cents = priceCompletion("cohere", "command-r-08-2024", { + inputTokens: 1_000_000, + outputTokens: 100_000, + estimated: false, + }); + // 1M * 15c + 100k * 60c/1M = 15 + 6 = 21c + expect(cents).toBe(21); + }); + + test("charges cohere command-r-plus at premium-tier rates", () => { + const cents = priceCompletion("cohere", "command-r-plus-08-2024", { + inputTokens: 1_000_000, + outputTokens: 100_000, + estimated: false, + }); + // 1M * 250c + 100k * 1000c/1M = 250 + 100 = 350c + expect(cents).toBe(350); + }); +}); + +describe("cohere adapter", () => { + test("sends preamble + message and Authorization header", async () => { + mockFetch( + new Response( + JSON.stringify({ + text: "ack", + meta: { tokens: { input_tokens: 12, output_tokens: 3 } }, + }), + { status: 200 }, + ), + ); + const adapter = createCohereAdapter({ apiKey: "co-key" }); + const res = await adapter.complete({ + systemPrompt: "Watcher", + transcriptPrompt: "events", + model: "command-r", + }); + expect(res.text).toBe("ack"); + expect(res.usage.inputTokens).toBe(12); + expect(res.usage.outputTokens).toBe(3); + expect(res.usage.estimated).toBe(false); + + const [req] = capturedRequests; + expect(req.url).toBe("https://api.cohere.com/v1/chat"); + expect((req.init.headers as Record).Authorization).toBe("Bearer co-key"); + const body = JSON.parse(String(req.init.body)); + expect(body.model).toBe("command-r"); + expect(body.message).toBe("events"); + expect(body.preamble).toBe("Watcher"); + // Default sampling matches the watcher's policy. + expect(body.temperature).toBe(0.2); + }); + + test("trims trailing slashes on baseUrl override", async () => { + mockFetch(new Response(JSON.stringify({ text: "" }), { status: 200 })); + const adapter = createCohereAdapter({ + apiKey: "co-key", + baseUrl: "https://proxy.example.com//", + }); + await adapter.complete({ + systemPrompt: "x", + transcriptPrompt: "y", + model: "command-r", + }); + expect(capturedRequests[0].url).toBe("https://proxy.example.com/v1/chat"); + }); + + test("normalizes 401 to permanent_auth and surfaces message field", async () => { + mockFetch(new Response(JSON.stringify({ message: "invalid api token" }), { status: 401 })); + const adapter = createCohereAdapter({ apiKey: "bad" }); + const err = await adapter + .complete({ systemPrompt: "x", transcriptPrompt: "y", model: "command-r" }) + .catch((e: unknown) => e); + expect(err).toBeInstanceOf(LlmError); + const llm = err as LlmError; + expect(llm.subType).toBe("permanent_auth"); + expect(llm.status).toBe(401); + expect(llm.message).toContain("invalid api token"); + }); + + test("normalizes 429 to transient_rate_limit", async () => { + mockFetch(new Response(JSON.stringify({ message: "rate limit" }), { status: 429 })); + const adapter = createCohereAdapter({ apiKey: "ok" }); + const err = await adapter + .complete({ systemPrompt: "x", transcriptPrompt: "y", model: "command-r" }) + .catch((e: unknown) => e); + expect((err as LlmError).subType).toBe("transient_rate_limit"); + }); + + test("estimates tokens when meta.tokens is missing", async () => { + mockFetch(new Response(JSON.stringify({ text: "hello" }), { status: 200 })); + const adapter = createCohereAdapter({ apiKey: "ok" }); + const res = await adapter.complete({ + systemPrompt: "x", + transcriptPrompt: "transcript body", + model: "command-r", + }); + expect(res.text).toBe("hello"); + expect(res.usage.estimated).toBe(true); + expect(res.usage.inputTokens).toBeGreaterThan(0); + expect(res.usage.outputTokens).toBeGreaterThan(0); + }); }); diff --git a/src/server/services/ai/llm/pricing.ts b/src/server/services/ai/llm/pricing.ts index 4196a21..09144fc 100644 --- a/src/server/services/ai/llm/pricing.ts +++ b/src/server/services/ai/llm/pricing.ts @@ -36,6 +36,9 @@ const MODEL_RATES: Array<{ match: RegExp; rate: ModelRate }> = [ { match: /^gemini-1\.5-flash/, rate: { inputPer1M: 8, outputPer1M: 30 } }, { match: /^gemini-1\.5-pro/, rate: { inputPer1M: 125, outputPer1M: 500 } }, { match: /^gemini-2/, rate: { inputPer1M: 100, outputPer1M: 400 } }, + // Cohere + { match: /^command-r-plus/, rate: { inputPer1M: 250, outputPer1M: 1000 } }, + { match: /^command-r/, rate: { inputPer1M: 15, outputPer1M: 60 } }, ]; /** Local/self-hosted kinds assumed free. */ diff --git a/src/server/services/ai/llm/registry.ts b/src/server/services/ai/llm/registry.ts index a4e5f62..172b6e4 100644 --- a/src/server/services/ai/llm/registry.ts +++ b/src/server/services/ai/llm/registry.ts @@ -1,4 +1,5 @@ import { createAnthropicAdapter } from "./anthropic.js"; +import { createCohereAdapter } from "./cohere.js"; import { createOpenAICompatibleAdapter } from "./openai-compatible.js"; import type { LlmAdapter, ProviderKind } from "./types.js"; @@ -22,6 +23,8 @@ function defaultBaseUrl(kind: ProviderKind): string { return "https://generativelanguage.googleapis.com/v1beta/openai"; case "openai_compatible": return "http://localhost:11434/v1"; // Ollama default + case "cohere": + return "https://api.cohere.com"; } } @@ -32,6 +35,10 @@ export function getAdapter(provider: ProviderConfig): LlmAdapter { return createAnthropicAdapter({ apiKey: provider.apiKey, baseUrl }); } + if (provider.kind === "cohere") { + return createCohereAdapter({ apiKey: provider.apiKey, baseUrl }); + } + // Everything else uses the OpenAI-compatible chat completions surface. // OpenRouter wants attribution headers; include them when applicable so // the request appears cleanly in OpenRouter's dashboards. diff --git a/src/server/services/ai/llm/types.ts b/src/server/services/ai/llm/types.ts index fe49a44..73f702a 100644 --- a/src/server/services/ai/llm/types.ts +++ b/src/server/services/ai/llm/types.ts @@ -7,7 +7,13 @@ * use) stays inside the adapter. */ -export type ProviderKind = "anthropic" | "openai" | "google" | "openrouter" | "openai_compatible"; +export type ProviderKind = + | "anthropic" + | "openai" + | "google" + | "openrouter" + | "openai_compatible" + | "cohere"; export interface LlmRequest { /** Stable across a session — candidate for provider prompt caching. */