Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3157,6 +3157,10 @@ PLUGINS_CLI_MARKUP_MODE=rich
# Current Ed25519 public key (derived automatically if private key is set)
# ED25519_PUBLIC_KEY=

# SSL context cache settings
# SSL_CONTEXT_CACHE_MAX_SIZE=100
# SSL_CONTEXT_CACHE_TTL=


# =============================================================================
# Bootstrap additional system roles
Expand Down
4 changes: 3 additions & 1 deletion docs/docs/manage/self-signed-certificates.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ When the gateway invokes a tool from an MCP server with a custom CA certificate:
1. **SSL Context Creation**: A custom SSL context is created using Python's `ssl.create_default_context()`
2. **Certificate Loading**: The CA certificate is loaded using `ctx.load_verify_locations(cadata=ca_certificate)`
3. **Signature Validation** (if enabled): The certificate signature is validated using Ed25519 to ensure it hasn't been tampered with
4. **HTTPS Client Configuration**: The SSL context is passed to the HTTPX client as the `verify` parameter
4. **mTLS client cert/key support**: If `client_cert` and `client_key` are configured for the gateway, the SSL context is loaded with `ctx.load_cert_chain(client_cert, client_key)` for mutual TLS
5. **HTTPS Client Configuration**: The SSL context is passed to the HTTPX client as the `verify` parameter
5. **Secure Connection**: All HTTPS requests to the MCP server use the custom CA certificate for validation
6. **HTTP Bypass**: For plain `http://` gateway URLs, SSL context creation is skipped and default HTTPX verification is used for reduced overhead

### Usage During Gateway Registration

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""add_gateway_client_cert_and_key

Revision ID: 615af4ab94b4
Revises: 20a0e0538ac5
Create Date: 2026-03-20 10:47:06.592968

"""

# Standard
from typing import Sequence, Union

# Third-Party
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision: str = "615af4ab94b4"
down_revision: Union[str, Sequence[str], None] = "20a0e0538ac5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
bind = op.get_bind()
inspector = sa.inspect(bind)

if "gateways" not in inspector.get_table_names():
return

columns = [col["name"] for col in inspector.get_columns("gateways")]

if "client_cert" not in columns:
op.add_column("gateways", sa.Column("client_cert", sa.Text(), nullable=True))

if "client_key" not in columns:
op.add_column("gateways", sa.Column("client_key", sa.Text(), nullable=True))


def downgrade() -> None:
"""Downgrade schema."""
bind = op.get_bind()
inspector = sa.inspect(bind)

if "gateways" not in inspector.get_table_names():
return

columns = [col["name"] for col in inspector.get_columns("gateways")]

if "client_cert" in columns:
op.drop_column("gateways", "client_cert")

if "client_key" in columns:
op.drop_column("gateways", "client_key")
4 changes: 4 additions & 0 deletions mcpgateway/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4594,6 +4594,10 @@ class Gateway(Base):
ca_certificate_sig: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
signing_algorithm: Mapped[Optional[str]] = mapped_column(String(20), nullable=True, default="ed25519") # e.g., "sha256"

# mTLS client certificate/key for upstream gateway authentication
client_cert: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
client_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)

# Relationship with local tools this gateway provides
tools: Mapped[List["Tool"]] = relationship(back_populates="gateway", foreign_keys="Tool.gateway_id", cascade="all, delete-orphan", passive_deletes=True)

Expand Down
57 changes: 57 additions & 0 deletions mcpgateway/handlers/signal_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
"""Signal handlers for ContextForge Gateway.

Provides SIGHUP handling for certificate rotation without restart.
"""

# Standard
import asyncio
import logging
from typing import Any

logger = logging.getLogger(__name__)


async def sighup_reload() -> None:
"""Clear SSL context cache and drain MCP session pool on SIGHUP for certificate rotation.

Clears the SSL context cache to force recreation of SSL contexts
with potentially updated certificates, and drains the MCP session
pool so pooled connections reconnect with new TLS state.
"""
try:
# First-Party
from mcpgateway.utils.ssl_context_cache import clear_ssl_context_cache # pylint: disable=import-outside-toplevel

clear_ssl_context_cache()
logger.info("SIGHUP: SSL context cache cleared")
except Exception as exc:
logger.error(f"SIGHUP handler failed to clear SSL context cache: {exc}")

try:
# First-Party
from mcpgateway.services.mcp_session_pool import drain_mcp_session_pool # pylint: disable=import-outside-toplevel

await drain_mcp_session_pool()
logger.info("SIGHUP: MCP session pool drained for TLS rotation")
except Exception as exc:
logger.debug(f"SIGHUP: MCP session pool drain skipped: {exc}")


def sighup_handler(_signum: int, _frame: Any) -> None:
"""Handle SIGHUP signal by scheduling async SSL cache reload.

Signal handler that safely schedules an asynchronous task to clear
the SSL context cache. Uses the running event loop to create a task
for the async reload operation.

