Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,16 @@ describe('Config', () => {
);
});

it('should not throw error when TRANSPORT is "http" and ENABLE_PASSTHROUGH_AUTH is "true" and DANGEROUSLY_DISABLE_OAUTH is "true"', () => {
vi.stubEnv('TRANSPORT', 'http');
vi.stubEnv('OAUTH_ISSUER', undefined);
vi.stubEnv('DANGEROUSLY_DISABLE_OAUTH', 'true');
vi.stubEnv('ENABLE_PASSTHROUGH_AUTH', 'true');

const config = new Config();
expect(config.enablePassthroughAuth).toBe(true);
});

it('should throw error when OAUTH_JWE_PRIVATE_KEY and OAUTH_JWE_PRIVATE_KEY_PATH is not set', () => {
stubDefaultOAuthEnvVars();
vi.stubEnv('OAUTH_JWE_PRIVATE_KEY_PATH', '');
Expand Down Expand Up @@ -758,6 +768,20 @@ describe('Config', () => {
});
});

describe('Passthrough configuration', () => {
it('should set enablePassthroughAuth to false by default', () => {
const config = new Config();
expect(config.enablePassthroughAuth).toBe(false);
});

it('should set enablePassthroughAuth to true when ENABLE_PASSTHROUGH_AUTH is "true"', () => {
vi.stubEnv('ENABLE_PASSTHROUGH_AUTH', 'true');

const config = new Config();
expect(config.enablePassthroughAuth).toBe(true);
});
});

