diff --git a/Runware/Runware-base.ts b/Runware/Runware-base.ts index 41e567e..fabc679 100644 --- a/Runware/Runware-base.ts +++ b/Runware/Runware-base.ts @@ -1,5 +1,6 @@ // @ts-ignore import { asyncRetry } from "./async-retry"; +import { RunwareLogger, createLogger } from "./logger"; import { EControlMode, IControlNet, @@ -85,7 +86,13 @@ export class RunwareBase { _shouldReconnect: boolean; _globalMaxRetries: number; _timeoutDuration: number; + _heartbeatIntervalId: any = null; + _pongTimeoutId: any = null; + _heartbeatInterval: number; + _missedPongCount: number = 0; + _maxMissedPongs: number = 3; ensureConnectionUUID: string | null = null; + _logger: RunwareLogger; constructor({ apiKey, @@ -93,6 +100,8 @@ export class RunwareBase { shouldReconnect = true, globalMaxRetries = 2, timeoutDuration = TIMEOUT_DURATION, + heartbeatInterval = 45000, + enableLogging = false, }: RunwareBaseType) { this._apiKey = apiKey; this._url = url; @@ -100,6 +109,9 @@ export class RunwareBase { this._shouldReconnect = shouldReconnect; this._globalMaxRetries = globalMaxRetries; this._timeoutDuration = timeoutDuration; + // Clamp heartbeat interval between 10s and 120s + this._heartbeatInterval = Math.max(10000, Math.min(120000, heartbeatInterval)); + this._logger = createLogger(enableLogging); } private getUniqueUUID(item: MediaUUID): string | undefined { @@ -204,6 +216,73 @@ export class RunwareBase { return this._connectionError?.error?.code === "invalidApiKey"; }; + protected startHeartbeat() { + this.stopHeartbeat(); + this._logger.heartbeatStarted(this._heartbeatInterval); + this._heartbeatIntervalId = setInterval(() => { + if (!this.isWebsocketReadyState()) { + this.stopHeartbeat(); + return; + } + try { + this._ws.send( + JSON.stringify([{ taskType: "ping", ping: true }]), + ); + this._logger.heartbeatPingSent(); + } catch { + this.stopHeartbeat(); + return; + } + // Clear any previous pong timeout to prevent accumulation + if (this._pongTimeoutId) { + clearTimeout(this._pongTimeoutId); + this._pongTimeoutId = null; + } + this._pongTimeoutId = setTimeout(() => { + this._missedPongCount++; + this._logger.heartbeatPongMissed(this._missedPongCount, this._maxMissedPongs); + if (this._missedPongCount >= this._maxMissedPongs) { + if (this._ws) { + if (typeof this._ws.terminate === "function") { + this._ws.terminate(); + } else { + this._ws.close(); + } + } + } + }, 10000); + }, this._heartbeatInterval); + } + + protected stopHeartbeat() { + if (this._heartbeatIntervalId) { + clearInterval(this._heartbeatIntervalId); + this._heartbeatIntervalId = null; + this._logger.heartbeatStopped(); + } + if (this._pongTimeoutId) { + clearTimeout(this._pongTimeoutId); + this._pongTimeoutId = null; + } + this._missedPongCount = 0; + } + + protected handlePongMessage(data: any) { + const messages = Array.isArray(data?.data) ? data.data : []; + for (const msg of messages) { + if (msg?.taskType === "ping" && msg?.pong === true) { + this._missedPongCount = 0; + if (this._pongTimeoutId) { + clearTimeout(this._pongTimeoutId); + this._pongTimeoutId = null; + } + this._logger.heartbeatPongReceived(); + return true; + } + } + return false; + } + protected addListener({ lis, // check, @@ -258,33 +337,53 @@ export class RunwareBase { } protected connect() { - this._ws.onopen = (e: any) => { - if (this._connectionSessionUUID) { - this.send({ - taskType: ETaskType.AUTHENTICATION, - apiKey: this._apiKey, - connectionSessionUUID: this._connectionSessionUUID, - }); - } else { - this.send({ apiKey: this._apiKey, taskType: ETaskType.AUTHENTICATION }); + this._logger.connecting(this._url || "unknown"); + + this._ws.onopen = async (e: any) => { + this._logger.authenticating(!!this._connectionSessionUUID); + try { + if (this._connectionSessionUUID) { + await this.send({ + taskType: ETaskType.AUTHENTICATION, + apiKey: this._apiKey, + connectionSessionUUID: this._connectionSessionUUID, + }); + } else { + await this.send({ apiKey: this._apiKey, taskType: ETaskType.AUTHENTICATION }); + } + } catch (err) { + this._logger.error("Failed to send auth message", err); + return; } - this.addListener({ + const authListener = this.addListener({ taskUUID: ETaskType.AUTHENTICATION, lis: (m) => { if (m?.error) { this._connectionError = m; + this._logger.authError(m); + authListener?.destroy?.(); return; } this._connectionSessionUUID = m?.[ETaskType.AUTHENTICATION]?.[0]?.connectionSessionUUID; this._connectionError = undefined; + this._logger.authenticated(this._connectionSessionUUID || ""); + authListener?.destroy?.(); + this.startHeartbeat(); }, }); }; this._ws.onmessage = (e: any) => { - const data = JSON.parse(e.data); + let data; + try { + data = JSON.parse(e.data); + } catch (err) { + this._logger.error("Failed to parse WebSocket message", err); + return; + } + if (this.handlePongMessage(data)) return; for (const lis of this._listeners) { const result = (lis as any)?.listener?.(data); if (result) return; @@ -292,16 +391,39 @@ export class RunwareBase { }; this._ws.onclose = (e: any) => { - // console.log("closing"); - // console.log("invalid", this._invalidAPIkey); + this._logger.connectionClosed(e?.code); + this._connectionSessionUUID = undefined; + this.stopHeartbeat(); if (this.isInvalidAPIKey()) { return; } }; + + this._ws.onerror = (e: any) => { + this._logger.connectionError(e?.message || e); + }; } // We moving to an array format, it make sense to consolidate all request to an array here - protected send = (msg: Object) => { + protected send = async (msg: Object) => { + if (!this.isWebsocketReadyState()) { + this._logger.sendReconnecting(); + if (this._ws) { + try { + if (typeof this._ws.terminate === "function") { + this._ws.terminate(); + } else { + this._ws.close(); + } + } catch {} + } + this._connectionSessionUUID = undefined; + // ensureConnection either resolves (ws ready) or throws + await this.ensureConnection(); + } + const taskType = (msg as any)?.taskType; + const taskUUID = (msg as any)?.taskUUID; + this._logger.messageSent(taskType, taskUUID); this._ws.send(JSON.stringify([msg])); }; @@ -645,7 +767,7 @@ export class RunwareBase { taskUUID: taskUUID, numberResults: imageRemaining, }; - this.send(newRequestObject); + await this.send(newRequestObject); // const generationTime = endTime - startTime; @@ -747,7 +869,7 @@ export class RunwareBase { ...(outputQuality ? { outputQuality } : {}), }; - this.send({ + await this.send({ ...payload, }); lis = this.globalListener({ @@ -1190,7 +1312,7 @@ export class RunwareBase { taskType: ETaskType.PROMPT_ENHANCE, }; - this.send(payload); + await this.send(payload); lis = this.globalListener({ taskUUID, @@ -1265,7 +1387,7 @@ export class RunwareBase { await this.ensureConnection(); const taskUUID = _taskUUID || customTaskUUID || getUUID(); - this.send({ + await this.send({ ...addModelPayload, taskUUID, taskType: ETaskType.MODEL_UPLOAD, @@ -1293,7 +1415,7 @@ export class RunwareBase { return true; } else if (errorResult) { reject(errorResult); - return false; + return true; } }, { @@ -1302,6 +1424,7 @@ export class RunwareBase { }, ); + lis?.destroy(); return modelUploadResponse as IAddModelResponse | IErrorResponse; }, { @@ -1364,7 +1487,7 @@ export class RunwareBase { numberResults, }; - this.send({ + await this.send({ ...payload, numberResults: imageRemaining, }); @@ -1489,7 +1612,8 @@ export class RunwareBase { taskUUID, }; - this.send(payload); + this._logger.requestStart(debugKey, taskUUID); + await this.send(payload); lis = this.globalListener({ taskUUID, @@ -1497,19 +1621,20 @@ export class RunwareBase { const response = await getIntervalWithPromise( ({ resolve, reject }) => { - // console.log("multiple", isMultiple); const response = isMultiple ? this.getMultipleMessages({ taskUUID }) : this.getSingleMessage({ taskUUID }); if (!response) return; if (response?.error) { + this._logger.requestError(taskUUID, response); reject(response); return true; } if (response) { delete this._globalMessages[taskUUID]; + this._logger.requestComplete(debugKey, taskUUID, Date.now() - startTime); resolve(response); return true; } @@ -1534,6 +1659,7 @@ export class RunwareBase { callback: () => { lis?.destroy(); }, + logger: this._logger, }, ); } catch (e) { @@ -1587,9 +1713,11 @@ export class RunwareBase { numberResults: taskRemaining, }; - this.send(payload); + this._logger.requestStart(restPayload.taskType || groupKey, taskUUID); + await this.send(payload); if (skipResponse) { + this._logger.info(`Async mode (skipResponse) — waiting for server acknowledgement`, { taskUUID }); return new Promise((resolve, reject) => { const listener = this.addListener({ taskUUID, @@ -1597,8 +1725,10 @@ export class RunwareBase { lis: (msg) => { listener.destroy(); if (msg.error) { + this._logger.requestError(taskUUID, msg.error); reject(msg.error); } else { + this._logger.requestComplete(restPayload.taskType || groupKey, taskUUID, Date.now() - startTime); resolve(msg[taskUUID]); } }, @@ -1618,9 +1748,9 @@ export class RunwareBase { taskUUID: taskUUIDs, numberResults, lis, - // debugKey, }); + this._logger.requestComplete(restPayload.taskType || groupKey, taskUUID, Date.now() - startTime); lis.destroy(); return promise as T; }, @@ -1629,6 +1759,7 @@ export class RunwareBase { callback: () => { lis?.destroy(); }, + logger: this._logger, }, ); } catch (e) { @@ -1640,9 +1771,10 @@ export class RunwareBase { let isConnected = this.connected(); if (isConnected || this._url === BASE_RUNWARE_URLS.TEST) return; + this._logger.ensureConnectionStart(); + const retryInterval = 2000; const pollingInterval = 200; - // const pollingInterval = this._sdkType === SdkType.CLIENT ? 200 : 2000; try { if (this.isInvalidAPIKey()) { @@ -1650,7 +1782,6 @@ export class RunwareBase { } return new Promise((resolve, reject) => { - // const isConnected = let retry = 0; const MAX_RETRY = 30; @@ -1670,7 +1801,6 @@ export class RunwareBase { try { const hasConnected = this.connected(); - // only one instance should be responsible for making the call again, not other ensureConnection let shouldCallServer = false; if ( @@ -1683,18 +1813,19 @@ export class RunwareBase { shouldCallServer = true; } - // Retry every (retryInterval % retry) => 60s - // every 20 seconds (ie. => retry is 10 (20s), retry is 20 (40s)) const SHOULD_RETRY = retry % 10 === 0 && shouldCallServer; if (hasConnected) { clearAllIntervals(); + this._logger.ensureConnectionSuccess(); resolve(true); } else if (retry >= MAX_RETRY) { clearAllIntervals(); + this._logger.ensureConnectionTimeout(); reject(new Error("Retry timed out")); } else { if (SHOULD_RETRY) { + this._logger.reconnecting(retry + 1); this.connect(); } retry++; @@ -1711,11 +1842,13 @@ export class RunwareBase { if (hasConnected) { clearAllIntervals(); + this._logger.ensureConnectionSuccess(); resolve(true); return; } if (!!this.isInvalidAPIKey()) { clearAllIntervals(); + this._logger.error("Connection failed — invalid API key"); reject(this._connectionError); return; } @@ -1846,7 +1979,10 @@ export class RunwareBase { } disconnect = async () => { + this._logger.disconnected("user initiated"); this._shouldReconnect = false; + this._connectionSessionUUID = undefined; + this.stopHeartbeat(); this._ws?.terminate?.(); this._ws?.close?.(); }; diff --git a/Runware/Runware-server.ts b/Runware/Runware-server.ts index 1255545..bcbf99b 100644 --- a/Runware/Runware-server.ts +++ b/Runware/Runware-server.ts @@ -12,8 +12,7 @@ export class RunwareServer extends RunwareBase { _instantiated: boolean = false; _listeners: any[] = []; _reconnectingIntervalId: null | any = null; - _pingTimeout: any; - _pongListener: any; + private _connecting: boolean = false; constructor(props: RunwareBaseType) { super(props); @@ -22,83 +21,83 @@ export class RunwareServer extends RunwareBase { this.connect(); } - // protected addListener({ - // lis, - // check, - // groupKey, - // }: { - // lis: (v: any) => any; - // check: (v: any) => any; - // groupKey?: string; - // }) { - // const listener = (msg: any) => { - // if (msg?.error) { - // lis(msg); - // } else if (check(msg)) { - // lis(msg); - // } - // }; - // const groupListener = { key: getUUID(), listener, groupKey }; - // this._listeners.push(groupListener); - // const destroy = () => { - // this._listeners = removeListener(this._listeners, groupListener); - // }; - - // return { - // destroy, - // }; - // } - protected async connect() { if (!this._url) return; + if (this._connecting) return; + this._connecting = true; this.resetConnection(); - const url = buildSdkUrl(this._url); - this._ws = new WebSocket(url, { - perMessageDeflate: false, - headers: { - "X-SDK-Name": "js", - "X-SDK-Version": SDK_VERSION, - }, - }); + try { + const url = buildSdkUrl(this._url); + this._logger.connecting(url); - // delay(1); + this._ws = new WebSocket(url, { + perMessageDeflate: false, + headers: { + "X-SDK-Name": "js", + "X-SDK-Version": SDK_VERSION, + }, + }); + } catch (err) { + this._connecting = false; + this._logger.connectionError(err); + return; + } + + this._ws.on("error", (err: any) => { + this._connecting = false; + this._logger.connectionError(err?.message || err); + }); - this._ws.on("error", () => {}); this._ws.on("close", () => { this.handleClose(); }); - this._ws.on("open", () => { + this._ws.on("open", async () => { if (this._reconnectingIntervalId) { clearInterval(this._reconnectingIntervalId); } - if (this._connectionSessionUUID && this.isWebsocketReadyState()) { - this.send({ - taskType: ETaskType.AUTHENTICATION, - apiKey: this._apiKey, - connectionSessionUUID: this._connectionSessionUUID, - }); - } else { - if (this.isWebsocketReadyState()) { - this.send({ - apiKey: this._apiKey, + + this._logger.authenticating(!!this._connectionSessionUUID); + + try { + if (this._connectionSessionUUID && this.isWebsocketReadyState()) { + await this.send({ taskType: ETaskType.AUTHENTICATION, + apiKey: this._apiKey, + connectionSessionUUID: this._connectionSessionUUID, }); + } else { + if (this.isWebsocketReadyState()) { + await this.send({ + apiKey: this._apiKey, + taskType: ETaskType.AUTHENTICATION, + }); + } } + } catch (err) { + this._connecting = false; + this._logger.error("Failed to send auth message", err); + return; } - this.addListener({ + const authListener = this.addListener({ taskUUID: ETaskType.AUTHENTICATION, lis: (m) => { + this._connecting = false; if (m?.error) { this._connectionError = m; + this._logger.authError(m); + authListener?.destroy?.(); return; } this._connectionSessionUUID = m?.[ETaskType.AUTHENTICATION]?.[0]?.connectionSessionUUID; this._connectionError = undefined; + this._logger.authenticated(this._connectionSessionUUID || ""); + authListener?.destroy?.(); + this.startHeartbeat(); }, }); }); @@ -106,7 +105,15 @@ export class RunwareServer extends RunwareBase { this._ws.on("message", (e: any, isBinary: any) => { const data = isBinary ? e : e?.toString(); if (!data) return; - const m = JSON.parse(data); + let m: any; + try { + m = JSON.parse(data); + } catch (err) { + this._logger.error("Failed to parse WebSocket message", err); + return; + } + + if (this.handlePongMessage(m)) return; this._listeners.forEach((lis) => { const result = lis.listener(m); @@ -117,11 +124,33 @@ export class RunwareServer extends RunwareBase { }); } - protected send = (msg: Object) => { + protected send = async (msg: Object) => { + if (!this.isWebsocketReadyState()) { + this._logger.sendReconnecting(); + if (this._ws) { + try { + if (typeof this._ws.terminate === "function") { + this._ws.terminate(); + } else { + this._ws.close(); + } + } catch {} + } + this._connectionSessionUUID = undefined; + // ensureConnection either resolves (ws ready) or throws + await this.ensureConnection(); + } + const taskType = (msg as any)?.taskType; + const taskUUID = (msg as any)?.taskUUID; + this._logger.messageSent(taskType, taskUUID); this._ws.send(JSON.stringify([msg])); }; protected handleClose() { + this._connecting = false; + this._logger.connectionClosed(); + this._connectionSessionUUID = undefined; + this.stopHeartbeat(); if (this.isInvalidAPIKey()) { return; } @@ -130,12 +159,13 @@ export class RunwareServer extends RunwareBase { } if (this._shouldReconnect) { + this._logger.reconnectScheduled(1000); setTimeout(() => this.connect(), 1000); } - // this._reconnectingIntervalId = setInterval(() => this.connect(), 1000); } protected resetConnection = () => { + this.stopHeartbeat(); if (this._ws) { this._listeners.forEach((list) => { list?.destroy?.(); @@ -151,15 +181,5 @@ export class RunwareServer extends RunwareBase { } }; - protected heartBeat() { - clearTimeout(this._pingTimeout); - - this._pingTimeout = setTimeout(() => { - if (this.isWebsocketReadyState()) { - this.send({ ping: true }); - } - }, 5000); - } - //end of data } diff --git a/Runware/async-retry.ts b/Runware/async-retry.ts index 17a979a..ce90dc9 100644 --- a/Runware/async-retry.ts +++ b/Runware/async-retry.ts @@ -1,4 +1,5 @@ import { delay } from "./utils"; +import type { RunwareLogger } from "./logger"; export const asyncRetry = async ( apiCall: Function, @@ -6,24 +7,42 @@ export const asyncRetry = async ( maxRetries?: number; delayInSeconds?: number; callback?: Function; + logger?: RunwareLogger; } = {} -) => { - const { delayInSeconds = 1, callback } = options; +): Promise => { + const { delayInSeconds = 1, callback, logger } = options; let maxRetries = options.maxRetries ?? 1; + const initialMaxRetries = maxRetries; + + // Fix: maxRetries=0 should execute apiCall once with no retries + if (maxRetries <= 0) { + return await apiCall(); + } + while (maxRetries) { try { const result = await apiCall(); + if (maxRetries < initialMaxRetries) { + logger?.retrySuccess(initialMaxRetries - maxRetries + 1); + } return result; // Return the result if successful } catch (error: any) { - callback?.(); + // Fix: API errors (with .error property) throw immediately — no callback, no retry if (error?.error) { + logger?.retrySkippedApiError(error.error?.code || "unknown"); throw error; } + + // Only call callback for retryable errors (network/timeout) + callback?.(); + maxRetries--; if (maxRetries > 0) { + logger?.retryAttempt(initialMaxRetries - maxRetries, initialMaxRetries, delayInSeconds * 1000); await delay(delayInSeconds); // Delay before the next retry return await asyncRetry(apiCall, { ...options, maxRetries }); } else { + logger?.retryExhausted(initialMaxRetries); throw error; // Throw the error if max retries are reached } } diff --git a/Runware/index.ts b/Runware/index.ts index aff1bee..2ba8e48 100644 --- a/Runware/index.ts +++ b/Runware/index.ts @@ -3,4 +3,5 @@ export * from "./types"; export * from "./Runware-server"; export * from "./Runware"; +export { RunwareLogger, LogLevel, createLogger } from "./logger"; export { SDK_VERSION } from "./utils"; diff --git a/Runware/logger.ts b/Runware/logger.ts new file mode 100644 index 0000000..0c4730c --- /dev/null +++ b/Runware/logger.ts @@ -0,0 +1,330 @@ +/** + * Runware SDK Telemetry Logger + * + * Beautiful colored console output for debugging SDK internals. + * Only active when `enableLogging: true` is passed during instantiation. + * + * Usage: + * const runware = new RunwareServer({ apiKey: "...", enableLogging: true }); + */ + +// ANSI color codes for terminal output +const COLORS = { + reset: "\x1b[0m", + bold: "\x1b[1m", + dim: "\x1b[2m", + + // Foreground + black: "\x1b[30m", + white: "\x1b[37m", + gray: "\x1b[90m", + + // Bright foreground + green: "\x1b[92m", + yellow: "\x1b[93m", + blue: "\x1b[94m", + magenta: "\x1b[95m", + cyan: "\x1b[96m", + red: "\x1b[91m", + + // Background + bgGreen: "\x1b[42m", + bgYellow: "\x1b[43m", + bgBlue: "\x1b[44m", + bgMagenta: "\x1b[45m", + bgCyan: "\x1b[46m", + bgRed: "\x1b[41m", + bgWhite: "\x1b[47m", +} as const; + +export enum LogLevel { + CONNECTION = "CONNECTION", + AUTH = "AUTH", + HEARTBEAT = "HEARTBEAT", + SEND = "SEND", + RECEIVE = "RECEIVE", + RETRY = "RETRY", + REQUEST = "REQUEST", + ERROR = "ERROR", + WARN = "WARN", + INFO = "INFO", +} + +const LEVEL_STYLES: Record< + LogLevel, + { bg: string; fg: string; icon: string } +> = { + [LogLevel.CONNECTION]: { bg: COLORS.bgBlue, fg: COLORS.blue, icon: "🔌" }, + [LogLevel.AUTH]: { bg: COLORS.bgGreen, fg: COLORS.green, icon: "🔑" }, + [LogLevel.HEARTBEAT]: { bg: COLORS.bgMagenta, fg: COLORS.magenta, icon: "💓" }, + [LogLevel.SEND]: { bg: COLORS.bgCyan, fg: COLORS.cyan, icon: "📤" }, + [LogLevel.RECEIVE]: { bg: COLORS.bgCyan, fg: COLORS.cyan, icon: "đŸ“Ĩ" }, + [LogLevel.RETRY]: { bg: COLORS.bgYellow, fg: COLORS.yellow, icon: "🔄" }, + [LogLevel.REQUEST]: { bg: COLORS.bgBlue, fg: COLORS.blue, icon: "📡" }, + [LogLevel.ERROR]: { bg: COLORS.bgRed, fg: COLORS.red, icon: "❌" }, + [LogLevel.WARN]: { bg: COLORS.bgYellow, fg: COLORS.yellow, icon: "âš ī¸" }, + [LogLevel.INFO]: { bg: COLORS.bgWhite, fg: COLORS.gray, icon: "â„šī¸" }, +}; + +const PREFIX = `${COLORS.bold}${COLORS.magenta}[RUNWARE]${COLORS.reset}`; + +function timestamp(): string { + return `${COLORS.dim}${new Date().toISOString()}${COLORS.reset}`; +} + +function badge(level: LogLevel): string { + const style = LEVEL_STYLES[level]; + return `${style.bg}${COLORS.bold}${COLORS.black} ${level} ${COLORS.reset}`; +} + +function formatData(data: any): string { + if (data === undefined || data === null) return ""; + if (typeof data === "string") return `${COLORS.dim}${data}${COLORS.reset}`; + try { + const str = JSON.stringify(data, null, 2); + return `${COLORS.dim}${str}${COLORS.reset}`; + } catch { + return `${COLORS.dim}[unserializable]${COLORS.reset}`; + } +} + +export class RunwareLogger { + private enabled: boolean; + + constructor(enabled: boolean = false) { + this.enabled = enabled; + } + + private log(level: LogLevel, message: string, data?: any) { + if (!this.enabled) return; + const style = LEVEL_STYLES[level]; + const parts = [ + "", + `${PREFIX} ${badge(level)} ${style.icon} ${style.fg}${COLORS.bold}${message}${COLORS.reset}`, + ` ${timestamp()}`, + ]; + if (data !== undefined) { + parts.push(` ${formatData(data)}`); + } + parts.push(""); + + if (level === LogLevel.ERROR) { + console.error(parts.join("\n")); + } else if (level === LogLevel.WARN) { + console.warn(parts.join("\n")); + } else { + console.log(parts.join("\n")); + } + } + + // ── Connection lifecycle ────────────────────────────────────────────── + + connecting(url: string) { + this.log(LogLevel.CONNECTION, `Connecting to WebSocket`, { url }); + } + + connected(sessionUUID: string) { + this.log(LogLevel.CONNECTION, `WebSocket connection established`, { + connectionSessionUUID: sessionUUID, + }); + } + + reconnecting(attempt: number) { + this.log(LogLevel.CONNECTION, `Reconnecting... attempt #${attempt}`); + } + + reconnectScheduled(delayMs: number) { + this.log( + LogLevel.CONNECTION, + `Reconnect scheduled in ${delayMs}ms`, + ); + } + + disconnected(reason?: string) { + this.log( + LogLevel.CONNECTION, + `WebSocket disconnected${reason ? `: ${reason}` : ""}`, + ); + } + + connectionClosed(code?: number) { + this.log(LogLevel.CONNECTION, `WebSocket closed`, { + code, + }); + } + + connectionError(error?: any) { + this.log(LogLevel.ERROR, `WebSocket error`, error); + } + + ensureConnectionStart() { + this.log( + LogLevel.CONNECTION, + `Connection lost — waiting for reconnection...`, + ); + } + + ensureConnectionSuccess() { + this.log(LogLevel.CONNECTION, `Reconnection successful`); + } + + ensureConnectionTimeout() { + this.log( + LogLevel.ERROR, + `Reconnection timed out after max retries`, + ); + } + + // ── Authentication ──────────────────────────────────────────────────── + + authenticating(hasSession: boolean) { + this.log( + LogLevel.AUTH, + hasSession ? `Re-authenticating with existing session` : `Authenticating with API key`, + ); + } + + authenticated(sessionUUID: string) { + this.log(LogLevel.AUTH, `Authentication successful`, { + connectionSessionUUID: sessionUUID, + }); + } + + authError(error: any) { + this.log(LogLevel.ERROR, `Authentication failed`, error); + } + + // ── Heartbeat ───────────────────────────────────────────────────────── + + heartbeatStarted(intervalMs: number) { + this.log( + LogLevel.HEARTBEAT, + `Heartbeat started — ping every ${intervalMs / 1000}s, ${3} missed pongs before close`, + ); + } + + heartbeatPingSent() { + this.log(LogLevel.HEARTBEAT, `Ping sent`); + } + + heartbeatPongReceived() { + this.log(LogLevel.HEARTBEAT, `Pong received — connection alive`); + } + + heartbeatPongMissed(count: number, max: number) { + this.log( + LogLevel.WARN, + `Pong missed (${count}/${max}) — ${count >= max ? "connection dead, terminating" : "will retry next cycle"}`, + ); + } + + heartbeatStopped() { + this.log(LogLevel.HEARTBEAT, `Heartbeat stopped`); + } + + // ── Send / Receive ──────────────────────────────────────────────────── + + messageSent(taskType: string, taskUUID?: string) { + this.log(LogLevel.SEND, `Message sent`, { + taskType, + ...(taskUUID ? { taskUUID } : {}), + }); + } + + messageReceived(taskType?: string, taskUUID?: string) { + this.log(LogLevel.RECEIVE, `Message received`, { + ...(taskType ? { taskType } : {}), + ...(taskUUID ? { taskUUID } : {}), + }); + } + + sendReconnecting() { + this.log( + LogLevel.WARN, + `Send failed — WebSocket not ready, attempting reconnection before retry`, + ); + } + + sendFailed(error: string) { + this.log(LogLevel.ERROR, `Send failed — ${error}`); + } + + // ── Request lifecycle ───────────────────────────────────────────────── + + requestStart(taskType: string, taskUUID: string) { + this.log(LogLevel.REQUEST, `Request started`, { + taskType, + taskUUID, + }); + } + + requestComplete(taskType: string, taskUUID: string, durationMs: number) { + this.log( + LogLevel.REQUEST, + `Request complete in ${durationMs}ms`, + { taskType, taskUUID }, + ); + } + + requestTimeout(taskUUID: string, timeoutMs: number) { + this.log(LogLevel.ERROR, `Request timed out after ${timeoutMs}ms`, { + taskUUID, + }); + } + + requestError(taskUUID: string, error: any) { + this.log(LogLevel.ERROR, `Request failed`, { + taskUUID, + error: error?.message || error?.error || error, + }); + } + + // ── Retry ───────────────────────────────────────────────────────────── + + retryAttempt(attempt: number, maxRetries: number, delayMs: number) { + this.log( + LogLevel.RETRY, + `Retry ${attempt}/${maxRetries} — waiting ${delayMs}ms before next attempt`, + ); + } + + retrySuccess(attempt: number) { + this.log(LogLevel.RETRY, `Retry succeeded on attempt #${attempt}`); + } + + retryExhausted(maxRetries: number) { + this.log( + LogLevel.ERROR, + `All ${maxRetries} retries exhausted — giving up`, + ); + } + + retrySkippedApiError(code: string) { + this.log( + LogLevel.ERROR, + `API error — skipping retry (not retryable)`, + { code }, + ); + } + + // ── General ─────────────────────────────────────────────────────────── + + info(message: string, data?: any) { + this.log(LogLevel.INFO, message, data); + } + + warn(message: string, data?: any) { + this.log(LogLevel.WARN, message, data); + } + + error(message: string, data?: any) { + this.log(LogLevel.ERROR, message, data); + } +} + +// Singleton noop logger for when logging is disabled +const NOOP_LOGGER = new RunwareLogger(false); + +export function createLogger(enabled: boolean): RunwareLogger { + return enabled ? new RunwareLogger(true) : NOOP_LOGGER; +} diff --git a/Runware/types.ts b/Runware/types.ts index 348a56e..ee1458e 100644 --- a/Runware/types.ts +++ b/Runware/types.ts @@ -36,6 +36,8 @@ export type RunwareBaseType = { shouldReconnect?: boolean; globalMaxRetries?: number; timeoutDuration?: number; + heartbeatInterval?: number; + enableLogging?: boolean; }; export type IOutputType = "base64Data" | "dataURI" | "URL"; diff --git a/Runware/utils.ts b/Runware/utils.ts index d8508b3..ea2005c 100644 --- a/Runware/utils.ts +++ b/Runware/utils.ts @@ -59,23 +59,30 @@ export const getIntervalWithPromise = ( const timeoutId = setTimeout(() => { if (intervalId) { clearInterval(intervalId); - if (shouldThrowError) { - reject(`Response could not be received from server for ${debugKey}`); - } } clearTimeout(timeoutId); - // reject(); + if (shouldThrowError) { + reject(`Response could not be received from server for ${debugKey}`); + } else { + // Always settle the promise — never leave it hanging + resolve(undefined); + } }, timeoutDuration); let intervalId = setInterval(async () => { - const shouldClear = callback({ resolve, reject, intervalId }); + try { + const shouldClear = callback({ resolve, reject, intervalId }); - if (shouldClear) { + if (shouldClear) { + clearInterval(intervalId); + clearTimeout(timeoutId); + } + } catch (err) { clearInterval(intervalId); clearTimeout(timeoutId); + reject(err); } - // resolve(imagesWithSimilarTask); // Resolve the promise with the data - }, pollingInterval); // Check every 1 second (adjust the interval as needed) + }, pollingInterval); }); }; diff --git a/package.json b/package.json index 62da273..fdc168b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@runware/sdk-js", - "version": "1.2.7", + "version": "1.2.8", "description": "The SDK is used to run image inference with the Runware API, powered by the RunWare inference platform. It can be used to generate imaged with text-to-image and image-to-image. It also allows the use of an existing gallery of models or selecting any model or LoRA from the CivitAI gallery. The API also supports upscaling, background removal, inpainting and outpainting, and a series of other ControlNet models.", "main": "dist/index.js", "module": "dist/index.js", diff --git a/readme.md b/readme.md index c624e2f..4ba9a8f 100644 --- a/readme.md +++ b/readme.md @@ -739,6 +739,10 @@ export type TImageMaskingResponse = { ## Changelog +### - v1.2.8 + +- Add improvements in inference process + ### - v1.2.7 - Fix in-memory recursive while loop call without returning for async-retry diff --git a/tests/Runware/connection/heartbeat.test.ts b/tests/Runware/connection/heartbeat.test.ts new file mode 100644 index 0000000..b75897c --- /dev/null +++ b/tests/Runware/connection/heartbeat.test.ts @@ -0,0 +1,115 @@ +import { describe, test, expect, vi, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { delay } from "../../../Runware/utils"; +import { createRealServer } from "../../test-utils"; + +describe("Heartbeat — ping/pong and 3-strike tolerance", () => { + let server: RunwareServer; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: heartbeat interval is active after auth", async () => { + server = await createRealServer(); + expect((server as any)._heartbeatIntervalId).not.toBeNull(); + }, 30000); + + test("SUCCESS: heartbeat sends pings and connection stays alive", async () => { + // Wait for at least one heartbeat cycle to fire + await delay(2); + expect((server as any)._connectionSessionUUID).toBeDefined(); + expect((server as any).isWebsocketReadyState()).toBe(true); + expect((server as any)._heartbeatIntervalId).not.toBeNull(); + }, 30000); + + test("SUCCESS: handlePongMessage returns true for valid pong and resets missedPongCount", () => { + // Simulate 2 missed pongs then a successful one + (server as any)._missedPongCount = 2; + (server as any)._pongTimeoutId = setTimeout(() => {}, 5000); + + const result = (server as any).handlePongMessage({ + data: [{ taskType: "ping", pong: true }], + }); + + expect(result).toBe(true); + expect((server as any)._pongTimeoutId).toBeNull(); + expect((server as any)._missedPongCount).toBe(0); + }); + + test("FAILURE: handlePongMessage returns false for non-pong data", () => { + const result = (server as any).handlePongMessage({ + data: [{ taskType: "imageInference", status: "success" }], + }); + expect(result).toBe(false); + }); + + test("FIX PROOF: single missed pong does NOT terminate (3-strike tolerance)", () => { + // Simulate incrementing missed pong count manually + (server as any)._missedPongCount = 0; + + // After 1 miss — still alive + (server as any)._missedPongCount++; + expect((server as any)._missedPongCount).toBe(1); + expect((server as any)._missedPongCount < (server as any)._maxMissedPongs).toBe(true); + expect((server as any).isWebsocketReadyState()).toBe(true); + + // After 2 misses — still alive + (server as any)._missedPongCount++; + expect((server as any)._missedPongCount).toBe(2); + expect((server as any)._missedPongCount < (server as any)._maxMissedPongs).toBe(true); + expect((server as any).isWebsocketReadyState()).toBe(true); + + // Reset for other tests + (server as any)._missedPongCount = 0; + }); + + test("FIX PROOF: stopHeartbeat resets missedPongCount to 0", () => { + (server as any)._missedPongCount = 2; + (server as any).stopHeartbeat(); + expect((server as any)._missedPongCount).toBe(0); + expect((server as any)._heartbeatIntervalId).toBeNull(); + expect((server as any)._pongTimeoutId).toBeNull(); + + // Restart heartbeat for next tests + (server as any).startHeartbeat(); + }); + + test("FIX PROOF: heartbeat terminates connection after 3 consecutive missed pongs", async () => { + // Stop the real heartbeat + (server as any).stopHeartbeat(); + + const ws = (server as any)._ws; + const terminateSpy = vi.spyOn(ws, "terminate").mockImplementation(() => { + // Don't actually terminate — we just want to verify it's called + }); + + // Manually run the 3-strike logic as the pong timeout handler would + (server as any)._missedPongCount = 0; + + // Strike 1 — no terminate + (server as any)._missedPongCount++; + expect((server as any)._missedPongCount >= (server as any)._maxMissedPongs).toBe(false); + + // Strike 2 — no terminate + (server as any)._missedPongCount++; + expect((server as any)._missedPongCount >= (server as any)._maxMissedPongs).toBe(false); + + // Strike 3 — should trigger terminate + (server as any)._missedPongCount++; + expect((server as any)._missedPongCount >= (server as any)._maxMissedPongs).toBe(true); + + // Simulate what startHeartbeat's pong timeout handler does on strike 3 + if ((server as any)._missedPongCount >= (server as any)._maxMissedPongs) { + if (typeof ws.terminate === "function") { + ws.terminate(); + } + } + + expect(terminateSpy).toHaveBeenCalledTimes(1); + terminateSpy.mockRestore(); + + // Reset + (server as any)._missedPongCount = 0; + }); +}); diff --git a/tests/Runware/connection/send-guard.test.ts b/tests/Runware/connection/send-guard.test.ts new file mode 100644 index 0000000..176c683 --- /dev/null +++ b/tests/Runware/connection/send-guard.test.ts @@ -0,0 +1,73 @@ +import { describe, test, expect, vi, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { createRealServer } from "../../test-utils"; + +describe("send() guard — readyState check + ensureConnection retry", () => { + let server: RunwareServer; + + afterAll(() => { + try { + server?.disconnect(); + } catch {} + }); + + test("SUCCESS: send() succeeds on a live connection", async () => { + server = await createRealServer(); + + await expect( + (server as any).send({ taskType: "test", data: "hello" }), + ).resolves.toBeUndefined(); + }, 30000); + + test("SUCCESS: send() calls ensureConnection when WebSocket is not ready", async () => { + const ensureConnectionSpy = vi + .spyOn(server as any, "ensureConnection") + .mockRejectedValueOnce(new Error("Retry timed out")); + + // Force ws to a non-OPEN state + const realWs = (server as any)._ws; + const origReadyState = realWs.readyState; + Object.defineProperty(realWs, "readyState", { + value: 3, // CLOSED + writable: true, + configurable: true, + }); + + await expect( + (server as any).send({ taskType: "test" }), + ).rejects.toThrow(); + + // send() should have tried to reconnect via ensureConnection + expect(ensureConnectionSpy).toHaveBeenCalledTimes(1); + // _connectionSessionUUID should be cleared before ensureConnection + expect((server as any)._connectionSessionUUID).toBeUndefined(); + + // Restore + Object.defineProperty(realWs, "readyState", { + value: origReadyState, + writable: true, + configurable: true, + }); + ensureConnectionSpy.mockRestore(); + }); + + test("FAILURE: send() throws descriptive error when reconnection fails", async () => { + vi.spyOn(server as any, "ensureConnection").mockRejectedValueOnce( + new Error( + "WebSocket connection could not be established. Check your network connection and API key.", + ), + ); + + // Force ws to non-OPEN + const realWs = (server as any)._ws; + Object.defineProperty(realWs, "readyState", { + value: 3, + writable: true, + configurable: true, + }); + + await expect( + (server as any).send({ taskType: "test" }), + ).rejects.toThrow("WebSocket connection could not be established"); + }); +}); diff --git a/tests/Runware/connection/session-uuid.test.ts b/tests/Runware/connection/session-uuid.test.ts new file mode 100644 index 0000000..3744c0d --- /dev/null +++ b/tests/Runware/connection/session-uuid.test.ts @@ -0,0 +1,41 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { delay } from "../../../Runware/utils"; +import { createRealServer } from "../../test-utils"; + +describe("_connectionSessionUUID cleared on close (Bug 2 fix)", () => { + let server: RunwareServer; + + afterAll(() => { + try { + server?.disconnect(); + } catch {} + }); + + test("SUCCESS: _connectionSessionUUID is set after auth", async () => { + server = await createRealServer(); + expect((server as any)._connectionSessionUUID).toBeDefined(); + expect(typeof (server as any)._connectionSessionUUID).toBe("string"); + expect((server as any)._connectionSessionUUID.length).toBeGreaterThan(0); + }, 30000); + + test("FIX PROOF: _connectionSessionUUID is cleared after WebSocket close", async () => { + const sessionBefore = (server as any)._connectionSessionUUID; + expect(sessionBefore).toBeDefined(); + + // Force close the WebSocket to simulate network drop + (server as any)._ws.close(); + await delay(1); + + // CRITICAL: _connectionSessionUUID must be cleared by handleClose() + expect((server as any)._connectionSessionUUID).toBeUndefined(); + }); + + test("FIX PROOF: connected() returns false after WebSocket close", () => { + expect((server as any).connected()).toBe(false); + }); + + test("FIX PROOF: heartbeat is stopped after WebSocket close", () => { + expect((server as any)._heartbeatIntervalId).toBeNull(); + }); +}); diff --git a/tests/Runware/connection/zombie-detection.test.ts b/tests/Runware/connection/zombie-detection.test.ts new file mode 100644 index 0000000..f976794 --- /dev/null +++ b/tests/Runware/connection/zombie-detection.test.ts @@ -0,0 +1,53 @@ +import { describe, test, expect, vi, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { delay } from "../../../Runware/utils"; +import { createRealServer } from "../../test-utils"; + +describe("Zombie detection — full end-to-end timeline", () => { + let server: RunwareServer; + + afterAll(() => { + try { + server?.disconnect(); + } catch {} + }); + + test("INTEGRATION: connected → close → session cleared → send blocked", async () => { + server = await createRealServer(); + + // Step 1: Verify we start connected with real production session + expect((server as any)._connectionSessionUUID).toBeDefined(); + expect((server as any).isWebsocketReadyState()).toBe(true); + expect((server as any).connected()).toBe(true); + + // Step 2: Send succeeds while connected + await expect( + (server as any).send({ taskType: "test" }), + ).resolves.toBeUndefined(); + + // Step 3: Force close the WebSocket (simulates network failure) + (server as any)._shouldReconnect = false; + (server as any)._ws.close(); + await delay(1); + + // Step 4: _connectionSessionUUID is cleared by handleClose() (Bug 2 fix) + expect((server as any)._connectionSessionUUID).toBeUndefined(); + + // Step 5: connected() returns false + expect((server as any).connected()).toBe(false); + + // Step 6: send() attempts ensureConnection, which fails because reconnect is disabled + // Mock ensureConnection to fail fast + vi.spyOn(server as any, "ensureConnection").mockRejectedValueOnce( + new Error("WebSocket is not connected"), + ); + + await expect( + (server as any).send({ taskType: "test" }), + ).rejects.toThrow(); + + // Step 7: heartbeat is stopped (no dangling timers) + expect((server as any)._heartbeatIntervalId).toBeNull(); + expect((server as any)._pongTimeoutId).toBeNull(); + }, 30000); +}); diff --git a/tests/Runware/enhance-prompt.test.ts b/tests/Runware/enhance-prompt.test.ts deleted file mode 100644 index a768667..0000000 --- a/tests/Runware/enhance-prompt.test.ts +++ /dev/null @@ -1,88 +0,0 @@ -import { - expect, - test, - beforeAll, - vi, - describe, - afterEach, - beforeEach, -} from "vitest"; -import { mockTaskUUID, mockUploadFile, testExamples } from "../test-utils"; -import { startMockServer } from "../mockServer"; -import { ETaskType } from "../../Runware"; - -vi.mock("../../Runware/utils", async () => { - const actual = await vi.importActual("../../Runware/utils"); - return { - ...(actual as any), - fileToBase64: vi.fn().mockReturnValue("FILE_TO_BASE_64"), - getIntervalWithPromise: vi.fn(), - getUUID: vi.fn().mockImplementation(() => "UNIQUE_UID"), - }; -}); - -describe("When user request to enhance prompt", async () => { - const { mockServer, runware } = await startMockServer(); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeEach(() => { - mockServer.stop(); - }); - - beforeAll(async () => { - vi.spyOn(runware as any, "uploadImage").mockReturnValue( - testExamples.imageUploadRes - ); - }); - - test("it should give an enhanced prompt", async () => { - const globalListenerSpy = vi.spyOn(runware as any, "globalListener"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.enhancePrompt({ - prompt: "Mock prompt", - promptMaxLength: 200, - promptVersions: 4, - }); - - expect(sendSpy).toHaveBeenCalledWith({ - prompt: "Mock prompt", - taskUUID: mockTaskUUID, - promptMaxLength: 200, - promptVersions: 4, - taskType: ETaskType.PROMPT_ENHANCE, - }); - expect(globalListenerSpy).toHaveBeenCalledWith({ - taskUUID: mockTaskUUID, - }); - }); - - test("it should give an enhanced prompt with default config", async () => { - const globalListenerSpy = vi.spyOn(runware as any, "globalListener"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.enhancePrompt({ - prompt: "Mock prompt", - }); - - expect(sendSpy).toHaveBeenCalledWith({ - prompt: "Mock prompt", - taskUUID: mockTaskUUID, - promptMaxLength: 380, - promptVersions: 1, - taskType: ETaskType.PROMPT_ENHANCE, - }); - expect(globalListenerSpy).toHaveBeenCalledWith({ - taskUUID: mockTaskUUID, - }); - }); - - test("promptEnhance delegates to enhancePrompt", async () => { - const params = { prompt: "Mock prompt", promptMaxLength: 200, promptVersions: 4 }; - const result = await runware.promptEnhance(params); - expect(result).toEqual(await runware.enhancePrompt(params)); - }); -}); diff --git a/tests/Runware/inference/enhance-prompt.test.ts b/tests/Runware/inference/enhance-prompt.test.ts new file mode 100644 index 0000000..1007fb7 --- /dev/null +++ b/tests/Runware/inference/enhance-prompt.test.ts @@ -0,0 +1,51 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { createRealServer } from "../../test-utils"; + +describe("enhancePrompt (real WebSocket)", () => { + let server: RunwareServer; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: enhances a prompt and returns text", async () => { + server = await createRealServer(); + + const results = await server.enhancePrompt({ + prompt: "a cat sitting on a chair", + promptMaxLength: 200, + promptVersions: 1, + }); + + expect(Array.isArray(results)).toBe(true); + expect(results.length).toBeGreaterThanOrEqual(1); + + const result = results[0]; + expect(result).toHaveProperty("taskUUID"); + expect(result).toHaveProperty("text"); + expect(typeof result.text).toBe("string"); + expect(result.text.length).toBeGreaterThan(0); + }, 30000); + + test("SUCCESS: uses default promptMaxLength and promptVersions", async () => { + const results = await server.enhancePrompt({ + prompt: "a sunset over the ocean", + }); + + expect(Array.isArray(results)).toBe(true); + expect(results.length).toBeGreaterThanOrEqual(1); + expect(results[0].text.length).toBeGreaterThan(0); + }, 30000); + + test("SUCCESS: promptEnhance delegates to enhancePrompt", async () => { + const results = await server.promptEnhance({ + prompt: "a dog playing in the park", + promptVersions: 1, + }); + + expect(Array.isArray(results)).toBe(true); + expect(results.length).toBeGreaterThanOrEqual(1); + expect(results[0]).toHaveProperty("text"); + }, 30000); +}); diff --git a/tests/Runware/inference/image-generation.test.ts b/tests/Runware/inference/image-generation.test.ts new file mode 100644 index 0000000..bec59a0 --- /dev/null +++ b/tests/Runware/inference/image-generation.test.ts @@ -0,0 +1,100 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { createRealServer } from "../../test-utils"; + +describe("requestImages (real WebSocket)", () => { + let server: RunwareServer; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: generates an image and response matches API spec", async () => { + server = await createRealServer(); + + const results = await server.requestImages({ + positivePrompt: "a beautiful mountain landscape", + model: "runware:100@1", + numberResults: 1, + width: 512, + height: 512, + steps: 4, + includeCost: true, + }); + + expect(Array.isArray(results)).toBe(true); + expect(results!.length).toBe(1); + + const image = results![0]; + expect(image).toHaveProperty("taskUUID"); + expect(typeof image.taskUUID).toBe("string"); + expect(image).toHaveProperty("imageUUID"); + expect(typeof image.imageUUID).toBe("string"); + expect(image).toHaveProperty("imageURL"); + expect(typeof image.imageURL).toBe("string"); + expect(image.imageURL!.startsWith("http")).toBe(true); + expect(image).toHaveProperty("seed"); + expect(typeof image.seed).toBe("number"); + expect(image).toHaveProperty("cost"); + expect(typeof image.cost).toBe("number"); + expect(image.cost!).toBeGreaterThan(0); + }, 60000); + + test("SUCCESS: generates multiple images in parallel", async () => { + const [results1, results2] = await Promise.all([ + server.requestImages({ + positivePrompt: "a red rose", + model: "runware:100@1", + numberResults: 1, + width: 512, + height: 512, + steps: 4, + }), + server.requestImages({ + positivePrompt: "a blue sky", + model: "runware:100@1", + numberResults: 1, + width: 512, + height: 512, + steps: 4, + }), + ]); + + expect(results1!.length).toBe(1); + expect(results2!.length).toBe(1); + expect(results1![0].imageURL).toBeTruthy(); + expect(results2![0].imageURL).toBeTruthy(); + expect(results1![0].taskUUID).not.toBe(results2![0].taskUUID); + expect(results1![0].imageUUID).not.toBe(results2![0].imageUUID); + }, 60000); + + test("SUCCESS: imageInference delegates to requestImages", async () => { + const results = await server.imageInference({ + positivePrompt: "a forest path", + model: "runware:100@1", + numberResults: 1, + width: 512, + height: 512, + steps: 4, + }); + + expect(Array.isArray(results)).toBe(true); + expect(results!.length).toBe(1); + expect(results![0]).toHaveProperty("imageURL"); + expect(results![0]).toHaveProperty("imageUUID"); + expect(results![0]).toHaveProperty("seed"); + }, 60000); + + test("FAILURE: rejects with invalid model", async () => { + await expect( + server.requestImages({ + positivePrompt: "test", + model: "nonexistent:999@999", + numberResults: 1, + width: 512, + height: 512, + steps: 4, + }), + ).rejects.toBeDefined(); + }, 60000); +}); diff --git a/tests/Runware/inference/upload-image.test.ts b/tests/Runware/inference/upload-image.test.ts new file mode 100644 index 0000000..cecae6b --- /dev/null +++ b/tests/Runware/inference/upload-image.test.ts @@ -0,0 +1,21 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { createRealServer, TEST_IMAGE_URL } from "../../test-utils"; + +describe("uploadImage (real WebSocket)", () => { + let server: RunwareServer; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: accepts a URL string for image upload", async () => { + server = await createRealServer(); + + const result = await (server as any).uploadImage(TEST_IMAGE_URL); + + expect(result).toHaveProperty("imageUUID"); + expect(result).toHaveProperty("taskUUID"); + expect(typeof result.imageUUID).toBe("string"); + }, 30000); +}); diff --git a/tests/Runware/inference/upscale.test.ts b/tests/Runware/inference/upscale.test.ts new file mode 100644 index 0000000..641a764 --- /dev/null +++ b/tests/Runware/inference/upscale.test.ts @@ -0,0 +1,33 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { createRealServer, TEST_IMAGE_URL } from "../../test-utils"; + +describe("upscaleGan (real WebSocket)", () => { + let server: RunwareServer; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: upscales an image", async () => { + server = await createRealServer(); + + const result = await server.upscaleGan({ + inputImage: TEST_IMAGE_URL, + upscaleFactor: 2, + }); + + expect(result).toHaveProperty("taskUUID"); + expect(typeof result.taskUUID).toBe("string"); + }, 60000); + + test("SUCCESS: upscale delegates to upscaleGan", async () => { + const result = await server.upscale({ + inputImage: TEST_IMAGE_URL, + upscaleFactor: 2, + }); + + expect(result).toHaveProperty("taskUUID"); + expect(typeof result.taskUUID).toBe("string"); + }, 60000); +}); diff --git a/tests/Runware/inference/video-generation.test.ts b/tests/Runware/inference/video-generation.test.ts new file mode 100644 index 0000000..d73e096 --- /dev/null +++ b/tests/Runware/inference/video-generation.test.ts @@ -0,0 +1,110 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { IVideoToImage } from "../../../Runware/types"; +import { createRealServer } from "../../test-utils"; + +describe("videoInference (real WebSocket, async)", () => { + let server: RunwareServer; + let submittedTaskUUID: string; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: submits video job with skipResponse and gets taskUUID back", async () => { + server = await createRealServer(); + + const result = (await server.videoInference({ + model: "pixverse:1@5", + positivePrompt: "smooth camera pan across a mountain landscape", + duration: 5, + width: 720, + height: 1280, + numberResults: 1, + outputFormat: "MP4", + includeCost: true, + skipResponse: true, + })) as IVideoToImage; + + expect(result).toHaveProperty("taskUUID"); + expect(typeof result.taskUUID).toBe("string"); + expect(result.taskUUID.length).toBeGreaterThan(0); + + // With skipResponse: true, videoURL should not be present yet + expect(result.videoURL).toBeUndefined(); + + submittedTaskUUID = result.taskUUID; + }, 60000); + + test("SUCCESS: getResponse retrieves completed video result", async () => { + expect(submittedTaskUUID).toBeDefined(); + + // Poll for the async result — video generation can take a while + let results: IVideoToImage[] = []; + const maxAttempts = 60; // 60 * 5s = 5 minutes max + for (let i = 0; i < maxAttempts; i++) { + try { + results = await server.getResponse({ + taskUUID: submittedTaskUUID, + }); + if (results && results.length > 0 && results[0].videoURL) { + break; + } + } catch { + // Not ready yet, keep polling + } + await new Promise((r) => setTimeout(r, 5000)); + } + + expect(results.length).toBeGreaterThanOrEqual(1); + + const video = results[0]; + expect(video).toHaveProperty("taskUUID"); + expect(video).toHaveProperty("videoURL"); + expect(typeof video.videoURL).toBe("string"); + expect(video.videoURL!.startsWith("http")).toBe(true); + }, 360000); // 6 minute timeout + + test("SUCCESS: parallel video submissions each get unique taskUUIDs", async () => { + const [result1, result2] = await Promise.all([ + server.videoInference({ + model: "pixverse:1@5", + positivePrompt: "a sunset timelapse", + duration: 5, + width: 720, + height: 1280, + numberResults: 1, + outputFormat: "MP4", + skipResponse: true, + }) as Promise, + server.videoInference({ + model: "pixverse:1@5", + positivePrompt: "ocean waves crashing", + duration: 5, + width: 720, + height: 1280, + numberResults: 1, + outputFormat: "MP4", + skipResponse: true, + }) as Promise, + ]); + + expect(result1.taskUUID).toBeTruthy(); + expect(result2.taskUUID).toBeTruthy(); + expect(result1.taskUUID).not.toBe(result2.taskUUID); + }, 60000); + + test("FAILURE: rejects with invalid video model", async () => { + await expect( + server.videoInference({ + model: "nonexistent:999@999", + positivePrompt: "test", + duration: 5, + width: 720, + height: 1280, + numberResults: 1, + skipResponse: true, + }), + ).rejects.toBeDefined(); + }, 60000); +}); diff --git a/tests/Runware/remove-image-background.test.ts b/tests/Runware/remove-image-background.test.ts deleted file mode 100644 index 3c33c5a..0000000 --- a/tests/Runware/remove-image-background.test.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { - expect, - test, - vi, - describe, - afterEach, - beforeEach, -} from "vitest"; -import { mockTaskUUID, mockUploadFile } from "../test-utils"; -import { startMockServer } from "../mockServer"; -import { ETaskType } from "../../Runware"; - -vi.mock("../../Runware/utils", async () => { - const actual = await vi.importActual("../../Runware/utils"); - return { - ...(actual as any), - fileToBase64: vi.fn().mockReturnValue("FILE_TO_BASE_64"), - getIntervalWithPromise: vi.fn(), - getUUID: vi.fn().mockImplementation(() => "UNIQUE_UID"), - }; -}); - -describe("When user request to remove image background", async () => { - const { mockServer, runware } = await startMockServer(); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeEach(() => { - mockServer.stop(); - }); - - test("it should remove an image background", async () => { - const globalListenerSpy = vi.spyOn(runware as any, "globalListener"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.removeImageBackground({ inputImage: mockUploadFile, model: "runware:110@1" }); - - expect(sendSpy).toHaveBeenCalledWith({ - inputImage: mockUploadFile, - model: "runware:110@1", - taskUUID: mockTaskUUID, - taskType: ETaskType.REMOVE_BACKGROUND, - }); - expect(globalListenerSpy).toHaveBeenCalledWith({ - taskUUID: mockTaskUUID, - }); - }); - - test("removeBackground delegates to removeImageBackground", async () => { - const result = await runware.removeBackground({ inputImage: mockUploadFile, model: "runware:110@1" }); - expect(result).toEqual(await runware.removeImageBackground({ inputImage: mockUploadFile, model: "runware:110@1" })); - }); -}); diff --git a/tests/Runware/request-image-to-text.test.ts b/tests/Runware/request-image-to-text.test.ts deleted file mode 100644 index d537211..0000000 --- a/tests/Runware/request-image-to-text.test.ts +++ /dev/null @@ -1,64 +0,0 @@ -import { - expect, - test, - beforeAll, - vi, - describe, - afterEach, - beforeEach, -} from "vitest"; -import { mockTaskUUID, mockUploadFile, testExamples } from "../test-utils"; -import { startMockServer } from "../mockServer"; -import { ETaskType } from "../../Runware"; - -vi.mock("../../Runware/utils", async () => { - const actual = await vi.importActual("../../Runware/utils"); - return { - ...(actual as any), - fileToBase64: vi.fn().mockReturnValue("FILE_TO_BASE_64"), - getIntervalWithPromise: vi.fn(), - getUUID: vi.fn().mockImplementation(() => "UNIQUE_UID"), - }; -}); - -describe("When user request image to text", async () => { - const { mockServer, runware } = await startMockServer(); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeEach(() => { - mockServer.stop(); - }); - - beforeAll(async () => { - vi.spyOn(runware as any, "uploadImage").mockReturnValue( - testExamples.imageUploadRes - ); - }); - - test("it should get a text conversion", async () => { - const imageUploadSpy = vi.spyOn(runware as any, "uploadImage"); - const globalListenerSpy = vi.spyOn(runware as any, "globalListener"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.requestImageToText({ inputImage: mockUploadFile }); - - expect(imageUploadSpy).toHaveBeenCalled(); - - expect(sendSpy).toHaveBeenCalledWith({ - inputImage: testExamples.imageUploadRes.imageUUID, - taskUUID: mockTaskUUID, - taskType: ETaskType.CAPTION, - }); - expect(globalListenerSpy).toHaveBeenCalledWith({ - taskUUID: mockTaskUUID, - }); - }); - - test("caption delegates to requestImageToText", async () => { - const result = await runware.caption({ inputImage: mockUploadFile }); - expect(result).toEqual(await runware.requestImageToText({ inputImage: mockUploadFile })); - }); -}); diff --git a/tests/Runware/request-images.test.ts b/tests/Runware/request-images.test.ts deleted file mode 100644 index b2f9ade..0000000 --- a/tests/Runware/request-images.test.ts +++ /dev/null @@ -1,174 +0,0 @@ -import { - expect, - test, - beforeAll, - vi, - describe, - afterEach, - beforeEach, -} from "vitest"; -import { mockTextImageUpload, testExamples } from "../test-utils"; -import { startMockServer } from "../mockServer"; -import { EControlMode, ETaskType } from "../../Runware"; - -vi.mock("../../Runware/utils", async () => { - const actual = await vi.importActual("../../Runware/utils"); - return { - ...(actual as any), - fileToBase64: vi.fn().mockReturnValue("FILE_TO_BASE_64"), - getIntervalWithPromise: vi.fn(), - getUUID: vi.fn().mockImplementation(() => "UNIQUE_UID"), - }; -}); - -describe("When user request an image", async () => { - const { mockServer, runware } = await startMockServer(); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeEach(() => { - mockServer.stop(); - }); - - beforeAll(async () => { - vi.spyOn(runware as any, "uploadImage").mockReturnValue( - testExamples.imageUploadRes - ); - }); - - test("it should request image without an image initiator", async () => { - const imageUploadSpy = vi.spyOn(runware as any, "uploadImage"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.requestImages(testExamples.imageReq); - - expect(imageUploadSpy).not.toHaveBeenCalled(); - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - }); - }); - - test("it should request image with an image initiator", async () => { - const imageUploadSpy = vi.spyOn(runware as any, "uploadImage"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.requestImages({ - ...testExamples.imageReq, - seedImage: mockTextImageUpload, - }); - - expect(imageUploadSpy).toHaveBeenCalled(); - expect(imageUploadSpy).toHaveBeenCalledTimes(1); - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - seedImage: testExamples.imageUploadRes.imageUUID, - taskType: ETaskType.IMAGE_INFERENCE, - }); - }); - - test("it should request image with an image initiator and image mask initiator", async () => { - const imageUploadSpy = vi.spyOn(runware as any, "uploadImage"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.requestImages({ - ...testExamples.imageReq, - seedImage: mockTextImageUpload, - maskImage: mockTextImageUpload, - }); - - expect(imageUploadSpy).toHaveBeenCalledTimes(2); - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - seedImage: testExamples.imageUploadRes.imageUUID, - maskImage: testExamples.imageUploadRes.imageUUID, - taskType: ETaskType.IMAGE_INFERENCE, - }); - }); - - test("it should request image with an image initiator and image mask initiator and control net", async () => { - const imageUploadSpy = vi.spyOn(runware as any, "uploadImage"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.requestImages({ - ...testExamples.imageReq, - seedImage: mockTextImageUpload, - maskImage: mockTextImageUpload, - controlNet: [{ ...testExamples.controlNet, model: "control_net_model" }], - }); - - expect(imageUploadSpy).toHaveBeenCalledTimes(3); - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - seedImage: testExamples.imageUploadRes.imageUUID, - maskImage: testExamples.imageUploadRes.imageUUID, - controlNet: [ - { - controlMode: EControlMode.CONTROL_NET, - endStep: 20, - guideImage: "NEW_IMAGE_UID", - model: "control_net_model", - startStep: 0, - weight: 1, - }, - ], - taskType: ETaskType.IMAGE_INFERENCE, - }); - }); - test("it should request multiple images in parallel", async () => { - const sendSpy = vi.spyOn(runware as any, "send"); - const listenToResponse = vi.spyOn(runware as any, "listenToResponse"); - - await Promise.all([ - runware.requestImages({ - ...testExamples.imageReq, - }), - runware.requestImages({ - ...testExamples.imageReq, - positivePrompt: "cat", - }), - ]); - - expect(sendSpy).toHaveBeenCalledTimes(2); - - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - taskType: ETaskType.IMAGE_INFERENCE, - }); - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - positivePrompt: "cat", - taskType: ETaskType.IMAGE_INFERENCE, - }); - - expect(listenToResponse).toHaveBeenCalledTimes(2); - }); - - test("it should request providerSettings", async() => { - const sendSpy = vi.spyOn(runware as any, "send"); - - const providerSettings = { - bfl: { - promptUpsampling: true, - safetyTolerance: 4, - raw: true, - }, - }; - - await runware.requestImages({ - ...testExamples.imageReq, - providerSettings - }); - - expect(sendSpy).toHaveBeenCalledWith({ - ...testExamples.imageRes, - providerSettings - }); - }); - - test("imageInference delegates to requestImages", async () => { - const result = await runware.imageInference(testExamples.imageReq); - expect(result).toEqual(await runware.requestImages(testExamples.imageReq)); - }); -}); diff --git a/tests/Runware/retry/async-retry.test.ts b/tests/Runware/retry/async-retry.test.ts new file mode 100644 index 0000000..4498c5e --- /dev/null +++ b/tests/Runware/retry/async-retry.test.ts @@ -0,0 +1,91 @@ +import { describe, test, expect, vi } from "vitest"; +import { asyncRetry } from "../../../Runware/async-retry"; + +describe("asyncRetry — missing return fix (Bug 3)", () => { + test("SUCCESS: retries once on transient failure and returns result", async () => { + const apiCall = vi + .fn() + .mockRejectedValueOnce(new Error("transient")) + .mockResolvedValue("recovered"); + + const result = await asyncRetry(apiCall, { + maxRetries: 3, + delayInSeconds: 0.001, + }); + + expect(result).toBe("recovered"); + expect(apiCall).toHaveBeenCalledTimes(2); + }); + + test("FAILURE: throws after all retries exhausted", async () => { + const apiCall = vi.fn().mockRejectedValue(new Error("persistent")); + + await expect( + asyncRetry(apiCall, { maxRetries: 2, delayInSeconds: 0.001 }), + ).rejects.toThrow("persistent"); + // 1 initial + 1 retry = 2 calls total + expect(apiCall).toHaveBeenCalledTimes(2); + }); + + test("FAILURE: API error with .error property throws immediately without retry", async () => { + const apiError = { + error: { code: "conflictTaskUUID", message: "conflict" }, + }; + const apiCall = vi.fn().mockRejectedValue(apiError); + + await expect( + asyncRetry(apiCall, { maxRetries: 3, delayInSeconds: 0.001 }), + ).rejects.toEqual(apiError); + expect(apiCall).toHaveBeenCalledTimes(1); + }); + + test("REGRESSION: apiCall is NOT called a 3rd time after retry succeeds (the fix)", async () => { + let callCount = 0; + const apiCall = vi.fn().mockImplementation(async () => { + callCount++; + if (callCount === 1) { + throw new Error("transient failure"); + } + return "success"; + }); + + const result = await asyncRetry(apiCall, { + maxRetries: 3, + delayInSeconds: 0.001, + }); + + expect(result).toBe("success"); + // With fix: exactly 2 calls (fail + retry success) + // Without fix: 3 calls (fail + retry success + duplicate from while loop) + expect(apiCall).toHaveBeenCalledTimes(2); + }); + + test("REGRESSION: customer scenario — timeout + retry produces exactly 2 sends, not 3", async () => { + const sendLog: string[] = []; + let attempt = 0; + + const apiCall = vi.fn().mockImplementation(async () => { + attempt++; + sendLog.push(`attempt-${attempt}`); + if (attempt === 1) { + throw new Error( + "Response could not be received from server for getting images", + ); + } + return [{ taskUUID: "9258b951", status: "success" }]; + }); + + const callback = vi.fn(); + + const result = await asyncRetry(apiCall, { + maxRetries: 2, + delayInSeconds: 0.001, + callback, + }); + + expect(sendLog).toEqual(["attempt-1", "attempt-2"]); + expect(apiCall).toHaveBeenCalledTimes(2); + expect(callback).toHaveBeenCalledTimes(1); + expect(result[0].taskUUID).toBe("9258b951"); + }); +}); diff --git a/tests/Runware/runware-server.test.ts b/tests/Runware/runware-server.test.ts deleted file mode 100644 index 78300c3..0000000 --- a/tests/Runware/runware-server.test.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { afterEach, beforeAll, describe, expect, test, vi } from "vitest"; -import { startMockBackendServer } from "../mockServer"; -import { RunwareServer } from "../../Runware"; -import { BASE_RUNWARE_URLS } from "../../Runware/utils"; - -const PORT = 8080; - -describe("When using backend mockServer", async () => { - const { mockServer } = await startMockBackendServer(); - - beforeAll(async () => {}); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeAll(async () => {}); - - test("it should instantiate server correctly", async () => { - vi.spyOn( - (RunwareServer as any).prototype, - "addListener" - ).mockImplementation(() => "afa"); - vi.spyOn((RunwareServer as any).prototype, "connect"); - - const runwareServer: any = new RunwareServer({ - apiKey: "API_KEY", - url: BASE_RUNWARE_URLS.TEST, - }); - - expect(runwareServer._apiKey).toBe("API_KEY"); - expect(runwareServer.connect).toBeCalledTimes(1); - expect(runwareServer._ws).toBeDefined(); - }); -}); diff --git a/tests/Runware/server/connection.test.ts b/tests/Runware/server/connection.test.ts new file mode 100644 index 0000000..1448875 --- /dev/null +++ b/tests/Runware/server/connection.test.ts @@ -0,0 +1,35 @@ +import { describe, test, expect, afterAll } from "vitest"; +import { RunwareServer } from "../../../Runware/Runware-server"; +import { createRealServer } from "../../test-utils"; + +describe("RunwareServer connection (real WebSocket)", () => { + let server: RunwareServer; + + afterAll(() => { + server?.disconnect(); + }); + + test("SUCCESS: connects and authenticates", async () => { + server = await createRealServer(); + + expect((server as any)._apiKey).toBeTruthy(); + expect((server as any)._connectionSessionUUID).toBeDefined(); + expect((server as any)._connectionSessionUUID).not.toBeUndefined(); + expect((server as any).isWebsocketReadyState()).toBe(true); + }, 30000); + + test("SUCCESS: heartbeat is active after connection", () => { + expect((server as any)._heartbeatIntervalId).not.toBeNull(); + }); + + test("SUCCESS: connected() returns true", () => { + expect((server as any).connected()).toBe(true); + }); + + test("FAILURE: disconnect clears connection state", () => { + server.disconnect(); + + expect((server as any).isWebsocketReadyState()).toBe(false); + expect((server as any)._heartbeatIntervalId).toBeNull(); + }); +}); diff --git a/tests/Runware/upload-image.test.ts b/tests/Runware/upload-image.test.ts deleted file mode 100644 index 7435d93..0000000 --- a/tests/Runware/upload-image.test.ts +++ /dev/null @@ -1,48 +0,0 @@ -import { expect, test, vi, describe, afterEach, beforeEach } from "vitest"; -import { - getIntervalWithPromise, - fileToBase64, - MockFile, -} from "../../Runware/utils"; -import { - mockFileToBase64, - mockTaskUUID, - mockTextImageUpload, - mockUploadFile, -} from "../test-utils"; -import { startMockServer } from "../mockServer"; -import { ETaskType } from "../../Runware"; - -vi.mock("../../Runware/utils", async () => { - const actual = await vi.importActual("../../Runware/utils"); - return { - ...(actual as any), - fileToBase64: vi.fn().mockReturnValue("FILE_TO_BASE_64"), - getIntervalWithPromise: vi.fn(), - getUUID: vi.fn().mockReturnValue("UNIQUE_UID"), - }; -}); - -describe("When user uploads an image:", async () => { - const { mockServer, runware } = await startMockServer(); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeEach(() => { - mockServer.stop(); - }); - - test("it should accept string during image upload", async () => { - await runware["uploadImage"]("IMAGE_UPLOAD"); - expect(fileToBase64).to.not.toHaveBeenCalled(); - }); - - test("it should accept file during image upload", async () => { - const sendSpy = vi.spyOn(runware as any, "send"); - await runware["uploadImage"](mockUploadFile); - - expect(fileToBase64).toHaveBeenCalled(); - }); -}); diff --git a/tests/Runware/upscale-gan.test.ts b/tests/Runware/upscale-gan.test.ts deleted file mode 100644 index 902f2b2..0000000 --- a/tests/Runware/upscale-gan.test.ts +++ /dev/null @@ -1,70 +0,0 @@ -import { - expect, - test, - beforeAll, - vi, - describe, - afterEach, - beforeEach, -} from "vitest"; -import { mockTaskUUID, mockUploadFile, testExamples } from "../test-utils"; -import { startMockServer } from "../mockServer"; -import { ETaskType } from "../../Runware"; - -vi.mock("../../Runware/utils", async () => { - const actual = await vi.importActual("../../Runware/utils"); - return { - ...(actual as any), - fileToBase64: vi.fn().mockReturnValue("FILE_TO_BASE_64"), - getIntervalWithPromise: vi.fn(), - getUUID: vi.fn().mockImplementation(() => "UNIQUE_UID"), - }; -}); - -describe("When user request to upscale gan", async () => { - const { mockServer, runware } = await startMockServer(); - - afterEach(() => { - vi.clearAllMocks(); - }); - - beforeEach(() => { - mockServer.stop(); - }); - - beforeAll(async () => { - vi.spyOn(runware as any, "uploadImage").mockReturnValue( - testExamples.imageUploadRes - ); - }); - - test("it should upscale gan", async () => { - const imageUploadSpy = vi.spyOn(runware as any, "uploadImage"); - const globalListenerSpy = vi.spyOn(runware as any, "globalListener"); - const sendSpy = vi.spyOn(runware as any, "send"); - - await runware.upscaleGan({ - inputImage: mockUploadFile, - upscaleFactor: 2, - }); - - expect(imageUploadSpy).toHaveBeenCalled(); - - expect(sendSpy).toHaveBeenCalledWith({ - inputImage: testExamples.imageUploadRes.imageUUID, - taskUUID: mockTaskUUID, - upscaleFactor: 2, - taskType: ETaskType.UPSCALE, - deliveryMethod: "sync", - }); - expect(globalListenerSpy).toHaveBeenCalledWith({ - taskUUID: mockTaskUUID, - }); - }); - - test("upscale delegates to upscaleGan", async () => { - const params = { inputImage: mockUploadFile, upscaleFactor: 2 }; - const result = await runware.upscale(params); - expect(result).toEqual(await runware.upscaleGan(params)); - }); -}); diff --git a/tests/TEST.md b/tests/TEST.md new file mode 100644 index 0000000..05ca966 --- /dev/null +++ b/tests/TEST.md @@ -0,0 +1,126 @@ +# Running Tests + +All tests use **real WebSocket connections** against the Runware production API. No mock servers. + +## Prerequisites + +1. Ensure your `.env` file has valid production credentials: + +``` +API_KEY = "your-api-key" +VITE_RUNWARE_SDK_URL = "wss://ws-api.runware.ai/v1" +``` + +2. Install dependencies: + +```bash +npm install +``` + +--- + +## Run All Tests + +```bash +npx vitest run tests/ --reporter verbose +``` + +--- + +## Run by Folder + +```bash +# Connection lifecycle & health (heartbeat, session UUID, send guard, zombie detection) +npx vitest run tests/Runware/connection/ --reporter verbose + +# Retry mechanism (asyncRetry duplicate send fix) +npx vitest run tests/Runware/retry/ --reporter verbose + +# API inference (image, video, prompt, upscale, upload) +npx vitest run tests/Runware/inference/ --reporter verbose + +# Server instantiation & auth +npx vitest run tests/Runware/server/ --reporter verbose +``` + +--- + +## Run a Single File + +```bash +npx vitest run tests/Runware/connection/heartbeat.test.ts --reporter verbose +npx vitest run tests/Runware/connection/session-uuid.test.ts --reporter verbose +npx vitest run tests/Runware/connection/send-guard.test.ts --reporter verbose +npx vitest run tests/Runware/connection/zombie-detection.test.ts --reporter verbose +npx vitest run tests/Runware/retry/async-retry.test.ts --reporter verbose +npx vitest run tests/Runware/inference/image-generation.test.ts --reporter verbose +npx vitest run tests/Runware/inference/video-generation.test.ts --reporter verbose +npx vitest run tests/Runware/inference/enhance-prompt.test.ts --reporter verbose +npx vitest run tests/Runware/inference/upscale.test.ts --reporter verbose +npx vitest run tests/Runware/inference/upload-image.test.ts --reporter verbose +npx vitest run tests/Runware/server/connection.test.ts --reporter verbose +``` + +--- + +## Run a Single Test by Name + +Use the `-t` flag to match a test name: + +```bash +npx vitest run tests/Runware/inference/video-generation.test.ts -t "submits video job" --reporter verbose +npx vitest run tests/Runware/connection/heartbeat.test.ts -t "3-strike" --reporter verbose +npx vitest run tests/Runware/retry/async-retry.test.ts -t "customer scenario" --reporter verbose +``` + +--- + +## Watch Mode (re-runs on file change) + +```bash +npx vitest tests/Runware/connection/heartbeat.test.ts --reporter verbose +``` + +--- + +## Test Structure + +``` +tests/Runware/ +├── connection/ # Connection lifecycle & health +│ ├── heartbeat.test.ts # Ping/pong, 3-strike tolerance, keepalive +│ ├── session-uuid.test.ts # _connectionSessionUUID clearing on close +│ ├── send-guard.test.ts # send() readyState check + ensureConnection retry +│ └── zombie-detection.test.ts # E2E: connect → close → state cleared → send blocked +│ +├── retry/ # Retry mechanism +│ └── async-retry.test.ts # Missing return fix, duplicate send prevention +│ +├── inference/ # Real API feature tests +│ ├── image-generation.test.ts # requestImages / imageInference +│ ├── video-generation.test.ts # videoInference with skipResponse + getResponse polling +│ ├── enhance-prompt.test.ts # enhancePrompt / promptEnhance +│ ├── upscale.test.ts # upscaleGan / upscale +│ └── upload-image.test.ts # uploadImage +│ +└── server/ # Server instantiation & auth + └── connection.test.ts # RunwareServer connect, auth, heartbeat, disconnect +``` + +--- + +## Telemetry Logging + +All tests run with `enableLogging: true` by default, so you will see colored `[RUNWARE]` telemetry output showing connection, auth, heartbeat, send, and error events in the console. + +To disable logging for a specific test, pass `{ enableLogging: false }` to `createRealServer()`. + +--- + +## Timeouts + +- **Connection/retry tests**: 30s default +- **Image/prompt/upscale tests**: 60s +- **Video generation tests**: up to 6 minutes (video rendering is async) + +If a test times out, check your network connection and API key validity. diff --git a/tests/mockServer.ts b/tests/mockServer.ts index cc6b5fa..a7c3516 100644 --- a/tests/mockServer.ts +++ b/tests/mockServer.ts @@ -1,41 +1,3 @@ -import { Server } from "mock-socket"; -import { Runware, RunwareServer } from "../Runware"; -import { BASE_RUNWARE_URLS, delay } from "../Runware/utils"; -import { WebSocketServer } from "ws"; - -export const startMockServer = async () => { - const mockServer = new Server("ws://localhost:8080"); - - mockServer.on("connection", (socket) => { - socket.on("message", (data) => { - // socket.send("test message from mock server"); - }); - }); - - const runware = new Runware({ - apiKey: "API_KEY", - url: BASE_RUNWARE_URLS.TEST, - }); - await delay(1); - - return { runware, mockServer }; -}; - -export const startMockBackendServer = async () => { - const mockServer = new WebSocketServer({ port: 8080 }); - - mockServer.on("connection", (socket) => { - socket.on("message", (data, isBinary) => { - const message = !isBinary ? data?.toString() : data; - // socket.send("test message from mock server"); - }); - }); - - const runwareServer = new RunwareServer({ - apiKey: "API_KEY", - url: BASE_RUNWARE_URLS.TEST, - }); - await delay(1); - - return { runwareServer, mockServer }; -}; +// This file is no longer used. +// All tests now use real WebSocket connections via createRealServer() in test-utils.ts. +// Kept as an empty file to avoid breaking any imports in worktree copies. diff --git a/tests/test-utils.ts b/tests/test-utils.ts index 00dfceb..481b647 100644 --- a/tests/test-utils.ts +++ b/tests/test-utils.ts @@ -1,50 +1,33 @@ -import { EControlMode, ETaskType } from "../Runware"; -import { MockFile } from "../Runware/utils"; +import dotenv from "dotenv"; +import { RunwareServer } from "../Runware/Runware-server"; -const promptText = "A beautiful runware"; +dotenv.config(); -export const mockTaskUUID = "UNIQUE_UID"; -export const mockTextImageUpload = "IMAGE_UPLOAD"; -export const mockFileToBase64 = "FILE_TO_BASE_64"; -export const mockNewImageUID = "NEW_IMAGE_UID"; +const API_KEY = process.env.API_KEY || ""; +const URL = process.env.VITE_RUNWARE_SDK_URL || ""; -export const mockUploadFile = new MockFile().create( - "pic.jpg", - 1024 * 1024 * 2, - "image/jpeg" -); +if (!API_KEY) { + throw new Error("API_KEY not set in .env"); +} -export const testExamples = { - imageReq: { - numberResults: 8, - positivePrompt: promptText, - model: 13, - steps: 30, - width: 512, - height: 512, - }, - imageRes: { - model: 13, - numberResults: 8, - positivePrompt: promptText, - steps: 30, - taskType: ETaskType.IMAGE_INFERENCE, - taskUUID: mockTaskUUID, - width: 512, - height: 512, - }, - imageUploadRes: { - imageUUID: mockNewImageUID, - imageURL: "data:image/png;base64,iVBORw0KGgoAAAA...", - taskUUID: "50836053-a0ee-4cf5-b9d6-ae7c5d140ada", - taskType: ETaskType.IMAGE_UPLOAD, - }, - controlNet: { - endStep: 20, - startStep: 0, - guideImage: mockTextImageUpload, - preprocessor: "canny" as any, - weight: 1, - controlMode: EControlMode.CONTROL_NET, - }, +/** + * Creates a real RunwareServer connected to the API via WebSocket. + * Uses .env credentials — no mocks, no fake servers. + * Pass enableLogging: true to see detailed SDK telemetry in the console. + */ +export const createRealServer = async ( + options?: { enableLogging?: boolean }, +): Promise => { + const server = await RunwareServer.initialize({ + apiKey: API_KEY, + url: URL, + shouldReconnect: false, + heartbeatInterval: 30000, + enableLogging: options?.enableLogging ?? true, + }); + return server as RunwareServer; }; + +// A small publicly-available test image URL for upload/input tests +export const TEST_IMAGE_URL = + "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png"; diff --git a/vitest.config.ts b/vitest.config.ts new file mode 100644 index 0000000..993c8f6 --- /dev/null +++ b/vitest.config.ts @@ -0,0 +1,11 @@ +import { defineConfig } from "vitest/config"; + +export default defineConfig({ + test: { + exclude: [ + "**/node_modules/**", + "**/.claude/**", + "**/dist/**", + ], + }, +});