Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 16 additions & 74 deletions Sources/Containerization/UnixSocketRelay.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,63 +21,7 @@ import Foundation
import Logging
import Synchronization

package actor UnixSocketRelayManager {
private let vm: any VirtualMachineInstance
private var relays: [String: SocketRelay]
private let q: DispatchQueue
private let log: Logger?

init(vm: any VirtualMachineInstance, log: Logger? = nil) {
self.vm = vm
self.relays = [:]
self.q = DispatchQueue(label: "com.apple.containerization.socket-relay")
self.log = log
}
}

extension UnixSocketRelayManager {
func start(port: UInt32, socket: UnixSocketConfiguration) async throws {
guard self.relays[socket.id] == nil else {
throw ContainerizationError(
.invalidState,
message: "socket relay \(socket.id) already started"
)
}

let socketRelay = try SocketRelay(
port: port,
socket: socket,
vm: self.vm,
queue: self.q,
log: self.log
)

do {
self.relays[socket.id] = socketRelay
try await socketRelay.start()
} catch {
self.relays.removeValue(forKey: socket.id)
}
}

func stop(socket: UnixSocketConfiguration) async throws {
guard let storedRelay = self.relays.removeValue(forKey: socket.id) else {
throw ContainerizationError(
.notFound,
message: "failed to stop socket relay"
)
}
try storedRelay.stop()
}

func stopAll() async throws {
for (_, relay) in self.relays {
try relay.stop()
}
}
}

