diff --git a/src/SmartTransactionsController.test.ts b/src/SmartTransactionsController.test.ts index 2d77a73..0f303bc 100644 --- a/src/SmartTransactionsController.test.ts +++ b/src/SmartTransactionsController.test.ts @@ -2194,6 +2194,48 @@ describe('SmartTransactionsController', () => { expect(apiCall.isDone()).toBe(true); }); }); + + it('sends Authorization header when getBearerToken is provided', async () => { + const bearerToken = 'test-bearer-token-123'; + await withController( + { + options: { + getBearerToken: async () => Promise.resolve(bearerToken), + }, + }, + async ({ controller }) => { + const apiCall = nock(API_BASE_URL) + .post(`/networks/${ethereumChainIdDec}/cancel`) + .matchHeader('Authorization', `Bearer ${bearerToken}`) + .reply(200, { message: 'successful' }); + + await controller.cancelSmartTransaction('uuid1'); + + expect(apiCall.isDone()).toBe(true); + }, + ); + }); + + it('sends Authorization header to Sentinel /network when getBearerToken is provided', async () => { + const bearerToken = 'test-bearer-token-456'; + await withController( + { + options: { + getBearerToken: async () => Promise.resolve(bearerToken), + }, + }, + async ({ controller }) => { + const apiCall = nock(SENTINEL_API_BASE_URL_MAP[ethereumChainIdDec]) + .get(`/network`) + .matchHeader('Authorization', `Bearer ${bearerToken}`) + .reply(200, createSuccessLivenessApiResponse()); + + await controller.fetchLiveness(); + + expect(apiCall.isDone()).toBe(true); + }, + ); + }); }); describe('getTransactions', () => { diff --git a/src/SmartTransactionsController.ts b/src/SmartTransactionsController.ts index ee5dc6f..d537a22 100644 --- a/src/SmartTransactionsController.ts +++ b/src/SmartTransactionsController.ts @@ -37,9 +37,11 @@ import { BigNumber } from 'bignumber.js'; import cloneDeep from 'lodash/cloneDeep'; import { + API_BASE_URL, DEFAULT_DISABLED_SMART_TRANSACTIONS_FEATURE_FLAGS, MetaMetricsEventCategory, MetaMetricsEventName, + SENTINEL_API_BASE_URL_MAP, SmartTransactionsTraceName, } from './constants'; import { @@ -230,6 +232,14 @@ type SmartTransactionsControllerOptions = { * removed in a future version. */ getFeatureFlags?: () => FeatureFlags; + /** + * Optional callback to obtain a bearer token for authenticating requests to + * the Transaction API. When provided, the token is sent in the + * Authorization header for all Transaction API calls. Can be used with + * the authentication flow from @metamask/core-backend (e.g. from + * AuthenticationController.getBearerToken). + */ + getBearerToken?: () => Promise | string | undefined; trace?: TraceCallback; }; @@ -258,6 +268,11 @@ export class SmartTransactionsController extends StaticIntervalPollingController readonly #getMetaMetricsProps: () => Promise; + readonly #getBearerToken?: () => + | Promise + | string + | undefined; + #trace: TraceCallback; /** @@ -292,11 +307,28 @@ export class SmartTransactionsController extends StaticIntervalPollingController /* istanbul ignore next */ async #fetch(request: string, options?: RequestInit) { + const headers: Record = { + 'Content-Type': 'application/json', + ...(this.#clientId && { 'X-Client-Id': this.#clientId }), + }; + + const urlMatches = + request.startsWith(API_BASE_URL) || + Object.values(SENTINEL_API_BASE_URL_MAP).some((baseUrl) => + request.startsWith(baseUrl), + ); + if (this.#getBearerToken && urlMatches) { + const token = await Promise.resolve(this.#getBearerToken()); + if (token) { + headers.Authorization = `Bearer ${token}`; + } + } + const fetchOptions = { ...options, headers: { - 'Content-Type': 'application/json', - ...(this.#clientId && { 'X-Client-Id': this.#clientId }), + ...headers, + ...options?.headers, }, }; @@ -312,6 +344,7 @@ export class SmartTransactionsController extends StaticIntervalPollingController state = {}, messenger, getMetaMetricsProps, + getBearerToken, trace, }: SmartTransactionsControllerOptions) { super({ @@ -323,6 +356,7 @@ export class SmartTransactionsController extends StaticIntervalPollingController ...state, }, }); + this.#interval = interval; this.#clientId = clientId; this.#chainId = InitialChainId; @@ -331,6 +365,7 @@ export class SmartTransactionsController extends StaticIntervalPollingController this.#ethQuery = undefined; this.#trackMetaMetricsEvent = trackMetaMetricsEvent; this.#getMetaMetricsProps = getMetaMetricsProps; + this.#getBearerToken = getBearerToken; this.#trace = trace ?? (((_request, fn) => fn?.()) as TraceCallback); this.initializeSmartTransactionsForChainId();