From 71806a44e2bbab1216ba1f85783a1364c0da60d8 Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Fri, 26 Jun 2026 10:29:52 +0700 Subject: [PATCH 1/2] fix(ssh): tear down tunnel relay on socket hangup instead of busy-spinning (#1769) --- CHANGELOG.md | 4 + TablePro/Core/SSH/LibSSH2ChannelIO.swift | 41 ++++ TablePro/Core/SSH/LibSSH2Tunnel.swift | 89 ++----- TablePro/Core/SSH/LibSSH2TunnelFactory.swift | 82 ++----- TablePro/Core/SSH/RelayPollState.swift | 20 ++ TablePro/Core/SSH/SSHChannelRelay.swift | 158 ++++++++++++ .../Core/SSH/RelayPollStateTests.swift | 54 +++++ .../Core/SSH/SSHChannelRelayTests.swift | 228 ++++++++++++++++++ 8 files changed, 537 insertions(+), 139 deletions(-) create mode 100644 TablePro/Core/SSH/LibSSH2ChannelIO.swift create mode 100644 TablePro/Core/SSH/RelayPollState.swift create mode 100644 TablePro/Core/SSH/SSHChannelRelay.swift create mode 100644 TableProTests/Core/SSH/RelayPollStateTests.swift create mode 100644 TableProTests/Core/SSH/SSHChannelRelayTests.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index c83a34977..2a7717d29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- SSH tunnels no longer pin a CPU core after the connection drops. A dropped tunnel is now detected and torn down instead of spinning in its relay loop. (#1769) + ## [0.53.0] - 2026-06-25 ### Added diff --git a/TablePro/Core/SSH/LibSSH2ChannelIO.swift b/TablePro/Core/SSH/LibSSH2ChannelIO.swift new file mode 100644 index 000000000..c74a99365 --- /dev/null +++ b/TablePro/Core/SSH/LibSSH2ChannelIO.swift @@ -0,0 +1,41 @@ +// +// LibSSH2ChannelIO.swift +// TablePro +// + +import Foundation + +import CLibSSH2 + +/// Routes channel reads/writes through the serial `sessionQueue` because libssh2 +/// is not thread-safe per session. Maps libssh2 return codes to the transport +/// agnostic results consumed by `SSHChannelRelay`. +internal struct LibSSH2ChannelIO: SSHChannelIO { + let channel: OpaquePointer + let session: OpaquePointer + let sessionQueue: DispatchQueue + + func read(into buffer: UnsafeMutablePointer, count: Int) -> ChannelReadResult { + let result = sessionQueue.sync { Int(tablepro_libssh2_channel_read(channel, buffer, count)) } + if result > 0 { return .bytes(result) } + if result == 0 { return .closed } + if sessionQueue.sync(execute: { libssh2_channel_eof(channel) }) != 0 { return .closed } + if result == Int(LIBSSH2_ERROR_EAGAIN) { return .wouldBlock } + return .closed + } + + func write(_ buffer: UnsafePointer, count: Int) -> ChannelWriteResult { + let result = sessionQueue.sync { Int(tablepro_libssh2_channel_write(channel, buffer, count)) } + if result > 0 { return .bytes(result) } + if result == Int(LIBSSH2_ERROR_EAGAIN) { return .wouldBlock } + return .closed + } + + func blockDirections() -> RelayDirections { + let directions = sessionQueue.sync { libssh2_session_block_directions(session) } + var result: RelayDirections = [] + if directions & LIBSSH2_SESSION_BLOCK_INBOUND != 0 { result.insert(.inbound) } + if directions & LIBSSH2_SESSION_BLOCK_OUTBOUND != 0 { result.insert(.outbound) } + return result + } +} diff --git a/TablePro/Core/SSH/LibSSH2Tunnel.swift b/TablePro/Core/SSH/LibSSH2Tunnel.swift index b0d07ab19..526122a7c 100644 --- a/TablePro/Core/SSH/LibSSH2Tunnel.swift +++ b/TablePro/Core/SSH/LibSSH2Tunnel.swift @@ -350,82 +350,27 @@ internal final class LibSSH2Tunnel: @unchecked Sendable { /// Blocking relay loop. Runs on `relayQueue`; libssh2 calls go through `sessionQueue`. private func runRelay(clientFD: Int32, channel: OpaquePointer) { - let buffer = UnsafeMutablePointer.allocate(capacity: Self.relayBufferSize) - defer { - buffer.deallocate() - Darwin.close(clientFD) - if self.isRunning { - sessionQueue.sync { - libssh2_channel_close(channel) - libssh2_channel_free(channel) - } - } - } + let relay = SSHChannelRelay( + localFD: clientFD, + transportFD: socketFD, + channelIO: LibSSH2ChannelIO(channel: channel, session: session, sessionQueue: sessionQueue), + bufferSize: Self.relayBufferSize, + isActive: { [weak self] in self?.isRunning ?? false } + ) - while self.isRunning { - var pollFDs = [ - pollfd(fd: clientFD, events: Int16(POLLIN), revents: 0), - pollfd(fd: self.socketFD, events: Int16(POLLIN), revents: 0), - ] + let termination = relay.run() - let pollResult = poll(&pollFDs, 2, 500) // 500ms timeout - if pollResult < 0 { break } + Darwin.close(clientFD) + guard self.isRunning else { return } - // Read from SSH channel when the SSH socket has data or on timeout - // (libssh2 may have internally buffered data) - if pollFDs[1].revents & Int16(POLLIN) != 0 || pollResult == 0 { - let readResult: Int = sessionQueue.sync { - Int(tablepro_libssh2_channel_read(channel, buffer, Self.relayBufferSize)) - } - if readResult > 0 { - var totalSent = 0 - while totalSent < readResult { - let sent = send( - clientFD, - buffer.advanced(by: totalSent), - readResult - totalSent, - 0 - ) - if sent <= 0 { return } - totalSent += sent - } - } else if readResult == 0 || sessionQueue.sync(execute: { libssh2_channel_eof(channel) }) != 0 { - return - } else if readResult != Int(LIBSSH2_ERROR_EAGAIN) { - return - } - } + sessionQueue.sync { + libssh2_channel_close(channel) + libssh2_channel_free(channel) + } - // Read from client -> write to SSH channel - if pollFDs[0].revents & Int16(POLLIN) != 0 { - let clientRead = recv(clientFD, buffer, Self.relayBufferSize, 0) - if clientRead <= 0 { return } - - var totalWritten = 0 - while totalWritten < Int(clientRead) { - let written: Int = sessionQueue.sync { - Int(tablepro_libssh2_channel_write( - channel, - buffer.advanced(by: totalWritten), - Int(clientRead) - totalWritten - )) - } - if written > 0 { - totalWritten += written - } else if written == Int(LIBSSH2_ERROR_EAGAIN) { - let directions = sessionQueue.sync { - libssh2_session_block_directions(self.session) - } - _ = self.waitForSocketDirections( - directions: directions, - socketFD: self.socketFD, - timeoutMs: 1_000 - ) - } else { - return - } - } - } + if termination == .transportHangup { + Self.logger.info("SSH transport hung up, marking tunnel dead for \(self.connectionId)") + markDead() } } diff --git a/TablePro/Core/SSH/LibSSH2TunnelFactory.swift b/TablePro/Core/SSH/LibSSH2TunnelFactory.swift index af6295ad2..e9254563c 100644 --- a/TablePro/Core/SSH/LibSSH2TunnelFactory.swift +++ b/TablePro/Core/SSH/LibSSH2TunnelFactory.swift @@ -735,75 +735,23 @@ internal enum LibSSH2TunnelFactory { return Task.detached { await withCheckedContinuation { (continuation: CheckedContinuation) in relayQueue.async { - let bufferSize = 32_768 - let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) - defer { - buffer.deallocate() - Darwin.close(socketFD) - continuation.resume() - } - - while !Task.isCancelled { - var pollFDs = [ - pollfd(fd: socketFD, events: Int16(POLLIN), revents: 0), - pollfd(fd: sshSocketFD, events: Int16(POLLIN), revents: 0), - ] + let relay = SSHChannelRelay( + localFD: socketFD, + transportFD: sshSocketFD, + channelIO: LibSSH2ChannelIO( + channel: channel, + session: session, + sessionQueue: sessionQueue + ), + bufferSize: 32_768, + isActive: { !Task.isCancelled } + ) - let pollResult = poll(&pollFDs, 2, 500) - if pollResult < 0 { break } + _ = relay.run() - // Channel -> socketpair (serialized libssh2 call) - if pollFDs[1].revents & Int16(POLLIN) != 0 || pollResult == 0 { - let channelRead: Int = sessionQueue.sync { - Int(tablepro_libssh2_channel_read(channel, buffer, bufferSize)) - } - if channelRead > 0 { - var totalSent = 0 - while totalSent < channelRead { - let sent = send( - socketFD, - buffer.advanced(by: totalSent), - channelRead - totalSent, - 0 - ) - if sent <= 0 { return } - totalSent += sent - } - } else if channelRead == 0 - || sessionQueue.sync(execute: { libssh2_channel_eof(channel) }) != 0 { - return - } else if channelRead != Int(LIBSSH2_ERROR_EAGAIN) { - return - } - } - - // Socketpair -> channel - if pollFDs[0].revents & Int16(POLLIN) != 0 { - let socketRead = recv(socketFD, buffer, bufferSize, 0) - if socketRead <= 0 { return } - - var totalWritten = 0 - while totalWritten < Int(socketRead) { - let written: Int = sessionQueue.sync { - Int(tablepro_libssh2_channel_write( - channel, - buffer.advanced(by: totalWritten), - Int(socketRead) - totalWritten - )) - } - if written > 0 { - totalWritten += written - } else if written == Int(LIBSSH2_ERROR_EAGAIN) { - var writePollFD = pollfd( - fd: sshSocketFD, events: Int16(POLLOUT), revents: 0 - ) - _ = poll(&writePollFD, 1, 1_000) - } else { - return - } - } - } - } + shutdown(socketFD, SHUT_RDWR) + Darwin.close(socketFD) + continuation.resume() } } } diff --git a/TablePro/Core/SSH/RelayPollState.swift b/TablePro/Core/SSH/RelayPollState.swift new file mode 100644 index 000000000..fc7e8b35f --- /dev/null +++ b/TablePro/Core/SSH/RelayPollState.swift @@ -0,0 +1,20 @@ +// +// RelayPollState.swift +// TablePro +// + +import Foundation + +internal enum RelayFDState: Equatable { + case idle + case readable + case drainThenStop + case stop +} + +internal func relayFDState(_ revents: Int16) -> RelayFDState { + if revents & Int16(POLLERR | POLLNVAL) != 0 { return .stop } + if revents & Int16(POLLHUP) != 0 { return .drainThenStop } + if revents & Int16(POLLIN) != 0 { return .readable } + return .idle +} diff --git a/TablePro/Core/SSH/SSHChannelRelay.swift b/TablePro/Core/SSH/SSHChannelRelay.swift new file mode 100644 index 000000000..f205b0eb9 --- /dev/null +++ b/TablePro/Core/SSH/SSHChannelRelay.swift @@ -0,0 +1,158 @@ +// +// SSHChannelRelay.swift +// TablePro +// + +import Foundation + +internal struct RelayDirections: OptionSet { + let rawValue: Int32 + + static let inbound = RelayDirections(rawValue: 1 << 0) + static let outbound = RelayDirections(rawValue: 1 << 1) +} + +internal enum ChannelReadResult: Equatable { + case bytes(Int) + case wouldBlock + case closed +} + +internal enum ChannelWriteResult: Equatable { + case bytes(Int) + case wouldBlock + case closed +} + +internal protocol SSHChannelIO { + func read(into buffer: UnsafeMutablePointer, count: Int) -> ChannelReadResult + func write(_ buffer: UnsafePointer, count: Int) -> ChannelWriteResult + func blockDirections() -> RelayDirections +} + +internal enum RelayTermination: Equatable { + case localClosed + case transportHangup + case channelClosed + case cancelled +} + +/// Bidirectional relay between a local socket fd and an SSH channel. +/// Polls the local fd and the SSH transport fd, draining buffered data before +/// tearing down on hangup so the last bytes are not dropped. Hangup and error +/// on either fd terminate the loop instead of spinning on a permanently +/// poll-ready, closed fd. +internal struct SSHChannelRelay { + let localFD: Int32 + let transportFD: Int32 + let channelIO: any SSHChannelIO + let bufferSize: Int + let isActive: () -> Bool + + private static let pollTimeoutMs: Int32 = 500 + private static let writeWaitTimeoutMs: Int32 = 1_000 + + func run() -> RelayTermination { + let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) + defer { buffer.deallocate() } + + while isActive() { + var pollFDs = [ + pollfd(fd: localFD, events: Int16(POLLIN), revents: 0), + pollfd(fd: transportFD, events: Int16(POLLIN), revents: 0), + ] + + let pollResult = poll(&pollFDs, 2, Self.pollTimeoutMs) + if pollResult < 0 { + if errno == EINTR { continue } + return .cancelled + } + + let localState = relayFDState(pollFDs[0].revents) + let transportState = relayFDState(pollFDs[1].revents) + + if transportState == .stop { return .transportHangup } + if localState == .stop { return .localClosed } + + if transportState != .idle || pollResult == 0 { + switch pumpChannelToLocal(buffer) { + case .keepGoing: break + case .channelClosed: return .channelClosed + case .localClosed: return .localClosed + } + } + + if localState == .readable || localState == .drainThenStop { + switch pumpLocalToChannel(buffer) { + case .keepGoing: break + case .channelClosed: return .channelClosed + case .localClosed: return .localClosed + } + } + + if transportState == .drainThenStop { return .transportHangup } + if localState == .drainThenStop { return .localClosed } + } + + return .cancelled + } + + private enum PumpOutcome { + case keepGoing + case channelClosed + case localClosed + } + + private func pumpChannelToLocal(_ buffer: UnsafeMutablePointer) -> PumpOutcome { + switch channelIO.read(into: buffer, count: bufferSize) { + case .bytes(let count): + var totalSent = 0 + while totalSent < count { + let sent = send(localFD, buffer.advanced(by: totalSent), count - totalSent, 0) + if sent <= 0 { return .localClosed } + totalSent += sent + } + return .keepGoing + case .wouldBlock: + return .keepGoing + case .closed: + return .channelClosed + } + } + + private func pumpLocalToChannel(_ buffer: UnsafeMutablePointer) -> PumpOutcome { + let localRead = recv(localFD, buffer, bufferSize, 0) + if localRead <= 0 { return .localClosed } + + var totalWritten = 0 + while totalWritten < Int(localRead) { + switch channelIO.write(buffer.advanced(by: totalWritten), count: Int(localRead) - totalWritten) { + case .bytes(let count): + totalWritten += count + case .wouldBlock: + if !waitForTransport(channelIO.blockDirections()) { return .channelClosed } + case .closed: + return .channelClosed + } + } + return .keepGoing + } + + private func waitForTransport(_ directions: RelayDirections) -> Bool { + var events: Int16 = 0 + if directions.contains(.inbound) { events |= Int16(POLLIN) } + if directions.contains(.outbound) { events |= Int16(POLLOUT) } + guard events != 0 else { return true } + + var pollFD = pollfd(fd: transportFD, events: events, revents: 0) + let rc = poll(&pollFD, 1, Self.writeWaitTimeoutMs) + if rc < 0 { return false } + + switch relayFDState(pollFD.revents) { + case .stop, .drainThenStop: + return false + case .readable, .idle: + return true + } + } +} diff --git a/TableProTests/Core/SSH/RelayPollStateTests.swift b/TableProTests/Core/SSH/RelayPollStateTests.swift new file mode 100644 index 000000000..8dd2a1ab1 --- /dev/null +++ b/TableProTests/Core/SSH/RelayPollStateTests.swift @@ -0,0 +1,54 @@ +// +// RelayPollStateTests.swift +// TableProTests +// +// Tests for relayFDState, the poll revents classifier that decides whether a +// relay fd is idle, readable, draining before teardown, or fatally errored. +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("RelayPollState") +struct RelayPollStateTests { + @Test("No events is idle") + func idle() { + #expect(relayFDState(0) == .idle) + } + + @Test("POLLIN alone is readable") + func readable() { + #expect(relayFDState(Int16(POLLIN)) == .readable) + } + + @Test("POLLHUP drains then stops") + func hangup() { + #expect(relayFDState(Int16(POLLHUP)) == .drainThenStop) + } + + @Test("POLLIN with POLLHUP drains then stops") + func hangupWithBufferedData() { + #expect(relayFDState(Int16(POLLIN) | Int16(POLLHUP)) == .drainThenStop) + } + + @Test("POLLERR is fatal") + func error() { + #expect(relayFDState(Int16(POLLERR)) == .stop) + } + + @Test("POLLNVAL is fatal") + func invalid() { + #expect(relayFDState(Int16(POLLNVAL)) == .stop) + } + + @Test("Fatal errors win over readable data") + func errorWithData() { + #expect(relayFDState(Int16(POLLIN) | Int16(POLLERR)) == .stop) + } + + @Test("Fatal errors win over hangup") + func errorWithHangup() { + #expect(relayFDState(Int16(POLLHUP) | Int16(POLLERR)) == .stop) + } +} diff --git a/TableProTests/Core/SSH/SSHChannelRelayTests.swift b/TableProTests/Core/SSH/SSHChannelRelayTests.swift new file mode 100644 index 000000000..7ee41cf9c --- /dev/null +++ b/TableProTests/Core/SSH/SSHChannelRelayTests.swift @@ -0,0 +1,228 @@ +// +// SSHChannelRelayTests.swift +// TableProTests +// +// Tests for SSHChannelRelay termination behaviour. The relay runs over real +// socketpairs with a scripted channel so the regression case (a closed peer fd +// must terminate the loop instead of busy-spinning) is exercised end to end. +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("SSHChannelRelay") +struct SSHChannelRelayTests { + @Test("Cancellation via isActive stops the relay") + func cancelled() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + let result = runRelay( + localFD: local.a, + transportFD: transport.a, + io: FakeChannelIO(fallback: .wouldBlock), + isActive: { false } + ) + + #expect(result == .cancelled) + } + + @Test("Transport hangup terminates instead of spinning") + func transportHangup() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + transport.closeB() + + let result = runRelay( + localFD: local.a, + transportFD: transport.a, + io: FakeChannelIO(fallback: .wouldBlock) + ) + + #expect(result == .transportHangup) + } + + @Test("Local hangup terminates instead of spinning") + func localHangup() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + local.closeB() + + let result = runRelay( + localFD: local.a, + transportFD: transport.a, + io: FakeChannelIO(fallback: .wouldBlock) + ) + + #expect(result == .localClosed) + } + + @Test("Channel close terminates the relay") + func channelClosed() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + let result = runRelay( + localFD: local.a, + transportFD: transport.a, + io: FakeChannelIO(fallback: .closed) + ) + + #expect(result == .channelClosed) + } + + @Test("Channel data is forwarded to the local socket") + func forwardsChannelData() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + let payload = Data("hello".utf8) + var dummy: UInt8 = 1 + _ = Darwin.send(transport.b, &dummy, 1, 0) + + let io = FakeChannelIO(actions: [.data(payload)], fallback: .closed) + let result = runRelay(localFD: local.a, transportFD: transport.a, io: io) + + #expect(result == .channelClosed) + + var received = [UInt8](repeating: 0, count: payload.count) + let count = recv(local.b, &received, received.count, 0) + #expect(count == payload.count) + #expect(Data(received) == payload) + } + + @Test("Local data is forwarded to the channel") + func forwardsLocalData() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + let payload = Data("world".utf8) + payload.withUnsafeBytes { raw in + _ = Darwin.send(local.b, raw.baseAddress, raw.count, 0) + } + + let io = FakeChannelIO(fallback: .wouldBlock) + let result = runRelay( + localFD: local.a, + transportFD: transport.a, + io: io, + isActive: io.activeUntilWritten(payload.count) + ) + + #expect(result == .cancelled) + #expect(io.written == payload) + } + + private func runRelay( + localFD: Int32, + transportFD: Int32, + io: FakeChannelIO, + isActive: @escaping @Sendable () -> Bool = { true }, + timeout: Double = 3 + ) -> RelayTermination? { + let box = ResultBox() + let semaphore = DispatchSemaphore(value: 0) + DispatchQueue.global().async { + let relay = SSHChannelRelay( + localFD: localFD, + transportFD: transportFD, + channelIO: io, + bufferSize: 32_768, + isActive: isActive + ) + box.value = relay.run() + semaphore.signal() + } + _ = semaphore.wait(timeout: .now() + timeout) + return box.value + } +} + +private final class ResultBox: @unchecked Sendable { + var value: RelayTermination? +} + +private final class SocketPair { + let a: Int32 + let b: Int32 + + init() { + var fds: [Int32] = [0, 0] + _ = socketpair(AF_UNIX, SOCK_STREAM, 0, &fds) + a = fds[0] + b = fds[1] + } + + func closeB() { Darwin.close(b) } + + func close() { + Darwin.close(a) + Darwin.close(b) + } +} + +private final class FakeChannelIO: SSHChannelIO, @unchecked Sendable { + enum Action { + case data(Data) + case wouldBlock + case closed + } + + private let lock = NSLock() + private var actions: [Action] + private let fallback: Action + private var writtenBuffer = Data() + + init(actions: [Action] = [], fallback: Action = .wouldBlock) { + self.actions = actions + self.fallback = fallback + } + + var written: Data { + lock.lock() + defer { lock.unlock() } + return writtenBuffer + } + + func activeUntilWritten(_ target: Int) -> @Sendable () -> Bool { + { [weak self] in (self?.written.count ?? target) < target } + } + + func read(into buffer: UnsafeMutablePointer, count: Int) -> ChannelReadResult { + lock.lock() + defer { lock.unlock() } + let action = actions.isEmpty ? fallback : actions.removeFirst() + switch action { + case .data(let data): + let length = min(data.count, count) + buffer.withMemoryRebound(to: UInt8.self, capacity: length) { destination in + _ = data.copyBytes(to: UnsafeMutableBufferPointer(start: destination, count: length)) + } + return .bytes(length) + case .wouldBlock: + return .wouldBlock + case .closed: + return .closed + } + } + + func write(_ buffer: UnsafePointer, count: Int) -> ChannelWriteResult { + lock.lock() + defer { lock.unlock() } + buffer.withMemoryRebound(to: UInt8.self, capacity: count) { source in + writtenBuffer.append(contentsOf: UnsafeBufferPointer(start: source, count: count)) + } + return .bytes(count) + } + + func blockDirections() -> RelayDirections { [] } +} From 442bf4f91568adddb0c16a9e0737b5cd87124896 Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Fri, 26 Jun 2026 10:40:48 +0700 Subject: [PATCH 2/2] refactor(ssh): detect transport EOF directly and propagate write-path hangup --- TablePro/Core/SSH/SSHChannelRelay.swift | 53 +++++++++++++------ .../Core/SSH/SSHChannelRelayTests.swift | 17 ++++++ 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/TablePro/Core/SSH/SSHChannelRelay.swift b/TablePro/Core/SSH/SSHChannelRelay.swift index f205b0eb9..df6f10aa2 100644 --- a/TablePro/Core/SSH/SSHChannelRelay.swift +++ b/TablePro/Core/SSH/SSHChannelRelay.swift @@ -41,7 +41,9 @@ internal enum RelayTermination: Equatable { /// Polls the local fd and the SSH transport fd, draining buffered data before /// tearing down on hangup so the last bytes are not dropped. Hangup and error /// on either fd terminate the loop instead of spinning on a permanently -/// poll-ready, closed fd. +/// poll-ready, closed fd. Because libssh2 reports EAGAIN rather than a channel +/// close when only the transport dies, a transport that polls readable but is +/// at EOF is detected directly so a half-closed transport cannot spin. internal struct SSHChannelRelay { let localFD: Int32 let transportFD: Int32 @@ -75,10 +77,11 @@ internal struct SSHChannelRelay { if localState == .stop { return .localClosed } if transportState != .idle || pollResult == 0 { - switch pumpChannelToLocal(buffer) { + switch pumpChannelToLocal(buffer, transportReadable: transportState != .idle) { case .keepGoing: break case .channelClosed: return .channelClosed case .localClosed: return .localClosed + case .transportHangup: return .transportHangup } } @@ -87,6 +90,7 @@ internal struct SSHChannelRelay { case .keepGoing: break case .channelClosed: return .channelClosed case .localClosed: return .localClosed + case .transportHangup: return .transportHangup } } @@ -101,9 +105,15 @@ internal struct SSHChannelRelay { case keepGoing case channelClosed case localClosed + case transportHangup } - private func pumpChannelToLocal(_ buffer: UnsafeMutablePointer) -> PumpOutcome { + private enum TransportWait { + case ready + case hangup + } + + private func pumpChannelToLocal(_ buffer: UnsafeMutablePointer, transportReadable: Bool) -> PumpOutcome { switch channelIO.read(into: buffer, count: bufferSize) { case .bytes(let count): var totalSent = 0 @@ -114,6 +124,7 @@ internal struct SSHChannelRelay { } return .keepGoing case .wouldBlock: + if transportReadable, transportAtEOF() { return .transportHangup } return .keepGoing case .closed: return .channelClosed @@ -130,7 +141,10 @@ internal struct SSHChannelRelay { case .bytes(let count): totalWritten += count case .wouldBlock: - if !waitForTransport(channelIO.blockDirections()) { return .channelClosed } + switch waitForTransport(channelIO.blockDirections()) { + case .ready: break + case .hangup: return .transportHangup + } case .closed: return .channelClosed } @@ -138,21 +152,30 @@ internal struct SSHChannelRelay { return .keepGoing } - private func waitForTransport(_ directions: RelayDirections) -> Bool { + private func waitForTransport(_ directions: RelayDirections) -> TransportWait { var events: Int16 = 0 if directions.contains(.inbound) { events |= Int16(POLLIN) } if directions.contains(.outbound) { events |= Int16(POLLOUT) } - guard events != 0 else { return true } + guard events != 0 else { return .ready } - var pollFD = pollfd(fd: transportFD, events: events, revents: 0) - let rc = poll(&pollFD, 1, Self.writeWaitTimeoutMs) - if rc < 0 { return false } - - switch relayFDState(pollFD.revents) { - case .stop, .drainThenStop: - return false - case .readable, .idle: - return true + while true { + var pollFD = pollfd(fd: transportFD, events: events, revents: 0) + let rc = poll(&pollFD, 1, Self.writeWaitTimeoutMs) + if rc < 0 { + if errno == EINTR { continue } + return .hangup + } + switch relayFDState(pollFD.revents) { + case .stop, .drainThenStop: + return .hangup + case .readable, .idle: + return .ready + } } } + + private func transportAtEOF() -> Bool { + var byte: UInt8 = 0 + return recv(transportFD, &byte, 1, MSG_PEEK | MSG_DONTWAIT) == 0 + } } diff --git a/TableProTests/Core/SSH/SSHChannelRelayTests.swift b/TableProTests/Core/SSH/SSHChannelRelayTests.swift index 7ee41cf9c..ad818aeb1 100644 --- a/TableProTests/Core/SSH/SSHChannelRelayTests.swift +++ b/TableProTests/Core/SSH/SSHChannelRelayTests.swift @@ -46,6 +46,23 @@ struct SSHChannelRelayTests { #expect(result == .transportHangup) } + @Test("Transport half-close at EOF terminates instead of spinning") + func transportHalfCloseEOF() { + let local = SocketPair() + let transport = SocketPair() + defer { local.close(); transport.close() } + + shutdown(transport.b, SHUT_WR) + + let result = runRelay( + localFD: local.a, + transportFD: transport.a, + io: FakeChannelIO(fallback: .wouldBlock) + ) + + #expect(result == .transportHangup) + } + @Test("Local hangup terminates instead of spinning") func localHangup() { let local = SocketPair()