diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index 49104863c..30c7c877f 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -157,7 +157,7 @@ final class Transaction: break case .forwardStreamFinished(let executor): - executor.finishRequestBodyStream(self, promise: nil) + executor.finishRequestBodyStream(trailers: nil, request: self, promise: nil) } return } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index fc7e0af49..9ca82f9a9 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -242,7 +242,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .sendBodyPart(let part, let writePromise): context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise) - case .sendRequestEnd(let writePromise, let finalAction): + case .sendRequestEnd(let trailers, let writePromise, let finalAction): let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) // We need to defer succeeding the old request to avoid ordering issues @@ -282,7 +282,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } } - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) + context.writeAndFlush(self.wrapOutboundOut(.end(trailers)), promise: writePromise) if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(readTimeoutAction, context: context) @@ -339,7 +339,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(buffer) - case .forwardResponseEnd(let finalAction, let buffer): + case .forwardResponseEnd(let finalAction, let buffer, let trailers): // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet @@ -358,15 +358,15 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .close: self.request = nil context.close(promise: nil) - oldRequest.receiveResponseEnd(buffer, trailers: nil) + oldRequest.receiveResponseEnd(buffer, trailers: trailers) case .none: - oldRequest.receiveResponseEnd(buffer, trailers: nil) + oldRequest.receiveResponseEnd(buffer, trailers: trailers) case .informConnectionIsIdle: self.request = nil self.onConnectionIdle() - oldRequest.receiveResponseEnd(buffer, trailers: nil) + oldRequest.receiveResponseEnd(buffer, trailers: trailers) } case .failRequest(let error, let finalAction): @@ -504,14 +504,18 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - fileprivate func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + fileprivate func finishRequestBodyStream0( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamFinished(promise: promise) + let action = self.state.requestStreamFinished(trailers: trailers, promise: promise) self.run(action, context: context) } @@ -565,9 +569,13 @@ extension HTTP1ClientChannelHandler { } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + func finishRequestBodyStream( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { self.loopBound.execute { - $0.finishRequestBodyStream0(request, promise: promise) + $0.finishRequestBodyStream0(trailers: trailers, request: request, promise: promise) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index 60a98b333..5281dca4b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -57,7 +57,7 @@ struct HTTP1ConnectionStateMachine { startIdleTimer: Bool ) case sendBodyPart(IOData, EventLoopPromise?) - case sendRequestEnd(EventLoopPromise?, FinalSuccessfulStreamAction) + case sendRequestEnd(trailers: HTTPHeaders?, EventLoopPromise?, FinalSuccessfulStreamAction) case failSendBodyPart(Error, EventLoopPromise?) case failSendStreamFinished(Error, EventLoopPromise?) @@ -66,7 +66,7 @@ struct HTTP1ConnectionStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case forwardResponseEnd(FinalSuccessfulStreamAction, CircularBuffer) + case forwardResponseEnd(FinalSuccessfulStreamAction, CircularBuffer, HTTPHeaders?) case failRequest(Error, FinalFailedStreamAction) @@ -218,13 +218,13 @@ struct HTTP1ConnectionStateMachine { } } - mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { + mutating func requestStreamFinished(trailers: HTTPHeaders?, promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamFinished(promise: promise) + let action = requestStateMachine.requestStreamFinished(trailers: trailers, promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } @@ -427,7 +427,7 @@ extension HTTP1ConnectionStateMachine.State { return .resumeRequestBodyStream case .sendBodyPart(let part, let writePromise): return .sendBodyPart(part, writePromise) - case .sendRequestEnd(let writePromise, let finalAction): + case .sendRequestEnd(let trailers, let writePromise, let finalAction): guard case .inRequest(_, close: let close) = self else { assertionFailure("Invalid state: \(self)") self = .closing @@ -450,13 +450,13 @@ extension HTTP1ConnectionStateMachine.State { case .none: newFinalAction = .none } - return .sendRequestEnd(writePromise, newFinalAction) + return .sendRequestEnd(trailers: trailers, writePromise, newFinalAction) case .forwardResponseHead(let head, let pauseRequestBodyStream): return .forwardResponseHead(head, pauseRequestBodyStream: pauseRequestBodyStream) case .forwardResponseBodyParts(let parts): return .forwardResponseBodyParts(parts) - case .forwardResponseEnd(let finalAction, let finalParts): + case .forwardResponseEnd(let finalAction, let finalParts, let trailers): guard case .inRequest(_, close: let close) = self else { assertionFailure("Invalid state: \(self)") self = .closing @@ -480,7 +480,7 @@ extension HTTP1ConnectionStateMachine.State { // request is ongoing. request stream is still alive newFinalAction = .none } - return .forwardResponseEnd(newFinalAction, finalParts) + return .forwardResponseEnd(newFinalAction, finalParts, trailers) case .failRequest(let error, let finalAction): switch self { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 3daa95289..2022fec4e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -196,7 +196,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { case .sendBodyPart(let data, let writePromise): context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: writePromise) - case .sendRequestEnd(let writePromise, let finalAction): + case .sendRequestEnd(let trailers, let writePromise, let finalAction): let promise = writePromise ?? context.eventLoop.makePromise(of: Void.self) // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet @@ -205,7 +205,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { request.requestBodyStreamSent() } - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: promise) + context.writeAndFlush(self.wrapOutboundOut(.end(trailers)), promise: promise) if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(readTimeoutAction, context: context) @@ -256,10 +256,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // the right result for HTTP/1). In the h2 case we MUST always close. self.runFailedFinalAction(finalAction, context: context, error: error) - case .forwardResponseEnd(let finalAction, let finalParts): + case .forwardResponseEnd(let finalAction, let finalParts, let trailers): // We can force unwrap the request here, as we have just validated in the state machine, // that the request object is still present. - self.request!.receiveResponseEnd(finalParts, trailers: nil) + self.request!.receiveResponseEnd(finalParts, trailers: trailers) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) @@ -405,13 +405,17 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + private func finishRequestBodyStream0( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return } - let action = self.state.requestStreamFinished(promise: promise) + let action = self.state.requestStreamFinished(trailers: trailers, promise: promise) self.run(action, context: context) } @@ -461,9 +465,13 @@ extension HTTP2ClientRequestHandler { } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + func finishRequestBodyStream( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { self.loopBound.execute { - $0.finishRequestBodyStream0(request, promise: promise) + $0.finishRequestBodyStream0(trailers: trailers, request: request, promise: promise) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index b502ad034..32308a6be 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -185,7 +185,11 @@ protocol HTTPRequestExecutor: Sendable { /// Signals that the request body stream has finished /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func finishRequestBodyStream(_ task: HTTPExecutableRequest, promise: EventLoopPromise?) + func finishRequestBodyStream( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) /// Signals that more bytes from response body stream can be consumed. /// diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index f876a321b..cef736063 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -102,7 +102,7 @@ struct HTTPRequestStateMachine { startIdleTimer: Bool ) case sendBodyPart(IOData, EventLoopPromise?) - case sendRequestEnd(EventLoopPromise?, FinalSuccessfulRequestAction) + case sendRequestEnd(trailers: HTTPHeaders?, EventLoopPromise?, FinalSuccessfulRequestAction) case failSendBodyPart(Error, EventLoopPromise?) case failSendStreamFinished(Error, EventLoopPromise?) @@ -111,7 +111,7 @@ struct HTTPRequestStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case forwardResponseEnd(FinalSuccessfulRequestAction, CircularBuffer) + case forwardResponseEnd(FinalSuccessfulRequestAction, CircularBuffer, HTTPHeaders?) case failRequest(Error, FinalFailedRequestAction) @@ -353,7 +353,7 @@ struct HTTPRequestStateMachine { } } - mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { + mutating func requestStreamFinished(trailers: HTTPHeaders?, promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable, @@ -370,7 +370,7 @@ struct HTTPRequestStateMachine { } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd(promise, .none) + return .sendRequestEnd(trailers: trailers, promise, .none) case .running( .streaming(let expectedBodyLength, let sentBodyBytes, _), @@ -385,7 +385,7 @@ struct HTTPRequestStateMachine { } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd(promise, .none) + return .sendRequestEnd(trailers: trailers, promise, .none) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { @@ -395,7 +395,7 @@ struct HTTPRequestStateMachine { } self.state = .finished - return .sendRequestEnd(promise, .requestDone) + return .sendRequestEnd(trailers: trailers, promise, .requestDone) case .failed(let error): return .failSendStreamFinished(error, promise) @@ -497,8 +497,8 @@ struct HTTPRequestStateMachine { return self.receivedHTTPResponseHead(head) case .body(let body): return self.receivedHTTPResponseBodyPart(body) - case .end: - return self.receivedHTTPResponseEnd() + case .end(let trailers): + return self.receivedHTTPResponseEnd(trailers: trailers) } } @@ -618,7 +618,7 @@ struct HTTPRequestStateMachine { } } - private mutating func receivedHTTPResponseEnd() -> Action { + private mutating func receivedHTTPResponseEnd(trailers: HTTPHeaders?) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: preconditionFailure( @@ -648,7 +648,7 @@ struct HTTPRequestStateMachine { ), .endReceived ) - return .forwardResponseEnd(.none, remainingBuffer) + return .forwardResponseEnd(.none, remainingBuffer, trailers) case .close: // If we receive a `.close` as a connectionAction from the responseStreamState @@ -672,7 +672,7 @@ struct HTTPRequestStateMachine { // connection should be closed anyway. let (remainingBuffer, _) = responseStreamState.end() state = .finished - return .forwardResponseEnd(.close, remainingBuffer) + return .forwardResponseEnd(.close, remainingBuffer, trailers) } case .running(.endSent, .receivingBody(_, var responseStreamState)): @@ -681,9 +681,9 @@ struct HTTPRequestStateMachine { state = .finished switch action { case .none: - return .forwardResponseEnd(.requestDone, remainingBuffer) + return .forwardResponseEnd(.requestDone, remainingBuffer, trailers) case .close: - return .forwardResponseEnd(.close, remainingBuffer) + return .forwardResponseEnd(.close, remainingBuffer, trailers) } } diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 6bfbf09f4..67385e3f1 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -237,7 +237,7 @@ final class RequestBag: Sendabl promise.futureResult.whenSuccess { self.delegate.didSendRequest(task: self.task) } - writer.finishRequestBodyStream(self, promise: promise) + writer.finishRequestBodyStream(trailers: nil, request: self, promise: promise) case .forwardStreamFinishedAndSucceedTask(let writer, let writerPromise): let promise = writerPromise ?? self.task.eventLoop.makePromise(of: Void.self) @@ -256,7 +256,7 @@ final class RequestBag: Sendabl self.task.promise.fail(error) } } - writer.finishRequestBodyStream(self, promise: promise) + writer.finishRequestBodyStream(trailers: nil, request: self, promise: promise) case .forwardStreamFailureAndFailTask(let writer, let error, let promise): writer.cancelRequest(self) diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 0d871b7dc..72040296a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -897,6 +897,55 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { // and ensure that the state machine can tolerate this embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + + func testSendingAndReceivingTrailers() async throws { + let eventLoop = EmbeddedEventLoop() + let handler = HTTP1ClientChannelHandler( + eventLoop: eventLoop, + backgroundLogger: Logger(label: "no-op", factory: SwiftLogNoOpLogHandler.init), + connectionIdLoggerMetadata: "test connection" + ) + let channel = EmbeddedChannel(handlers: [handler], loop: eventLoop) + XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()) + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + let request = MockHTTPExecutableRequest( + head: .init(version: .http1_1, method: .POST, uri: "http://localhost/"), + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .stream), + raiseErrorIfUnimplementedMethodIsCalled: false + ) + + let executor = handler.requestExecutor + request.resumeRequestBodyStreamCallback = { + executor.writeRequestBodyPart(.byteBuffer(.init(string: "Hello World")), request: request, promise: nil) + executor.finishRequestBodyStream(trailers: ["trailer": "foo"], request: request, promise: nil) + } + + request.receiveResponseEndCallback = { (_, trailers) in + XCTAssertEqual(trailers, ["trailer": "bar"]) + } + + channel.write(request, promise: nil) + + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .head(request.requestHead)) + XCTAssertEqual( + try channel.readOutbound(as: HTTPClientRequestPart.self), + .body(.byteBuffer(.init(string: "Hello World"))) + ) + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .end(["trailer": "foo"])) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(.init(string: "Foo Bar")))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(["trailer": "bar"]))) + + XCTAssertEqual( + request.events.map(\.kind), + [ + .willExecuteRequest, .requestHeadSent, .resumeRequestBodyStream, .requestBodySent, .receiveResponseHead, + .receiveResponseEnd, + ] + ) + } } final class TestBackpressureWriter: Sendable { diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index a2fdfc51c..c9d29d1fc 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -52,7 +52,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil, .none)) + XCTAssertEqual( + state.requestStreamFinished(trailers: nil, promise: nil), + .sendRequestEnd(trailers: nil, nil, .none) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual( @@ -63,7 +66,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual( state.channelRead(.end(nil)), - .forwardResponseEnd(.informConnectionIsIdle, .init([responseBody])) + .forwardResponseEnd(.informConnectionIsIdle, [responseBody], nil) ) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -99,7 +102,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.informConnectionIsIdle, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.informConnectionIsIdle, [], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) } @@ -143,7 +146,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [responseBody], nil)) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) } @@ -166,7 +169,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [responseBody], nil)) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) } @@ -195,7 +198,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual( state.channelRead(.end(nil)), - .forwardResponseEnd(.informConnectionIsIdle, .init([responseBody])) + .forwardResponseEnd(.informConnectionIsIdle, [responseBody], nil) ) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) } @@ -220,7 +223,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [responseBody], nil)) } func testNIOTriggersChannelActiveTwice() { @@ -373,7 +376,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) } func testWeDontCrashAfterEarlyHintsAndConnectionClose() { @@ -451,10 +454,10 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { return lhsData == rhsData case ( - .forwardResponseEnd(let lhsFinalAction, let lhsFinalBuffer), - .forwardResponseEnd(let rhsFinalAction, let rhsFinalBuffer) + .forwardResponseEnd(let lhsFinalAction, let lhsFinalBuffer, let lhsTrailers), + .forwardResponseEnd(let rhsFinalAction, let rhsFinalBuffer, let rhsTrailers) ): - return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer + return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer && lhsTrailers == rhsTrailers case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): return lhsFinalAction == rhsFinalAction diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 71f7f3d1a..07750eafd 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -576,4 +576,50 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { channel.writeAndFlush(request, promise: nil) XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } + + func testSendingAndReceivingTrailers() async throws { + let eventLoop = EmbeddedEventLoop() + let handler = HTTP2ClientRequestHandler(eventLoop: eventLoop) + let channel = EmbeddedChannel(handlers: [handler], loop: eventLoop) + XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()) + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + let request = MockHTTPExecutableRequest( + head: .init(version: .http1_1, method: .POST, uri: "http://localhost/"), + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .stream), + raiseErrorIfUnimplementedMethodIsCalled: false + ) + + let executor = handler.requestExecutor + request.resumeRequestBodyStreamCallback = { + executor.writeRequestBodyPart(.byteBuffer(.init(string: "Hello World")), request: request, promise: nil) + executor.finishRequestBodyStream(trailers: ["trailer": "foo"], request: request, promise: nil) + } + + request.receiveResponseEndCallback = { (_, trailers) in + XCTAssertEqual(trailers, ["trailer": "bar"]) + } + + channel.write(request, promise: nil) + + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .head(request.requestHead)) + XCTAssertEqual( + try channel.readOutbound(as: HTTPClientRequestPart.self), + .body(.byteBuffer(.init(string: "Hello World"))) + ) + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .end(["trailer": "foo"])) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(.init(string: "Foo Bar")))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(["trailer": "bar"]))) + + XCTAssertEqual( + request.events.map(\.kind), + [ + .willExecuteRequest, .requestHeadSent, .resumeRequestBodyStream, .requestBodySent, .receiveResponseHead, + .receiveResponseEnd, + ] + ) + } + } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index c5823ff52..b4845005c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -37,7 +37,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]), nil)) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -77,7 +77,10 @@ class HTTPRequestStateMachineTests: XCTestCase { // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil, .none)) + XCTAssertEqual( + state.requestStreamFinished(trailers: nil, promise: nil), + .sendRequestEnd(trailers: nil, nil, .none) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual( @@ -86,7 +89,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]), nil)) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -132,7 +135,10 @@ class HTTPRequestStateMachineTests: XCTestCase { let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamFinished(trailers: nil, promise: nil).assertFailRequest( + HTTPClientError.bodyLengthMismatch, + .close(nil) + ) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { @@ -169,7 +175,7 @@ class HTTPRequestStateMachineTests: XCTestCase { "Expected to drop all stream data after having received a response head, with status >= 300" ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) XCTAssertEqual( state.requestStreamPartReceived(part, promise: nil), @@ -178,7 +184,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) XCTAssertEqual( - state.requestStreamFinished(promise: nil), + state.requestStreamFinished(trailers: nil, promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300" ) @@ -230,7 +236,7 @@ class HTTPRequestStateMachineTests: XCTestCase { "Expected to drop all stream data after having received a response head, with status >= 300" ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) XCTAssertEqual( state.requestStreamPartReceived(part, promise: nil), @@ -239,7 +245,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) XCTAssertEqual( - state.requestStreamFinished(promise: nil), + state.requestStreamFinished(trailers: nil, promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300" ) @@ -267,13 +273,16 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.none, [], nil)) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil, .requestDone)) + XCTAssertEqual( + state.requestStreamFinished(trailers: nil, promise: nil), + .sendRequestEnd(trailers: nil, nil, .requestDone) + ) XCTAssertEqual( state.requestStreamPartReceived(part2, promise: nil), @@ -308,9 +317,12 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil, .none)) + XCTAssertEqual( + state.requestStreamFinished(trailers: nil, promise: nil), + .sendRequestEnd(trailers: nil, nil, .none) + ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) } func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200() { @@ -335,11 +347,14 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.none, [])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.none, [], nil)) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) - state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamFinished(trailers: nil, promise: nil).assertFailRequest( + HTTPClientError.bodyLengthMismatch, + .close(nil) + ) XCTAssertEqual(state.channelInactive(), .wait) } @@ -368,7 +383,10 @@ class HTTPRequestStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) - state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamFinished(trailers: nil, promise: nil).assertFailRequest( + HTTPClientError.bodyLengthMismatch, + .close(nil) + ) XCTAssertEqual(state.channelRead(.end(nil)), .wait) } @@ -387,7 +405,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [responseBody], nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -430,7 +448,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -467,7 +485,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(part2)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([part2]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [part2], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -513,7 +531,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) } @@ -551,7 +569,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [responseBody], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -630,7 +648,7 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) } @@ -649,7 +667,7 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.idleReadTimeoutTriggered(), .wait, "A read timeout that fires to late must be ignored") } @@ -667,7 +685,7 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -705,7 +723,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -729,7 +747,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [body])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [body], nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -951,7 +969,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(part3)), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [part1, part2, part3])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [part1, part2, part3], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelInactive(), .wait) @@ -973,8 +991,12 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult - case (.sendRequestEnd(let lhsPromise, let lhsAction), .sendRequestEnd(let rhsPromise, let rhsAction)): - return lhsPromise?.futureResult == rhsPromise?.futureResult && lhsAction == rhsAction + case ( + .sendRequestEnd(let lhsTrailers, let lhsPromise, let lhsAction), + .sendRequestEnd(let rhsTrailers, let rhsPromise, let rhsAction) + ): + return lhsTrailers == rhsTrailers && lhsPromise?.futureResult == rhsPromise?.futureResult + && lhsAction == rhsAction case (.pauseRequestBodyStream, .pauseRequestBodyStream): return true @@ -991,10 +1013,10 @@ extension HTTPRequestStateMachine.Action: Equatable { return lhsData == rhsData case ( - .forwardResponseEnd(let lhsFinalAction, let lhsFinalBuffer), - .forwardResponseEnd(let rhsFinalAction, let rhsFinalBuffer) + .forwardResponseEnd(let lhsFinalAction, let lhsFinalBuffer, let lhsTrailers), + .forwardResponseEnd(let rhsFinalAction, let rhsFinalBuffer, let rhsTrailers) ): - return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer + return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer && lhsTrailers == rhsTrailers case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): return lhsFinalAction == rhsFinalAction diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift index 3347cac2e..227bbeff3 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift @@ -72,15 +72,64 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { private let file: StaticString private let line: UInt - let willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil - let requestHeadSentCallback: (@Sendable () -> Void)? = nil - let resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil - let pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil - let requestBodyStreamSentCallback: (@Sendable () -> Void)? = nil - let receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil - let receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil - let receiveResponseEndCallback: (@Sendable (CircularBuffer?, HTTPHeaders?) -> Void)? = nil - let failCallback: (@Sendable (Error) -> Void)? = nil + struct Callbacks { + var willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil + var requestHeadSentCallback: (@Sendable () -> Void)? = nil + var resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + var pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + var requestBodyStreamSentCallback: (@Sendable () -> Void)? = nil + var receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil + var receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil + var receiveResponseEndCallback: (@Sendable (CircularBuffer?, HTTPHeaders?) -> Void)? = nil + var failCallback: (@Sendable (Error) -> Void)? = nil + } + + let callbacks: NIOLockedValueBox = .init(.init()) + + var willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? { + get { self.callbacks.withLockedValue { $0.willExecuteRequestCallback } } + set { self.callbacks.withLockedValue { $0.willExecuteRequestCallback = newValue } } + } + + var requestHeadSentCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.requestHeadSentCallback } } + set { self.callbacks.withLockedValue { $0.requestHeadSentCallback = newValue } } + } + + var resumeRequestBodyStreamCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.resumeRequestBodyStreamCallback } } + set { self.callbacks.withLockedValue { $0.resumeRequestBodyStreamCallback = newValue } } + } + + var pauseRequestBodyStreamCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.pauseRequestBodyStreamCallback } } + set { self.callbacks.withLockedValue { $0.pauseRequestBodyStreamCallback = newValue } } + } + + var requestBodyStreamSentCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.requestBodyStreamSentCallback } } + set { self.callbacks.withLockedValue { $0.requestBodyStreamSentCallback = newValue } } + } + + var receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? { + get { self.callbacks.withLockedValue { $0.receiveResponseHeadCallback } } + set { self.callbacks.withLockedValue { $0.receiveResponseHeadCallback = newValue } } + } + + var receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? { + get { self.callbacks.withLockedValue { $0.receiveResponseBodyPartsCallback } } + set { self.callbacks.withLockedValue { $0.receiveResponseBodyPartsCallback = newValue } } + } + + var receiveResponseEndCallback: (@Sendable (CircularBuffer?, HTTPHeaders?) -> Void)? { + get { self.callbacks.withLockedValue { $0.receiveResponseEndCallback } } + set { self.callbacks.withLockedValue { $0.receiveResponseEndCallback = newValue } } + } + + var failCallback: (@Sendable (Error) -> Void)? { + get { self.callbacks.withLockedValue { $0.failCallback } } + set { self.callbacks.withLockedValue { $0.failCallback = newValue } } + } /// captures all ``HTTPExecutableRequest`` method calls in the order of occurrence, including arguments. /// If you are not interested in the arguments you can use `events.map(\.kind)` to get all events without arguments. @@ -170,11 +219,11 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { } func receiveResponseEnd(_ buffer: CircularBuffer?, trailers: HTTPHeaders?) { - self.events.append(.receiveResponseEnd(buffer, nil)) + self.events.append(.receiveResponseEnd(buffer, trailers)) guard let receiveResponseEndCallback = self.receiveResponseEndCallback else { return self.calledUnimplementedMethod(#function) } - receiveResponseEndCallback(buffer, nil) + receiveResponseEndCallback(buffer, trailers) } func fail(_ error: Error) { diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index e5d9caa8e..12080f887 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -14,6 +14,7 @@ import NIOConcurrencyHelpers import NIOCore +import NIOHTTP1 @testable import AsyncHTTPClient @@ -212,7 +213,11 @@ extension MockRequestExecutor: HTTPRequestExecutor { promise?.succeed(()) } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + func finishRequestBodyStream( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { self.writeNextRequestPart(.endOfStream, request: request) promise?.succeed(()) } diff --git a/Tests/AsyncHTTPClientTests/TransactionTests.swift b/Tests/AsyncHTTPClientTests/TransactionTests.swift index dda216975..b936b7155 100644 --- a/Tests/AsyncHTTPClientTests/TransactionTests.swift +++ b/Tests/AsyncHTTPClientTests/TransactionTests.swift @@ -98,7 +98,8 @@ final class TransactionTests: XCTestCase { } func finishRequestBodyStream( - _ task: AsyncHTTPClient.HTTPExecutableRequest, + trailers: HTTPHeaders?, + request: AsyncHTTPClient.HTTPExecutableRequest, promise: NIOCore.EventLoopPromise? ) { XCTFail()