describe('parseNumber', () => {
it('should return defaultValue when value is undefined', () => {
const result = parseNumber(undefined, { defaultValue: 42 });
Expand Down
17 changes: 14 additions & 3 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export class Config {
tableauServerVersionCheckIntervalInHours: number;
mcpSiteSettingsCheckIntervalInMinutes: number;
enableMcpSiteSettings: boolean;
enablePassthroughAuth: boolean;
oauth: {
enabled: boolean;
embeddedAuthzServer: boolean;
Expand Down Expand Up @@ -113,6 +114,7 @@ export class Config {
TABLEAU_SERVER_VERSION_CHECK_INTERVAL_IN_HOURS: tableauServerVersionCheckIntervalInHours,
MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES: mcpSiteSettingsCheckIntervalInMinutes,
ENABLE_MCP_SITE_SETTINGS: enableMcpSiteSettings,
ENABLE_PASSTHROUGH_AUTH: enablePassthroughAuth,
DANGEROUSLY_DISABLE_OAUTH: disableOauth,
OAUTH_EMBEDDED_AUTHZ_SERVER: oauthEmbeddedAuthzServer,
OAUTH_ISSUER: oauthIssuer,
Expand Down Expand Up @@ -175,6 +177,7 @@ export class Config {
);

this.enableMcpSiteSettings = enableMcpSiteSettings === 'true';
this.enablePassthroughAuth = enablePassthroughAuth === 'true';
const disableOauthOverride = disableOauth === 'true';
const disableScopes = oauthDisableScopes === 'true';
const enforceScopes = !disableScopes;
Expand Down Expand Up @@ -244,9 +247,17 @@ export class Config {
this.isHyperforce = isHyperforce === 'true';

this.auth = isAuthType(auth) ? auth : this.oauth.enabled ? 'oauth' : 'pat';
this.transport = isTransport(transport) ? transport : this.oauth.enabled ? 'http' : 'stdio';

if (this.transport === 'http' && !disableOauthOverride && !this.oauth.issuer) {
this.transport = isTransport(transport)
? transport
: this.oauth.enabled
? 'http'
: 'stdio';

if (
this.transport === 'http' &&
!disableOauthOverride &&
!this.oauth.issuer
) {
throw new Error(
'OAUTH_ISSUER must be set when TRANSPORT is "http" unless DANGEROUSLY_DISABLE_OAUTH is "true"',
);
Expand Down
28 changes: 28 additions & 0 deletions src/restApiInstance.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ describe('restApiInstance', () => {
});
});

describe('useRestApi with passthrough auth', () => {
it('should use setCredentials when tableauAuthInfo type is Passthrough', async () => {
const restApi = await useRestApi({
config: mockConfig,
requestId: mockRequestId,
server: new Server(),
tableauAuthInfo: {
type: 'Passthrough',
raw: 'abc123|xyz789|site-luid',
userId: 'user-luid-123',
username: 'testuser',
server: 'https://my-tableau-server.com',
siteId: 'site-id-123',
},
jwtScopes: [],
signal: new AbortController().signal,
callback: (restApi: RestApi) => Promise.resolve(restApi),
});

expect(restApi.setCredentials).toHaveBeenCalledWith(
'abc123|xyz789|site-luid',
'user-luid-123',
);
expect(restApi.signIn).not.toHaveBeenCalled();
expect(restApi.signOut).not.toHaveBeenCalled();
});
});

describe('Request Interceptor', () => {
it('should add User-Agent header and log request', () => {
const server = new Server();
Expand Down
33 changes: 18 additions & 15 deletions src/restApiInstance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,20 @@ const getNewRestApiInstanceAsync = async (
],
});

if (config.auth === 'pat') {
if (tableauAuthInfo?.type === 'Passthrough') {
// Pre-authenticated credentials from passthrough middleware
restApi.setCredentials(tableauAuthInfo.raw, tableauAuthInfo.userId);
} else if (
tableauAuthInfo?.type === 'X-Tableau-Auth' &&
tableauAuthInfo.accessToken &&
tableauAuthInfo.userId
) {
// Pre-authenticated credentials from OAuth
restApi.setCredentials(tableauAuthInfo.accessToken, tableauAuthInfo.userId);
} else if (tableauAuthInfo?.type === 'Bearer') {
// Bearer token from OAuth scope-based auth
restApi.setBearerToken(tableauAuthInfo.raw);
} else if (config.auth === 'pat') {
await restApi.signIn({
type: 'pat',
patName: config.patName,
Expand Down Expand Up @@ -124,17 +137,7 @@ const getNewRestApiInstanceAsync = async (
additionalPayload: getJwtAdditionalPayload(config, tableauAuthInfo),
});
} else {
if (tableauAuthInfo?.type === 'Bearer') {
restApi.setBearerToken(tableauAuthInfo.raw);
} else if (tableauAuthInfo?.type === 'X-Tableau-Auth') {
if (!tableauAuthInfo?.accessToken || !tableauAuthInfo?.userId) {
throw new Error('Auth info is required when not signing in first.');
}

restApi.setCredentials(tableauAuthInfo.accessToken, tableauAuthInfo.userId);
} else {
throw new Error('Auth info is required when not signing in first.');
}
throw new Error('Auth info is required when not signing in first.');
}

return restApi;
Expand All @@ -147,17 +150,17 @@ export const useRestApi = async <T>(
},
): Promise<T> => {
const { callback, ...remaining } = args;
const { config } = remaining;
const { config, tableauAuthInfo } = remaining;
const restApi = await getNewRestApiInstanceAsync({
...remaining,
jwtScopes: new Set(args.jwtScopes),
});
try {
return await callback(restApi);
} finally {
if (config.auth !== 'oauth') {
if (config.auth !== 'oauth' && tableauAuthInfo?.type !== 'Passthrough') {
// Tableau REST sessions for 'pat' and 'direct-trust' are intentionally ephemeral.
// Sessions for 'oauth' are not. Signing out would invalidate the session,
// Sessions for 'oauth' and 'passthrough' are not. Signing out would invalidate the session,
// preventing the access token from being reused for subsequent requests.
await restApi.signOut();
}
Expand Down
16 changes: 11 additions & 5 deletions src/server/express.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { getTableauAuthInfo } from './oauth/getTableauAuthInfo.js';
import { OAuthProvider } from './oauth/provider.js';
import { TableauAuthInfo } from './oauth/schemas.js';
import { AuthenticatedRequest } from './oauth/types.js';
import { passthroughMiddleware, X_TABLEAU_AUTH_HEADER } from './passthroughMiddleware.js';

const SESSION_ID_HEADER = 'mcp-session-id';

Expand Down Expand Up @@ -43,7 +44,7 @@ export async function startExpressServer({
'Accept',
'MCP-Protocol-Version',
],
exposedHeaders: [SESSION_ID_HEADER, 'x-session-id'],
exposedHeaders: [SESSION_ID_HEADER, 'x-session-id', X_TABLEAU_AUTH_HEADER],
}),
);

Expand All @@ -52,12 +53,15 @@ export async function startExpressServer({
app.set('trust proxy', config.trustProxyConfig);
}

const middleware: Array<RequestHandler> = [handlePingRequest];
const middleware: Array<RequestHandler> = [handlePingRequest, validateProtocolVersion];
if (config.enablePassthroughAuth) {
middleware.push(passthroughMiddleware());
}

if (config.oauth.enabled) {
const oauthProvider = new OAuthProvider();
oauthProvider.setupRoutes(app);
middleware.push(oauthProvider.authMiddleware);
middleware.push(validateProtocolVersion);
}

const path = `/${basePath}`;
Expand Down Expand Up @@ -120,7 +124,8 @@ export async function startExpressServer({
server.close();
});

await connect(server, transport, logLevel, getTableauAuthInfo(req.auth));
const authInfo = getTableauAuthInfo(req.auth);
await connect(server, transport, logLevel, authInfo);
} else {
const sessionId = req.headers[SESSION_ID_HEADER] as string | undefined;

Expand All @@ -132,7 +137,8 @@ export async function startExpressServer({
transport = createSession({ clientInfo });

const server = new Server({ clientInfo });
await connect(server, transport, logLevel, getTableauAuthInfo(req.auth));
const authInfo = getTableauAuthInfo(req.auth);
await connect(server, transport, logLevel, authInfo);
} else {
// Invalid request
res.status(400).json({
Expand Down
8 changes: 8 additions & 0 deletions src/server/oauth/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ export const tableauAuthInfoSchema = z.discriminatedUnion('type', [
siteId: z.string(),
raw: z.string(),
}),
z.object({
type: z.literal('Passthrough'),
username: z.string(),
userId: z.string(),
server: z.string(),
siteId: z.string(),
raw: z.string(),
}),
]);

export const cimdMetadataSchema = z.object({
Expand Down
86 changes: 86 additions & 0 deletions src/server/passthroughMiddleware.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { NextFunction, RequestHandler, Response } from 'express';
import { z } from 'zod';

import { getConfig, TEN_MINUTES_IN_MS } from '../config.js';
import { RestApi } from '../sdks/tableau/restApi.js';
import { ExpiringMap } from '../utils/expiringMap.js';
import { getSupportedMcpScopes } from './oauth/scopes.js';
import { AuthenticatedRequest } from './oauth/types.js';

export const X_TABLEAU_AUTH_HEADER = 'x-tableau-auth';

export const passthroughAuthInfoSchema = z.object({
type: z.literal('Passthrough'),
username: z.string(),
userId: z.string(),
server: z.string(),
siteId: z.string(),
raw: z.string(),
});

export type PassthroughAuthInfo = z.infer<typeof passthroughAuthInfoSchema>;

const passthroughAuthInfoCache = new ExpiringMap<string, PassthroughAuthInfo>({
defaultExpirationTimeMs: TEN_MINUTES_IN_MS,
});

export function passthroughMiddleware(): RequestHandler {
return async (req: AuthenticatedRequest, res: Response, next: NextFunction): Promise<void> => {
const tableauAccessToken: string =
getCookie(req, 'workgroup_session_id') || getHeader(req, X_TABLEAU_AUTH_HEADER);

if (!tableauAccessToken) {
next();
return;
}

const config = getConfig();
let passthroughAuthInfo = passthroughAuthInfoCache.get(tableauAccessToken);
if (!passthroughAuthInfo) {
const { server, maxRequestTimeoutMs } = config;

const restApi = new RestApi(server, {
maxRequestTimeoutMs,
});

restApi.setCredentials(tableauAccessToken, 'unknown user id');
const sessionResult = await restApi.authenticatedServerMethods.getCurrentServerSession();
if (!sessionResult.isOk()) {
res.status(401).json({
error: 'invalid_token',
error_description: sessionResult.error,
});
return;
}

passthroughAuthInfo = {
type: 'Passthrough',
username: sessionResult.value.user.name,
userId: sessionResult.value.user.id,
server,
siteId: sessionResult.value.site.id,
raw: tableauAccessToken,
};

passthroughAuthInfoCache.set(tableauAccessToken, passthroughAuthInfo);
}

req.auth = {
token: 'passthrough',
clientId: 'passthrough',
scopes: config.oauth.enforceScopes ? getSupportedMcpScopes() : [],
extra: passthroughAuthInfo,
};
next();
};
}

function getCookie(req: AuthenticatedRequest, cookieName: string): string {
const cookieValue = req.cookies?.[cookieName];
return cookieValue?.toString() ?? '';
}

function getHeader(req: AuthenticatedRequest, headerName: string): string {
const headerValue = req.headers[headerName];
return headerValue?.toString() ?? '';
}
1 change: 1 addition & 0 deletions src/testSetup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ vi.mock('./sdks/tableau/restApi.js', async (importOriginal) => ({
RestApi: vi.fn().mockImplementation(() => ({
signIn: vi.fn().mockResolvedValue(undefined),
signOut: vi.fn().mockResolvedValue(undefined),
setCredentials: vi.fn(),
serverMethods: {
getServerInfo: vi.fn().mockResolvedValue({
productVersion: testProductVersion,
Expand Down