import Foundation import NetworkExtension import os.log import Network class TransparentProxyManager { public let flow: NEAppProxyTCPFlow private let connection: NWConnection private let log = OSLog(subsystem: "com.proxy.tcp.network.extension", category: "provider") private let connectionQueue = DispatchQueue(label: "com.proxy.tcp.network.extension.TCPQueue") // Properties for domain extraction and filtering private var isFirstPacket = true private var blockConnection = false private var extractedDomain: String? init(flow: NEAppProxyTCPFlow, endpoint: NWHostEndpoint) { self.flow = flow let host = NWEndpoint.Host(endpoint.hostname) let port = NWEndpoint.Port(endpoint.port) ?? NWEndpoint.Port.any os_log(.debug, log: self.log, "On TCP init. host: %{public}@, port: %{public}@", String(describing: host), String(describing: port)) self.connection = NWConnection(host: host, port: port, using: .tcp) } public func startExchangingData() { connection.stateUpdateHandler = self.stateChangedCallback(to:) connection.start(queue: connectionQueue) } private func stateChangedCallback(to state: NWConnection.State) { switch state { case .ready: os_log(.debug, log: self.log, "TransparentProxyManager::stateChangedCallback. TCP connection is ready %{public}@", "") // Call the correct open method flow.open(withLocalEndpoint: nil) { error in if let error = error { os_log(.error, log: self.log, "TCP flow opening failed %{public}@", error.localizedDescription) self.connection.cancel() return } os_log(.debug, log: self.log, "TCP flow opened successfully, starting data exchange %{public}@", "") self.handleOutgoingTCPData() } case .failed(let error): os_log(.error, log: self.log, "TransparentProxyManager::stateChangedCallback. TCP connection failed %{public}@", error.localizedDescription) self.connection.cancel() self.flow.closeReadWithError(error) self.flow.closeWriteWithError(error) case .cancelled: os_log(.debug, log: self.log, "TransparentProxyManager::stateChangedCallback. TCP connection is cancelled %{public}@", "") self.flow.closeReadWithError(nil) self.flow.closeWriteWithError(nil) case .preparing: os_log(.debug, log: self.log, "TransparentProxyManager::stateChangedCallback. TCP connection is preparing %{public}@", "") case .setup: os_log(.debug, log: self.log, "TransparentProxyManager::stateChangedCallback. TCP connection is in setup state %{public}@", "") case .waiting(let error): os_log(.error, log: self.log, "TransparentProxyManager::stateChangedCallback. TCP connection is in waiting state, %{public}@", error.localizedDescription) self.connection.cancel() self.flow.closeReadWithError(error) self.flow.closeWriteWithError(error) default: os_log(.debug, log: self.log, "TransparentProxyManager::stateChangedCallback. State is unknown %{public}@", "") } } private func handleOutgoingTCPData() { os_log(.debug, log: self.log, "handleOutgoingTCPData. handling TCP flow data %{public}@", "") flow.readData { [weak self] data, error in guard let self = self else { return } if let error = error { os_log(.error, log: self.log, "handleOutgoingTCPData. Error on handling TCP flow data: %{public}@", error.localizedDescription) self.connection.cancel() self.flow.closeReadWithError(error) self.flow.closeWriteWithError(error) return } if let data = data, !data.isEmpty { // Check if this is the first packet if self.isFirstPacket { self.isFirstPacket = false // Determine if this is HTTPS (port 443) or HTTP (port 80) let isHttps = self.connection.endpoint.debugDescription.contains(":443") let isHttp = self.connection.endpoint.debugDescription.contains(":80") if isHttps { // Try to extract domain from TLS ClientHello (SNI) if let domain = self.extractSNI(from: data) { self.extractedDomain = domain os_log(.debug, log: self.log, "Extracted SNI domain: %{public}@", domain) // Check if domain is allowed self.checkDomain(domain) { allowed in if !allowed { os_log(.info, log: self.log, "Blocking HTTPS connection to: %{public}@", domain) self.blockConnection = true self.connection.cancel() self.flow.closeReadWithError(nil) self.flow.closeWriteWithError(nil) return } // Domain is allowed, continue with the connection self.sendData(data) } return } else { os_log(.debug, log: self.log, "Could not extract SNI from HTTPS connection, allowing anyway %{public}@", "") } } else if isHttp { // Try to extract Host header from HTTP request if let domain = self.extractHTTPHost(from: data) { self.extractedDomain = domain os_log(.debug, log: self.log, "Extracted HTTP Host: %{public}@", domain) self.checkDomain(domain) { allowed in if !allowed { os_log(.info, log: self.log, "Blocking HTTP connection to: %{public}@", domain) self.blockConnection = true self.connection.cancel() self.flow.closeReadWithError(nil) self.flow.closeWriteWithError(nil) return } // Domain is allowed, continue with the connection self.sendData(data) } return } else { os_log(.debug, log: self.log, "Could not extract Host from HTTP connection, allowing anyway %{public}@", "") } } } // If we reach here, either: // 1. Not the first packet, or // 2. Couldn't extract domain, or // 3. Not HTTP/HTTPS traffic if !self.blockConnection { self.sendData(data) } } else if data?.isEmpty == true { // Empty data but not an error - continue reading self.handleOutgoingTCPData() } else { // No data, likely EOF if !self.blockConnection { os_log(.debug, log: self.log, "No more outgoing data, switching to reading incoming data %{public}@", "") self.handleIncomingTCPData() } } } } private func sendData(_ data: Data) { self.connection.send(content: data, completion: .contentProcessed({ [weak self] error in guard let self = self else { return } if let error = error { os_log(.error, log: self.log, "TransparentProxyManager::sendDataToEndpoint. TCP error: %{public}@", error.localizedDescription) self.connection.cancel() self.flow.closeReadWithError(error) self.flow.closeWriteWithError(error) return } os_log(.debug, log: self.log, "TCP data sent successfully %{public}@", "") self.handleOutgoingTCPData() })) } private func handleIncomingTCPData() { os_log(.debug, log: self.log, "On handleIncomingTCPData %{public}@", "") connection.receive(minimumIncompleteLength: 1, maximumLength: 16384) { [weak self] (data, _, isComplete, error) in guard let self = self else { return } if let error = error { os_log(.error, log: self.log, "On handleIncomingTCPData. Error: %{public}@", error.localizedDescription) self.connection.cancel() self.flow.closeReadWithError(error) self.flow.closeWriteWithError(error) return } if let data = data, !data.isEmpty { if self.blockConnection { return } os_log(.debug, log: self.log, "On handleIncomingTCPData. Received %{public}d bytes of data", data.count) self.flow.write(data) { writeError in if let writeError = writeError { os_log(.error, log: self.log, "On handleIncomingTCPData. Write error: %{public}@", writeError.localizedDescription) self.connection.cancel() self.flow.closeReadWithError(writeError) self.flow.closeWriteWithError(writeError) return } // Continue reading self.handleIncomingTCPData() } } else if isComplete { os_log(.debug, log: self.log, "On handleIncomingTCPData. Connection is completed %{public}@", "") self.connection.stateUpdateHandler = nil self.connection.cancel() self.flow.closeReadWithError(nil) self.flow.closeWriteWithError(nil) } else { // Empty data but not complete - this shouldn't normally happen os_log(.debug, log: self.log, "On handleIncomingTCPData. Received empty data, continuing %{public}@", "") self.handleIncomingTCPData() } } } // MARK: - Domain Extraction Functions // Extract Server Name Indication from TLS ClientHello private func extractSNI(from data: Data) -> String? { // Verify this looks like a TLS ClientHello message guard data.count > 43 && data[0] == 0x16 else { return nil } let bytes = [UInt8](data) // Check if it's a ClientHello message (handshake type 1) guard bytes.count > 5 && bytes[5] == 0x01 else { return nil } // Skip fixed header (5 bytes) + handshake message type (1 byte) + length (3 bytes) var position = 9 // Skip client version (2 bytes) + random (32 bytes) position += 34 // Skip session ID if position + 1 < bytes.count { let sessionIDLength = Int(bytes[position]) position += 1 + sessionIDLength } else { return nil } // Skip cipher suites if position + 2 < bytes.count { let cipherSuitesLength = (Int(bytes[position]) << 8) | Int(bytes[position + 1]) position += 2 + cipherSuitesLength } else { return nil } // Skip compression methods if position + 1 < bytes.count { let compressionMethodsLength = Int(bytes[position]) position += 1 + compressionMethodsLength } else { return nil } // Check if we have extensions if position + 2 >= bytes.count { return nil } let extensionsLength = (Int(bytes[position]) << 8) | Int(bytes[position + 1]) position += 2 let extensionsEnd = position + extensionsLength if extensionsEnd > bytes.count { return nil } // Parse extensions while position + 4 < extensionsEnd { let extensionType = (Int(bytes[position]) << 8) | Int(bytes[position + 1]) let extensionLength = (Int(bytes[position + 2]) << 8) | Int(bytes[position + 3]) position += 4 // ServerName extension (0) if extensionType == 0 && position + 2 < extensionsEnd { // Skip the server name list length let serverNameListLength = (Int(bytes[position]) << 8) | Int(bytes[position + 1]) position += 2 // Parse server name entries let serverNameListEnd = position + serverNameListLength if serverNameListEnd > extensionsEnd { return nil } while position + 3 < serverNameListEnd { let nameType = bytes[position] let nameLength = (Int(bytes[position + 1]) << 8) | Int(bytes[position + 2]) position += 3 // HostName (0) if nameType == 0 && position + nameLength <= serverNameListEnd { let nameBytes = bytes[position..<(position + nameLength)] if let hostname = String(bytes: nameBytes, encoding: .utf8) { os_log(.debug, log: self.log, "Hostname found: %{public}@", hostname) return hostname } } position += nameLength if position > serverNameListEnd { break } } } else { // Skip this extension position += extensionLength if position > extensionsEnd { break } } } return nil } // Extract host from HTTP request private func extractHTTPHost(from data: Data) -> String? { guard let httpRequest = String(data: data, encoding: .utf8) else { return nil } // Basic verification that this looks like an HTTP request if !httpRequest.starts(with: "GET ") && !httpRequest.starts(with: "POST ") && !httpRequest.starts(with: "HEAD ") && !httpRequest.starts(with: "PUT ") && !httpRequest.starts(with: "DELETE ") { return nil } // Look for the Host: header in HTTP request let lines = httpRequest.components(separatedBy: "\r\n") for line in lines { if line.lowercased().starts(with: "host:") { let host = line.dropFirst(5).trimmingCharacters(in: .whitespaces) // Strip off port if present if let colonIndex = host.firstIndex(of: ":") { return String(host[.. Void) { // Implement your domain filtering logic here os_log(.debug, log: self.log, "Checking domain: %{public}@", domain) // For now, allow all domains for testing completion(true) } }