diff --git a/Sources/AWSLambdaRuntime/HTTPServer/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/HTTPServer/Lambda+LocalServer.swift index 84db6505..8f040c72 100644 --- a/Sources/AWSLambdaRuntime/HTTPServer/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/HTTPServer/Lambda+LocalServer.swift @@ -131,41 +131,48 @@ internal struct LambdaHTTPServer { _ closure: sending @escaping () async throws -> Result ) async throws -> Result { - let channel = try await ServerBootstrap(group: eventLoopGroup) + let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint) + + // Use traditional callback-based API to avoid NIO async channel cancellation race + // See: https://github.com/apple/swift-nio/issues/2637 + let bootstrap = ServerBootstrap(group: eventLoopGroup) .serverChannelOption(.backlog, value: 256) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) .childChannelOption(.maxMessagesPerRead, value: 1) - .bind( - host: host, - port: port - ) { channel in - channel.eventLoop.makeCompletedFuture { + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in + // Handle connection in a detached task + do { + let asyncChannel = try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerResponsePart.self + ) + ) - try channel.pipeline.syncOperations.configureHTTPServerPipeline( - withErrorHandling: true - ) + Task.detached { + await server.handleConnection(channel: asyncChannel, logger: logger) + } - return try NIOAsyncChannel( - wrappingChannelSynchronously: channel, - configuration: NIOAsyncChannel.Configuration( - inboundType: HTTPServerRequestPart.self, - outboundType: HTTPServerResponsePart.self - ) - ) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } } } + let channel = try await bootstrap.bind(host: host, port: port).get() + // it's ok to keep this at `info` level because it is only used for local testing and unit tests logger.info( "Server started and listening", metadata: [ - "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", - "port": "\(channel.channel.localAddress?.port ?? 0)", + "host": "\(channel.localAddress?.ipAddress?.debugDescription ?? "")", + "port": "\(channel.localAddress?.port ?? 0)", ] ) - let server = LambdaHTTPServer(invocationEndpoint: invocationEndpoint) - // Sadly the Swift compiler does not understand that the passed in closure will only be // invoked once. Because of this we need an unsafe transfer box here. Buuuh! let closureBox = UnsafeTransferBox(value: closure) @@ -183,31 +190,10 @@ internal struct LambdaHTTPServer { } } - // this Task will create one subtask to handle each individual connection + // Server task - just wait for channel to close group.addTask { do { - // We are handling each incoming connection in a separate child task. It is important - // to use a discarding task group here which automatically discards finished child tasks. - // A normal task group retains all child tasks and their outputs in memory until they are - // consumed by iterating the group or by exiting the group. Since, we are never consuming - // the results of the group we need the group to automatically discard them; otherwise, this - // would result in a memory leak over time. - try await withTaskCancellationHandler { - try await withThrowingDiscardingTaskGroup { taskGroup in - try await channel.executeThenClose { inbound in - for try await connectionChannel in inbound { - - taskGroup.addTask { - logger.trace("Handling a new connection") - await server.handleConnection(channel: connectionChannel, logger: logger) - logger.trace("Done handling the connection") - } - } - } - } - } onCancel: { - channel.channel.close(promise: nil) - } + try await channel.closeFuture.get() return .serverReturned(.success(())) } catch { return .serverReturned(.failure(error)) @@ -216,35 +202,29 @@ internal struct LambdaHTTPServer { // Now that the local HTTP server and LambdaHandler tasks are started, wait for the // first of the two that will terminate. - // When the first task terminates, cancel the group and collect the result of the - // second task. - - // collect and return the result of the LambdaHandler + // When first task completes, close the server channel and wait for the other task. + // Note: we intentionally do not call `group.cancelAll()` here. Closing the channel causes + // the server task (which is awaiting `channel.closeFuture`) to complete naturally, and + // we then wait for the remaining task to finish via `group.next()`. let serverOrHandlerResult1 = await group.next()! - group.cancelAll() - // Cancel all waiting continuations in the pools to prevent hangs + channel.close(promise: nil) + server.invocationPool.cancelAll() server.responsePool.cancelAll() + let serverOrHandlerResult2 = await group.next()! + switch serverOrHandlerResult1 { case .closureResult(let result): return result case .serverReturned(let result): - - if result.maybeError is CancellationError { - logger.trace("Server's task cancelled") - } else { - logger.error( - "Server shutdown before closure completed", - metadata: [ - "error": "\(result.maybeError != nil ? "\(result.maybeError!)" : "none")" - ] - ) + if let error = result.maybeError { + logger.error("Server error: \(error)") } - switch await group.next()! { + switch serverOrHandlerResult2 { case .closureResult(let result): return result @@ -271,82 +251,74 @@ internal struct LambdaHTTPServer { var requestBody: ByteBuffer? var requestId: String? - // Note that this method is non-throwing and we are catching any error. - // We do this since we don't want to tear down the whole server when a single connection - // encounters an error. - await withTaskCancellationHandler { - do { - try await channel.executeThenClose { inbound, outbound in - for try await inboundData in inbound { - switch inboundData { - case .head(let head): - requestHead = head - requestId = getRequestId(from: requestHead) - - // for streaming requests, push a partial head response - if self.isStreamingResponse(requestHead) { - self.responsePool.push( - LocalServerResponse( - id: requestId, - status: .ok - ) - ) - } - - case .body(let body): - precondition(requestHead != nil, "Received .body without .head") - - // if this is a request from a Streaming Lambda Handler, - // stream the response instead of buffering it - if self.isStreamingResponse(requestHead) { - self.responsePool.push( - LocalServerResponse(id: requestId, body: body) - ) - } else { - requestBody.setOrWriteImmutableBuffer(body) - } - - case .end: - precondition(requestHead != nil, "Received .end without .head") - - if self.isStreamingResponse(requestHead) { - // for streaming response, send the final response - self.responsePool.push( - LocalServerResponse(id: requestId, final: true) + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + switch inboundData { + case .head(let head): + requestHead = head + requestId = getRequestId(from: requestHead) + + // for streaming requests, push a partial head response + if self.isStreamingResponse(requestHead) { + self.responsePool.push( + LocalServerResponse( + id: requestId, + status: .ok ) + ) + } - // Send acknowledgment back to Lambda runtime client for streaming END - // This is the single HTTP response to the chunked HTTP request - try await self.sendResponse( - .init(id: requestId, status: .accepted, final: true), - outbound: outbound, - logger: logger - ) - } else { - // process the buffered response for non streaming requests - try await self.processRequestAndSendResponse( - head: requestHead, - body: requestBody, - outbound: outbound, - logger: logger - ) - } + case .body(let body): + precondition(requestHead != nil, "Received .body without .head") + + // if this is a request from a Streaming Lambda Handler, + // stream the response instead of buffering it + if self.isStreamingResponse(requestHead) { + self.responsePool.push( + LocalServerResponse(id: requestId, body: body) + ) + } else { + requestBody.setOrWriteImmutableBuffer(body) + } - // reset the request state for next request - requestHead = nil - requestBody = nil - requestId = nil + case .end: + precondition(requestHead != nil, "Received .end without .head") + + if self.isStreamingResponse(requestHead) { + // for streaming response, send the final response + self.responsePool.push( + LocalServerResponse(id: requestId, final: true) + ) + + // Send acknowledgment back to Lambda runtime client for streaming END + // This is the single HTTP response to the chunked HTTP request + try await self.sendResponse( + .init(id: requestId, status: .accepted, final: true), + outbound: outbound, + logger: logger + ) + } else { + // process the buffered response for non streaming requests + try await self.processRequestAndSendResponse( + head: requestHead, + body: requestBody, + outbound: outbound, + logger: logger + ) } + + // reset the request state for next request + requestHead = nil + requestBody = nil + requestId = nil } } - } catch let error as CancellationError { - logger.trace("The task was cancelled", metadata: ["error": "\(error)"]) - } catch { - logger.error("Hit error: \(error)") } - - } onCancel: { - channel.channel.close(promise: nil) + } catch let error as CancellationError { + logger.trace("The task was cancelled", metadata: ["error": "\(error)"]) + } catch { + logger.error("Hit error: \(error)") } }