Skip to content
11 changes: 10 additions & 1 deletion Sources/OpenAI/OpenAI+OpenAIAsync.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ extension OpenAI: OpenAIAsync {
chatsStream(query: query, onResult: onResult, completion: completion)
}
}


public func chatsStream(
query: ChatQuery,
onWebSearchEvent: @escaping @Sendable (WebSearchEvent) -> Void
) -> AsyncThrowingStream<ChatStreamResult, Error> {
makeAsyncStream { onResult, completion in
chatsStream(query: query, onResult: onResult, onWebSearchEvent: onWebSearchEvent, completion: completion)
}
}

public func model(query: ModelQuery) async throws -> ModelResult {
try await performRequestAsync(
request: makeModelRequest(query: query)
Expand Down
69 changes: 68 additions & 1 deletion Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,62 @@ final public class OpenAI: OpenAIProtocol, @unchecked Sendable {
)
}

/// Creates an OpenAI client with a custom URLSession protocol implementation.
///
/// - Important: This initializer only uses the custom session for non-streaming requests.
/// For streaming requests, use the initializer that accepts a `URLSessionFactory`.
///
/// - Parameters:
/// - configuration: The client configuration
/// - customSession: Custom URLSession protocol implementation
/// - middlewares: Optional middlewares for request/response interception
public convenience init(
configuration: Configuration,
customSession: any URLSessionProtocol,
middlewares: [OpenAIMiddleware] = []
) {
let streamingSessionFactory = ImplicitURLSessionStreamingSessionFactory(
middlewares: middlewares,
parsingOptions: configuration.parsingOptions,
sslDelegate: nil
)

self.init(
configuration: configuration,
session: customSession,
streamingSessionFactory: streamingSessionFactory,
middlewares: middlewares
)
}

/// Creates an OpenAI client with custom session handling for both regular and streaming requests.
///
/// - Parameters:
/// - configuration: The client configuration
/// - customSession: Custom URLSession protocol implementation for non-streaming requests
/// - streamingURLSessionFactory: Factory for creating sessions for streaming requests
/// - middlewares: Optional middlewares for request/response interception
public convenience init(
configuration: Configuration,
customSession: any URLSessionProtocol,
streamingURLSessionFactory: URLSessionFactory,
middlewares: [OpenAIMiddleware] = []
) {
let streamingSessionFactory = ImplicitURLSessionStreamingSessionFactory(
urlSessionFactory: streamingURLSessionFactory,
middlewares: middlewares,
parsingOptions: configuration.parsingOptions,
sslDelegate: nil
)

self.init(
configuration: configuration,
session: customSession,
streamingSessionFactory: streamingSessionFactory,
middlewares: middlewares
)
}

init(
configuration: Configuration,
session: URLSessionProtocol,
Expand Down Expand Up @@ -284,9 +340,19 @@ final public class OpenAI: OpenAIProtocol, @unchecked Sendable {
}

public func chatsStream(query: ChatQuery, onResult: @escaping @Sendable (Result<ChatStreamResult, Error>) -> Void, completion: (@Sendable (Error?) -> Void)?) -> CancellableRequest {
chatsStream(query: query, onResult: onResult, onWebSearchEvent: nil, completion: completion)
}

public func chatsStream(
query: ChatQuery,
onResult: @escaping @Sendable (Result<ChatStreamResult, Error>) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?,
completion: (@Sendable (Error?) -> Void)?
) -> CancellableRequest {
performStreamingRequest(
request: JSONRequest<ChatStreamResult>(body: query.makeStreamable(), url: buildURL(path: .chats)),
onResult: onResult,
onWebSearchEvent: onWebSearchEvent,
completion: completion
)
}
Expand Down Expand Up @@ -355,9 +421,10 @@ extension OpenAI {
func performStreamingRequest<ResultType: Codable & Sendable>(
request: any URLRequestBuildable,
onResult: @escaping @Sendable (Result<ResultType, Error>) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? = nil,
completion: (@Sendable (Error?) -> Void)?
) -> CancellableRequest {
streamingClient.performStreamingRequest(request: request, onResult: onResult, completion: completion)
streamingClient.performStreamingRequest(request: request, onResult: onResult, onWebSearchEvent: onWebSearchEvent, completion: completion)
}

func performSpeechRequest(request: any URLRequestBuildable, completion: @escaping @Sendable (Result<AudioSpeechResult, Error>) -> Void) -> CancellableRequest {
Expand Down
25 changes: 15 additions & 10 deletions Sources/OpenAI/Private/Client/StreamingClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ final class StreamingClient: @unchecked Sendable {
func performStreamingRequest<ResultType: Codable & Sendable>(
request: any URLRequestBuildable,
onResult: @escaping @Sendable (Result<ResultType, Error>) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? = nil,
completion: (@Sendable (Error?) -> Void)?
) -> CancellableRequest {
do {
Expand All @@ -41,16 +42,20 @@ final class StreamingClient: @unchecked Sendable {
}

let session = streamingSessionFactory.makeServerSentEventsStreamingSession(
urlRequest: interceptedRequest
) { _, object in
onResult(.success(object))
} onProcessingError: { _, error in
onResult(.failure(error))
} onComplete: { [weak self] session, error in
completion?(error)
self?.invalidateSession(session)
}

urlRequest: interceptedRequest,
onReceiveContent: { _, object in
onResult(.success(object))
},
onWebSearchEvent: onWebSearchEvent,
onProcessingError: { _, error in
onResult(.failure(error))
},
onComplete: { [weak self] session, error in
completion?(error)
self?.invalidateSession(session)
}
)

return runSession(session)
} catch {
completion?(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import Foundation

protocol InvalidatableSession: Sendable {
public protocol InvalidatableSession: Sendable {
func invalidateAndCancel()
func finishTasksAndInvalidate()
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ final class ServerSentEventsStreamInterpreter <ResultType: Codable & Sendable>:
private var previousChunkBuffer = ""

private var onEventDispatched: ((ResultType) -> Void)?
private var onWebSearchEvent: ((WebSearchEvent) -> Void)?
private var onError: ((Error) -> Void)?
private let parsingOptions: ParsingOptions

Expand All @@ -39,8 +40,26 @@ final class ServerSentEventsStreamInterpreter <ResultType: Codable & Sendable>:
/// - Parameters:
/// - onEventDispatched: Can be called multiple times per `processData`
/// - onError: Will only be called once per `processData`
func setCallbackClosures(onEventDispatched: @escaping @Sendable (ResultType) -> Void, onError: @escaping @Sendable (Error) -> Void) {
func setCallbackClosures(
onEventDispatched: @escaping @Sendable (ResultType) -> Void,
onError: @escaping @Sendable (Error) -> Void
) {
setCallbackClosures(onEventDispatched: onEventDispatched, onWebSearchEvent: nil, onError: onError)
}

/// Sets closures an instance of type. Not thread safe.
///
/// - Parameters:
/// - onEventDispatched: Can be called multiple times per `processData`
/// - onWebSearchEvent: Called when a web search event is received (optional)
/// - onError: Will only be called once per `processData`
func setCallbackClosures(
onEventDispatched: @escaping @Sendable (ResultType) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?,
onError: @escaping @Sendable (Error) -> Void
) {
self.onEventDispatched = onEventDispatched
self.onWebSearchEvent = onWebSearchEvent
self.onError = onError
}

Expand All @@ -66,7 +85,21 @@ final class ServerSentEventsStreamInterpreter <ResultType: Codable & Sendable>:
onError?(StreamingError.unknownContent)
return
}


// Handle web search events (they have "type" field instead of "object")
// Event types include: "web_search_call", or prefixed like "response.web_search_call.*"
if let json = try? JSONSerialization.jsonObject(with: jsonData) as? [String: Any],
let eventType = json["type"] as? String,
eventType.contains("web_search") {
do {
let webSearchEvent = try JSONDecoder().decode(WebSearchEvent.self, from: jsonData)
onWebSearchEvent?(webSearchEvent)
} catch {
onError?(error)
}
return
}

let decoder = JSONResponseDecoder(parsingOptions: parsingOptions)
do {
let object: ResultType = try decoder.decodeResponseData(jsonData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ protocol StreamingSessionFactory: Sendable {
func makeServerSentEventsStreamingSession<ResultType: Codable & Sendable>(
urlRequest: URLRequest,
onReceiveContent: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, ResultType) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?,
onProcessingError: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, Error) -> Void,
onComplete: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, Error?) -> Void
) -> StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>
Expand All @@ -35,34 +36,51 @@ protocol StreamingSessionFactory: Sendable {
}

struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory {
let urlSessionFactory: URLSessionFactory
let middlewares: [OpenAIMiddleware]
let parsingOptions: ParsingOptions
let sslDelegate: SSLDelegateProtocol?


init(
urlSessionFactory: URLSessionFactory = FoundationURLSessionFactory(),
middlewares: [OpenAIMiddleware],
parsingOptions: ParsingOptions,
sslDelegate: SSLDelegateProtocol?
) {
self.urlSessionFactory = urlSessionFactory
self.middlewares = middlewares
self.parsingOptions = parsingOptions
self.sslDelegate = sslDelegate
}

func makeServerSentEventsStreamingSession<ResultType>(
urlRequest: URLRequest,
onReceiveContent: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, ResultType) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?,
onProcessingError: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, any Error) -> Void,
onComplete: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, (any Error)?) -> Void
) -> StreamingSession<ServerSentEventsStreamInterpreter<ResultType>> where ResultType : Decodable, ResultType : Encodable, ResultType : Sendable {
.init(
urlSessionFactory: urlSessionFactory,
urlRequest: urlRequest,
interpreter: .init(parsingOptions: parsingOptions),
sslDelegate: sslDelegate,
middlewares: middlewares,
onReceiveContent: onReceiveContent,
onWebSearchEvent: onWebSearchEvent,
onProcessingError: onProcessingError,
onComplete: onComplete
)
}

func makeAudioSpeechStreamingSession(
urlRequest: URLRequest,
onReceiveContent: @Sendable @escaping (StreamingSession<AudioSpeechStreamInterpreter>, AudioSpeechResult) -> Void,
onProcessingError: @Sendable @escaping (StreamingSession<AudioSpeechStreamInterpreter>, any Error) -> Void,
onComplete: @Sendable @escaping (StreamingSession<AudioSpeechStreamInterpreter>, (any Error)?) -> Void
) -> StreamingSession<AudioSpeechStreamInterpreter> {
.init(
urlSessionFactory: urlSessionFactory,
urlRequest: urlRequest,
interpreter: .init(),
sslDelegate: sslDelegate,
Expand All @@ -72,14 +90,15 @@ struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory {
onComplete: onComplete
)
}

func makeModelResponseStreamingSession(
urlRequest: URLRequest,
onReceiveContent: @Sendable @escaping (StreamingSession<ModelResponseEventsStreamInterpreter>, ResponseStreamEvent) -> Void,
onProcessingError: @Sendable @escaping (StreamingSession<ModelResponseEventsStreamInterpreter>, any Error) -> Void,
onComplete: @Sendable @escaping (StreamingSession<ModelResponseEventsStreamInterpreter>, (any Error)?) -> Void
) -> StreamingSession<ModelResponseEventsStreamInterpreter> {
.init(
urlSessionFactory: urlSessionFactory,
urlRequest: urlRequest,
interpreter: .init(),
sslDelegate: sslDelegate,
Expand Down
28 changes: 22 additions & 6 deletions Sources/OpenAI/Private/Streaming/StreamingSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ final class StreamingSession<Interpreter: StreamInterpreter>: NSObject, Identifi
private let middlewares: [OpenAIMiddleware]
private let executionSerializer: ExecutionSerializer
private let onReceiveContent: (@Sendable (StreamingSession, ResultType) -> Void)?
private let onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?
private let onProcessingError: (@Sendable (StreamingSession, Error) -> Void)?
private let onComplete: (@Sendable (StreamingSession, Error?) -> Void)?

Expand All @@ -32,6 +33,7 @@ final class StreamingSession<Interpreter: StreamInterpreter>: NSObject, Identifi
middlewares: [OpenAIMiddleware],
executionSerializer: ExecutionSerializer = GCDQueueAsyncExecutionSerializer(queue: .userInitiated),
onReceiveContent: @escaping @Sendable (StreamingSession, ResultType) -> Void,
onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? = nil,
onProcessingError: @escaping @Sendable (StreamingSession, Error) -> Void,
onComplete: @escaping @Sendable (StreamingSession, Error?) -> Void
) {
Expand All @@ -42,6 +44,7 @@ final class StreamingSession<Interpreter: StreamInterpreter>: NSObject, Identifi
self.middlewares = middlewares
self.executionSerializer = executionSerializer
self.onReceiveContent = onReceiveContent
self.onWebSearchEvent = onWebSearchEvent
self.onProcessingError = onProcessingError
self.onComplete = onComplete
super.init()
Expand Down Expand Up @@ -96,12 +99,25 @@ final class StreamingSession<Interpreter: StreamInterpreter>: NSObject, Identifi
}

private func subscribeToParser() {
interpreter.setCallbackClosures { [weak self] content in
guard let self else { return }
self.onReceiveContent?(self, content)
} onError: { [weak self] error in
guard let self else { return }
self.onProcessingError?(self, error)
// Check if interpreter supports web search events (ServerSentEventsStreamInterpreter)
if let sseInterpreter = interpreter as? ServerSentEventsStreamInterpreter<ResultType> {
sseInterpreter.setCallbackClosures { [weak self] content in
guard let self else { return }
self.onReceiveContent?(self, content)
} onWebSearchEvent: { [weak self] event in
self?.onWebSearchEvent?(event)
} onError: { [weak self] error in
guard let self else { return }
self.onProcessingError?(self, error)
}
} else {
interpreter.setCallbackClosures { [weak self] content in
guard let self else { return }
self.onReceiveContent?(self, content)
} onError: { [weak self] error in
guard let self else { return }
self.onProcessingError?(self, error)
}
}
}
}
6 changes: 3 additions & 3 deletions Sources/OpenAI/Private/URLSessionCombine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ import FoundationNetworking
#if canImport(Combine)
import Combine

protocol URLSessionCombine {
public protocol URLSessionCombine {
func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError>
}

extension URLSession: URLSessionCombine {
func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> {
public func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> {
let typedPublisher: URLSession.DataTaskPublisher = dataTaskPublisher(for: request)
return typedPublisher.eraseToAnyPublisher()
}
}

#else
protocol URLSessionCombine {
public protocol URLSessionCombine {
}

extension URLSession: URLSessionCombine {}
Expand Down
4 changes: 2 additions & 2 deletions Sources/OpenAI/Private/URLSessionDataTaskProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import Foundation
import FoundationNetworking
#endif

protocol URLSessionTaskProtocol: Sendable {
public protocol URLSessionTaskProtocol: Sendable {
var originalRequest: URLRequest? { get }
func cancel()
}

extension URLSessionTask: URLSessionTaskProtocol {}

protocol URLSessionDataTaskProtocol: URLSessionTaskProtocol {
public protocol URLSessionDataTaskProtocol: URLSessionTaskProtocol {
func resume()
}

Expand Down
Loading