diff --git a/src/config.test.ts b/src/config.test.ts index 4fa2186a..d930d1a9 100644 --- a/src/config.test.ts +++ b/src/config.test.ts @@ -664,6 +664,16 @@ describe('Config', () => { ); }); + it('should not throw error when TRANSPORT is "http" and ENABLE_PASSTHROUGH_AUTH is "true" and DANGEROUSLY_DISABLE_OAUTH is "true"', () => { + vi.stubEnv('TRANSPORT', 'http'); + vi.stubEnv('OAUTH_ISSUER', undefined); + vi.stubEnv('DANGEROUSLY_DISABLE_OAUTH', 'true'); + vi.stubEnv('ENABLE_PASSTHROUGH_AUTH', 'true'); + + const config = new Config(); + expect(config.enablePassthroughAuth).toBe(true); + }); + it('should throw error when OAUTH_JWE_PRIVATE_KEY and OAUTH_JWE_PRIVATE_KEY_PATH is not set', () => { stubDefaultOAuthEnvVars(); vi.stubEnv('OAUTH_JWE_PRIVATE_KEY_PATH', ''); @@ -758,6 +768,20 @@ describe('Config', () => { }); }); + describe('Passthrough configuration', () => { + it('should set enablePassthroughAuth to false by default', () => { + const config = new Config(); + expect(config.enablePassthroughAuth).toBe(false); + }); + + it('should set enablePassthroughAuth to true when ENABLE_PASSTHROUGH_AUTH is "true"', () => { + vi.stubEnv('ENABLE_PASSTHROUGH_AUTH', 'true'); + + const config = new Config(); + expect(config.enablePassthroughAuth).toBe(true); + }); + }); + describe('parseNumber', () => { it('should return defaultValue when value is undefined', () => { const result = parseNumber(undefined, { defaultValue: 42 }); diff --git a/src/config.ts b/src/config.ts index f632688c..c615e29e 100644 --- a/src/config.ts +++ b/src/config.ts @@ -54,6 +54,7 @@ export class Config { tableauServerVersionCheckIntervalInHours: number; mcpSiteSettingsCheckIntervalInMinutes: number; enableMcpSiteSettings: boolean; + enablePassthroughAuth: boolean; oauth: { enabled: boolean; embeddedAuthzServer: boolean; @@ -113,6 +114,7 @@ export class Config { TABLEAU_SERVER_VERSION_CHECK_INTERVAL_IN_HOURS: tableauServerVersionCheckIntervalInHours, MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES: mcpSiteSettingsCheckIntervalInMinutes, ENABLE_MCP_SITE_SETTINGS: enableMcpSiteSettings, + ENABLE_PASSTHROUGH_AUTH: enablePassthroughAuth, DANGEROUSLY_DISABLE_OAUTH: disableOauth, OAUTH_EMBEDDED_AUTHZ_SERVER: oauthEmbeddedAuthzServer, OAUTH_ISSUER: oauthIssuer, @@ -175,6 +177,7 @@ export class Config { ); this.enableMcpSiteSettings = enableMcpSiteSettings === 'true'; + this.enablePassthroughAuth = enablePassthroughAuth === 'true'; const disableOauthOverride = disableOauth === 'true'; const disableScopes = oauthDisableScopes === 'true'; const enforceScopes = !disableScopes; @@ -244,9 +247,17 @@ export class Config { this.isHyperforce = isHyperforce === 'true'; this.auth = isAuthType(auth) ? auth : this.oauth.enabled ? 'oauth' : 'pat'; - this.transport = isTransport(transport) ? transport : this.oauth.enabled ? 'http' : 'stdio'; - - if (this.transport === 'http' && !disableOauthOverride && !this.oauth.issuer) { + this.transport = isTransport(transport) + ? transport + : this.oauth.enabled + ? 'http' + : 'stdio'; + + if ( + this.transport === 'http' && + !disableOauthOverride && + !this.oauth.issuer + ) { throw new Error( 'OAUTH_ISSUER must be set when TRANSPORT is "http" unless DANGEROUSLY_DISABLE_OAUTH is "true"', ); diff --git a/src/restApiInstance.test.ts b/src/restApiInstance.test.ts index 7911928f..279f3986 100644 --- a/src/restApiInstance.test.ts +++ b/src/restApiInstance.test.ts @@ -53,6 +53,34 @@ describe('restApiInstance', () => { }); }); + describe('useRestApi with passthrough auth', () => { + it('should use setCredentials when tableauAuthInfo type is Passthrough', async () => { + const restApi = await useRestApi({ + config: mockConfig, + requestId: mockRequestId, + server: new Server(), + tableauAuthInfo: { + type: 'Passthrough', + raw: 'abc123|xyz789|site-luid', + userId: 'user-luid-123', + username: 'testuser', + server: 'https://my-tableau-server.com', + siteId: 'site-id-123', + }, + jwtScopes: [], + signal: new AbortController().signal, + callback: (restApi: RestApi) => Promise.resolve(restApi), + }); + + expect(restApi.setCredentials).toHaveBeenCalledWith( + 'abc123|xyz789|site-luid', + 'user-luid-123', + ); + expect(restApi.signIn).not.toHaveBeenCalled(); + expect(restApi.signOut).not.toHaveBeenCalled(); + }); + }); + describe('Request Interceptor', () => { it('should add User-Agent header and log request', () => { const server = new Server(); diff --git a/src/restApiInstance.ts b/src/restApiInstance.ts index e69c963a..a21ba90b 100644 --- a/src/restApiInstance.ts +++ b/src/restApiInstance.ts @@ -92,7 +92,20 @@ const getNewRestApiInstanceAsync = async ( ], }); - if (config.auth === 'pat') { + if (tableauAuthInfo?.type === 'Passthrough') { + // Pre-authenticated credentials from passthrough middleware + restApi.setCredentials(tableauAuthInfo.raw, tableauAuthInfo.userId); + } else if ( + tableauAuthInfo?.type === 'X-Tableau-Auth' && + tableauAuthInfo.accessToken && + tableauAuthInfo.userId + ) { + // Pre-authenticated credentials from OAuth + restApi.setCredentials(tableauAuthInfo.accessToken, tableauAuthInfo.userId); + } else if (tableauAuthInfo?.type === 'Bearer') { + // Bearer token from OAuth scope-based auth + restApi.setBearerToken(tableauAuthInfo.raw); + } else if (config.auth === 'pat') { await restApi.signIn({ type: 'pat', patName: config.patName, @@ -124,17 +137,7 @@ const getNewRestApiInstanceAsync = async ( additionalPayload: getJwtAdditionalPayload(config, tableauAuthInfo), }); } else { - if (tableauAuthInfo?.type === 'Bearer') { - restApi.setBearerToken(tableauAuthInfo.raw); - } else if (tableauAuthInfo?.type === 'X-Tableau-Auth') { - if (!tableauAuthInfo?.accessToken || !tableauAuthInfo?.userId) { - throw new Error('Auth info is required when not signing in first.'); - } - - restApi.setCredentials(tableauAuthInfo.accessToken, tableauAuthInfo.userId); - } else { - throw new Error('Auth info is required when not signing in first.'); - } + throw new Error('Auth info is required when not signing in first.'); } return restApi; @@ -147,7 +150,7 @@ export const useRestApi = async ( }, ): Promise => { const { callback, ...remaining } = args; - const { config } = remaining; + const { config, tableauAuthInfo } = remaining; const restApi = await getNewRestApiInstanceAsync({ ...remaining, jwtScopes: new Set(args.jwtScopes), @@ -155,9 +158,9 @@ export const useRestApi = async ( try { return await callback(restApi); } finally { - if (config.auth !== 'oauth') { + if (config.auth !== 'oauth' && tableauAuthInfo?.type !== 'Passthrough') { // Tableau REST sessions for 'pat' and 'direct-trust' are intentionally ephemeral. - // Sessions for 'oauth' are not. Signing out would invalidate the session, + // Sessions for 'oauth' and 'passthrough' are not. Signing out would invalidate the session, // preventing the access token from being reused for subsequent requests. await restApi.signOut(); } diff --git a/src/server/express.ts b/src/server/express.ts index 7b52548c..8b658b73 100644 --- a/src/server/express.ts +++ b/src/server/express.ts @@ -15,6 +15,7 @@ import { getTableauAuthInfo } from './oauth/getTableauAuthInfo.js'; import { OAuthProvider } from './oauth/provider.js'; import { TableauAuthInfo } from './oauth/schemas.js'; import { AuthenticatedRequest } from './oauth/types.js'; +import { passthroughMiddleware, X_TABLEAU_AUTH_HEADER } from './passthroughMiddleware.js'; const SESSION_ID_HEADER = 'mcp-session-id'; @@ -43,7 +44,7 @@ export async function startExpressServer({ 'Accept', 'MCP-Protocol-Version', ], - exposedHeaders: [SESSION_ID_HEADER, 'x-session-id'], + exposedHeaders: [SESSION_ID_HEADER, 'x-session-id', X_TABLEAU_AUTH_HEADER], }), ); @@ -52,12 +53,15 @@ export async function startExpressServer({ app.set('trust proxy', config.trustProxyConfig); } - const middleware: Array = [handlePingRequest]; + const middleware: Array = [handlePingRequest, validateProtocolVersion]; + if (config.enablePassthroughAuth) { + middleware.push(passthroughMiddleware()); + } + if (config.oauth.enabled) { const oauthProvider = new OAuthProvider(); oauthProvider.setupRoutes(app); middleware.push(oauthProvider.authMiddleware); - middleware.push(validateProtocolVersion); } const path = `/${basePath}`; @@ -120,7 +124,8 @@ export async function startExpressServer({ server.close(); }); - await connect(server, transport, logLevel, getTableauAuthInfo(req.auth)); + const authInfo = getTableauAuthInfo(req.auth); + await connect(server, transport, logLevel, authInfo); } else { const sessionId = req.headers[SESSION_ID_HEADER] as string | undefined; @@ -132,7 +137,8 @@ export async function startExpressServer({ transport = createSession({ clientInfo }); const server = new Server({ clientInfo }); - await connect(server, transport, logLevel, getTableauAuthInfo(req.auth)); + const authInfo = getTableauAuthInfo(req.auth); + await connect(server, transport, logLevel, authInfo); } else { // Invalid request res.status(400).json({ diff --git a/src/server/oauth/schemas.ts b/src/server/oauth/schemas.ts index 6634b468..42c1b9eb 100644 --- a/src/server/oauth/schemas.ts +++ b/src/server/oauth/schemas.ts @@ -134,6 +134,14 @@ export const tableauAuthInfoSchema = z.discriminatedUnion('type', [ siteId: z.string(), raw: z.string(), }), + z.object({ + type: z.literal('Passthrough'), + username: z.string(), + userId: z.string(), + server: z.string(), + siteId: z.string(), + raw: z.string(), + }), ]); export const cimdMetadataSchema = z.object({ diff --git a/src/server/passthroughMiddleware.ts b/src/server/passthroughMiddleware.ts new file mode 100644 index 00000000..ca31620e --- /dev/null +++ b/src/server/passthroughMiddleware.ts @@ -0,0 +1,86 @@ +import { NextFunction, RequestHandler, Response } from 'express'; +import { z } from 'zod'; + +import { getConfig, TEN_MINUTES_IN_MS } from '../config.js'; +import { RestApi } from '../sdks/tableau/restApi.js'; +import { ExpiringMap } from '../utils/expiringMap.js'; +import { getSupportedMcpScopes } from './oauth/scopes.js'; +import { AuthenticatedRequest } from './oauth/types.js'; + +export const X_TABLEAU_AUTH_HEADER = 'x-tableau-auth'; + +export const passthroughAuthInfoSchema = z.object({ + type: z.literal('Passthrough'), + username: z.string(), + userId: z.string(), + server: z.string(), + siteId: z.string(), + raw: z.string(), +}); + +export type PassthroughAuthInfo = z.infer; + +const passthroughAuthInfoCache = new ExpiringMap({ + defaultExpirationTimeMs: TEN_MINUTES_IN_MS, +}); + +export function passthroughMiddleware(): RequestHandler { + return async (req: AuthenticatedRequest, res: Response, next: NextFunction): Promise => { + const tableauAccessToken: string = + getCookie(req, 'workgroup_session_id') || getHeader(req, X_TABLEAU_AUTH_HEADER); + + if (!tableauAccessToken) { + next(); + return; + } + + const config = getConfig(); + let passthroughAuthInfo = passthroughAuthInfoCache.get(tableauAccessToken); + if (!passthroughAuthInfo) { + const { server, maxRequestTimeoutMs } = config; + + const restApi = new RestApi(server, { + maxRequestTimeoutMs, + }); + + restApi.setCredentials(tableauAccessToken, 'unknown user id'); + const sessionResult = await restApi.authenticatedServerMethods.getCurrentServerSession(); + if (!sessionResult.isOk()) { + res.status(401).json({ + error: 'invalid_token', + error_description: sessionResult.error, + }); + return; + } + + passthroughAuthInfo = { + type: 'Passthrough', + username: sessionResult.value.user.name, + userId: sessionResult.value.user.id, + server, + siteId: sessionResult.value.site.id, + raw: tableauAccessToken, + }; + + passthroughAuthInfoCache.set(tableauAccessToken, passthroughAuthInfo); + } + + req.auth = { + token: 'passthrough', + clientId: 'passthrough', + scopes: config.oauth.enforceScopes ? getSupportedMcpScopes() : [], + extra: passthroughAuthInfo, + }; + next(); + }; +} + +function getCookie(req: AuthenticatedRequest, cookieName: string): string { + const cookieValue = req.cookies?.[cookieName]; + return cookieValue?.toString() ?? ''; +} + +function getHeader(req: AuthenticatedRequest, headerName: string): string { + const headerValue = req.headers[headerName]; + return headerValue?.toString() ?? ''; +} diff --git a/src/testSetup.ts b/src/testSetup.ts index 28019c51..3552a7eb 100644 --- a/src/testSetup.ts +++ b/src/testSetup.ts @@ -17,6 +17,7 @@ vi.mock('./sdks/tableau/restApi.js', async (importOriginal) => ({ RestApi: vi.fn().mockImplementation(() => ({ signIn: vi.fn().mockResolvedValue(undefined), signOut: vi.fn().mockResolvedValue(undefined), + setCredentials: vi.fn(), serverMethods: { getServerInfo: vi.fn().mockResolvedValue({ productVersion: testProductVersion,