package final class SocketRelay: Sendable {
package final class UnixSocketRelay: Sendable {
private let port: UInt32
private let configuration: UnixSocketConfiguration
private let log: Logger?
Expand Down Expand Up @@ -107,11 +51,11 @@ package final class SocketRelay: Sendable {
}

deinit {
self.state.withLock { $0.t?.cancel() }
state.withLock { $0.t?.cancel() }
}
}

extension SocketRelay {
extension UnixSocketRelay {
func start() async throws {
switch configuration.direction {
case .outOf:
Expand All @@ -122,7 +66,7 @@ extension SocketRelay {
}

func stop() throws {
try self.state.withLock {
try state.withLock {
guard let t = $0.t else {
throw ContainerizationError(
.invalidState,
Expand All @@ -148,7 +92,7 @@ extension SocketRelay {
}

private func setupHostVsockDial() async throws {
let hostConn = self.configuration.destination
let hostConn = configuration.destination

let socketType = try UnixType(
path: hostConn.path,
Expand All @@ -161,10 +105,10 @@ extension SocketRelay {
"listening on host UDS",
metadata: [
"path": "\(hostConn.path)",
"vport": "\(self.port)",
"vport": "\(port)",
])
let connectionStream = try hostSocket.acceptStream(closeOnDeinit: false)
self.state.withLock {
state.withLock {
$0.t = Task {
do {
for try await connection in connectionStream {
Expand All @@ -184,19 +128,17 @@ extension SocketRelay {
}

private func setupHostVsockListener() throws {
let hostPath = self.configuration.source
let port = self.port
let log = self.log
let hostPath = configuration.source

let listener = try self.vm.listen(self.port)
let listener = try vm.listen(port)
log?.info(
"listening on guest vsock",
metadata: [
"path": "\(hostPath)",
"vport": "\(port)",
])

self.state.withLock {
state.withLock {
$0.listener = listener
$0.t = Task {
do {
Expand All @@ -205,12 +147,12 @@ extension SocketRelay {
try await self.handleGuestVsockConn(
vsockConn: connection,
hostConnectionPath: hostPath,
port: port,
log: log
port: self.port,
log: self.log
)
}
} catch {
log?.error("failed to setup relay between vsock \(port) and \(hostPath.path): \(error)")
self.log?.error("failed to setup relay between vsock \(self.port) and \(hostPath.path): \(error)")
}
}
}
Expand Down Expand Up @@ -282,11 +224,11 @@ extension SocketRelay {
let relay = BidirectionalRelay(
fd1: hostFd,
fd2: guestFd,
queue: self.q,
log: self.log
queue: q,
log: log
)

self.state.withLock {
state.withLock {
$0.activeRelays[relayID] = relay
}

Expand Down
75 changes: 75 additions & 0 deletions Sources/Containerization/UnixSocketRelayManager.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//===----------------------------------------------------------------------===//
// Copyright © 2026 Apple Inc. and the Containerization project authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===----------------------------------------------------------------------===//

import ContainerizationError
import Foundation
import Logging

package actor UnixSocketRelayManager {
private let vm: any VirtualMachineInstance
private var relays: [String: UnixSocketRelay]
private let q: DispatchQueue
private let log: Logger?

init(vm: any VirtualMachineInstance, log: Logger? = nil) {
self.vm = vm
self.relays = [:]
self.q = DispatchQueue(label: "com.apple.containerization.socket-relay")
self.log = log
}
}

extension UnixSocketRelayManager {
func start(port: UInt32, socket: UnixSocketConfiguration) async throws {
guard relays[socket.id] == nil else {
throw ContainerizationError(
.invalidState,
message: "socket relay \(socket.id) already started"
)
}

let relay = try UnixSocketRelay(
port: port,
socket: socket,
vm: vm,
queue: q,
log: log
)

do {
relays[socket.id] = relay
try await relay.start()
} catch {
relays.removeValue(forKey: socket.id)
}
}

func stop(socket: UnixSocketConfiguration) async throws {
guard let storedRelay = relays.removeValue(forKey: socket.id) else {
throw ContainerizationError(
.notFound,
message: "failed to stop socket relay"
)
}
try storedRelay.stop()
}

func stopAll() async throws {
for (_, relay) in relays {
try relay.stop()
}
}
}
30 changes: 30 additions & 0 deletions Sources/ContainerizationOS/Socket/BidirectionalRelay.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ public final class BidirectionalRelay: Sendable {
let source2: DispatchSourceRead
}

private enum CompletionState {
case pending
case waiting(CheckedContinuation<Void, Never>)
case completed
}

private let state: Mutex<ConnectionSources?>
private let completionState: Mutex<CompletionState>

// The buffers aren't used concurrently.
private nonisolated(unsafe) let buffer1: UnsafeMutableBufferPointer<UInt8>
Expand All @@ -56,6 +63,7 @@ public final class BidirectionalRelay: Sendable {
self.queue = queue ?? DispatchQueue(label: "com.apple.containerization.bidirectional-relay")
self.log = log
self.state = Mutex(nil)
self.completionState = Mutex(.pending)

let pageSize = Int(getpagesize())
self.buffer1 = UnsafeMutableBufferPointer<UInt8>.allocate(capacity: pageSize)
Expand Down Expand Up @@ -134,6 +142,22 @@ public final class BidirectionalRelay: Sendable {
}
}

/// Waits for the relay to complete.
public func waitForCompletion() async {
await withCheckedContinuation { c in
completionState.withLock { state in
switch state {
case .pending:
state = .waiting(c)
case .waiting:
fatalError("waitForCompletion called multiple times")
case .completed:
c.resume()
}
}
}
}

private func fdCopyHandler(
buffer: UnsafeMutableBufferPointer<UInt8>,
source: DispatchSourceRead,
Expand Down Expand Up @@ -253,5 +277,11 @@ public final class BidirectionalRelay: Sendable {
)
close(fd1)
close(fd2)
completionState.withLock { state in
if case .waiting(let c) = state {
c.resume()
}
state = .completed
}
}
}
33 changes: 24 additions & 9 deletions vminitd/Sources/vminitd/Server+GRPC.swift
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,20 @@ extension Initd: Com_Apple_Containerization_Sandbox_V3_SandboxContextAsyncProvid
"action": "\(request.action)",
])

do {
let proxy = VsockProxy(
id: request.id,
action: request.action == .into ? .dial : .listen,
port: request.vsockPort,
path: URL(fileURLWithPath: request.guestPath),
udsPerms: request.guestSocketPermissions,
log: log
)
let proxy = VsockProxy(
id: request.id,
action: request.action == .into ? .dial : .listen,
port: request.vsockPort,
path: URL(fileURLWithPath: request.guestPath),
udsPerms: request.guestSocketPermissions,
log: log
)

do {
try await proxy.start()
try await state.add(proxy: proxy)
} catch {
try? await proxy.close()
log.error(
"proxyVsock",
metadata: [
Expand All @@ -222,6 +223,14 @@ extension Initd: Com_Apple_Containerization_Sandbox_V3_SandboxContextAsyncProvid
)
}

log.info(
"proxyVsock started",
metadata: [
"id": "\(request.id)",
"port": "\(request.vsockPort)",
"guestPath": "\(request.guestPath)",
])

return .init()
}

Expand Down Expand Up @@ -250,6 +259,12 @@ extension Initd: Com_Apple_Containerization_Sandbox_V3_SandboxContextAsyncProvid
)
}

log.info(
"stopVsockProxy completed",
metadata: [
"id": "\(request.id)"
])

return .init()
}

Expand Down