Args:
_signum: Signal number (unused but required by signal handler signature)
_frame: Current stack frame (unused but required by signal handler signature)
"""
logger.info("Received SIGHUP signal, scheduling SSL context cache refresh")
try:
event_loop = asyncio.get_running_loop()
event_loop.create_task(sighup_reload())
except RuntimeError:
logger.warning("SIGHUP received but event loop not running; skipping async reload")
12 changes: 12 additions & 0 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import json
import logging
import re
import signal
import sys
from typing import Any, AsyncIterator, Dict, List, Optional, TypeAlias, Union
from urllib.parse import urlparse, urlunparse
Expand Down Expand Up @@ -1701,6 +1702,11 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:

logger.info("All services initialized successfully")

# First-Party
from mcpgateway.handlers.signal_handlers import sighup_handler # pylint: disable=import-outside-toplevel

signal.signal(signal.SIGHUP, sighup_handler)

# Start cache invalidation subscriber for cross-worker cache synchronization
# First-Party
from mcpgateway.cache.registry_cache import get_cache_invalidation_subscriber # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -1769,6 +1775,12 @@ async def run_log_aggregation_loop() -> None:
raise SystemExit(1)
raise
finally:
# Restore default SIGHUP handling in case we reset signal handlers.
try:
signal.signal(signal.SIGHUP, signal.SIG_DFL)
except Exception as exc: # pragma: no cover - defensive
logger.debug(f"Failed to restore default SIGHUP handler: {exc}")

if aggregation_stop_event is not None:
aggregation_stop_event.set()
for task in (aggregation_backfill_task, aggregation_loop_task):
Expand Down
20 changes: 20 additions & 0 deletions mcpgateway/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2683,6 +2683,10 @@ class GatewayCreate(BaseModelWithConfigDict):
ca_certificate_sig: Optional[str] = Field(None, description="Signature of the custom CA certificate for integrity verification")
signing_algorithm: Optional[str] = Field("ed25519", description="Algorithm used for signing the CA certificate")

# mTLS client certificate/key
client_cert: Optional[str] = Field(None, description="Client TLS certificate for mTLS authentication")
client_key: Optional[str] = Field(None, description="Client TLS key for mTLS authentication")

# Per-gateway refresh configuration
refresh_interval_seconds: Optional[int] = Field(None, ge=60, description="Per-gateway refresh interval in seconds (minimum 60); uses global default if not set")

Expand Down Expand Up @@ -3025,6 +3029,15 @@ class GatewayUpdate(BaseModelWithConfigDict):
# Gateway mode configuration
gateway_mode: Optional[str] = Field(None, description="Gateway mode: 'cache' (database caching, default) or 'direct_proxy' (pass-through mode with no caching)", pattern="^(cache|direct_proxy)$")

# CA certificate configuration for custom TLS trust
ca_certificate: Optional[str] = Field(None, description="Custom CA certificate for TLS verification")
ca_certificate_sig: Optional[str] = Field(None, description="Signature of the custom CA certificate")
signing_algorithm: Optional[str] = Field(None, description="Algorithm used for signing the CA certificate")

# mTLS client TLS certificate and key
client_cert: Optional[str] = Field(None, description="Client TLS certificate for mTLS gateway authentication")
client_key: Optional[str] = Field(None, description="Client TLS key for mTLS gateway authentication")

@field_validator("tags")
@classmethod
def validate_tags(cls, v: Optional[List[str]]) -> List[str]:
Expand Down Expand Up @@ -3319,6 +3332,11 @@ class GatewayRead(BaseModelWithConfigDict):
last_seen: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc), description="Last seen timestamp")

passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through from client to target")
ca_certificate: Optional[str] = Field(default=None, description="Custom CA certificate for TLS verification")
ca_certificate_sig: Optional[str] = Field(default=None, description="Signature of the custom CA certificate")
signing_algorithm: Optional[str] = Field(default="ed25519", description="Algorithm used for signing the CA certificate")
client_cert: Optional[str] = Field(default=None, description="Client TLS certificate for mTLS authentication")
client_key: Optional[str] = Field(default=None, description="Client TLS key for mTLS authentication")
# Authorizations
auth_type: Optional[str] = Field(None, description="auth_type: basic, bearer, authheaders, oauth, query_param, or None")
auth_value: Optional[str] = Field(None, description="auth value: username/password or token or custom headers")
Expand Down Expand Up @@ -3591,6 +3609,8 @@ def masked(self) -> "GatewayRead":
masked_data["auth_token_unmasked"] = None
masked_data["auth_header_value_unmasked"] = None
masked_data["auth_headers_unmasked"] = None
# SECURITY: Mask mTLS client private key
masked_data["client_key"] = settings.masked_auth_value if masked_data.get("client_key") else None
return GatewayRead.model_validate(masked_data)


Expand Down
Loading
Loading