From 3afbc248b14164d68570a78dca7288d9762f1d75 Mon Sep 17 00:00:00 2001 From: albexk Date: Sat, 27 Jan 2024 16:55:05 +0300 Subject: [PATCH] Refactor split-tunneling: separate site addresses from routes --- .../android/awg/src/main/kotlin/AwgConfig.kt | 2 +- .../amnezia/vpn/protocol/openvpn/OpenVpn.kt | 13 +--- .../vpn/protocol/openvpn/OpenVpnConfig.kt | 2 +- .../protocolApi/src/main/kotlin/Protocol.kt | 44 +++--------- .../src/main/kotlin/ProtocolConfig.kt | 67 +++++++++++++++++-- .../vpn/protocol/wireguard/WireguardConfig.kt | 2 +- 6 files changed, 76 insertions(+), 54 deletions(-) diff --git a/client/android/awg/src/main/kotlin/AwgConfig.kt b/client/android/awg/src/main/kotlin/AwgConfig.kt index 372747f2..014c6e0a 100644 --- a/client/android/awg/src/main/kotlin/AwgConfig.kt +++ b/client/android/awg/src/main/kotlin/AwgConfig.kt @@ -99,7 +99,7 @@ class AwgConfig private constructor( fun setH3(h3: Long) = apply { this.h3 = h3 } fun setH4(h4: Long) = apply { this.h4 = h4 } - override fun build(): AwgConfig = AwgConfig(this) + override fun build(): AwgConfig = configBuild().run { AwgConfig(this@Builder) } } companion object { diff --git a/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpn.kt b/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpn.kt index 34069a0d..34f2934b 100644 --- a/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpn.kt +++ b/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpn.kt @@ -2,7 +2,6 @@ package org.amnezia.vpn.protocol.openvpn import android.content.Context import android.net.VpnService.Builder -import android.os.Build import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableStateFlow @@ -14,7 +13,6 @@ import org.amnezia.vpn.protocol.ProtocolState import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED import org.amnezia.vpn.protocol.Statistics import org.amnezia.vpn.protocol.VpnStartException -import org.amnezia.vpn.util.net.InetNetwork import org.amnezia.vpn.util.net.getLocalNetworks import org.json.JSONObject @@ -79,16 +77,7 @@ open class OpenVpn : Protocol() { if (evalConfig.error) { throw BadConfigException("OpenVPN config parse error: ${evalConfig.message}") } - configBuilder.apply { - // fix for split tunneling - // The exclude split tunneling OpenVpn configuration does not contain a default route. - // It is required for split tunneling in newer versions of Android. - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { - addRoute(InetNetwork("0.0.0.0", 0)) - addRoute(InetNetwork("::", 0)) - } - configSplitTunneling(config) - } + configBuilder.configSplitTunneling(config) scope.launch { val status = client.connect() diff --git a/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpnConfig.kt b/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpnConfig.kt index 36d8d93b..9554f978 100644 --- a/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpnConfig.kt +++ b/client/android/openvpn/src/main/kotlin/org/amnezia/vpn/protocol/openvpn/OpenVpnConfig.kt @@ -11,7 +11,7 @@ class OpenVpnConfig private constructor( class Builder : ProtocolConfig.Builder(false) { override var mtu: Int = OPENVPN_DEFAULT_MTU - override fun build(): OpenVpnConfig = OpenVpnConfig(this) + override fun build(): OpenVpnConfig = configBuild().run { OpenVpnConfig(this@Builder) } } companion object { diff --git a/client/android/protocolApi/src/main/kotlin/Protocol.kt b/client/android/protocolApi/src/main/kotlin/Protocol.kt index b2c52e9a..b729f9f7 100644 --- a/client/android/protocolApi/src/main/kotlin/Protocol.kt +++ b/client/android/protocolApi/src/main/kotlin/Protocol.kt @@ -14,8 +14,6 @@ import java.util.zip.ZipFile import kotlinx.coroutines.flow.MutableStateFlow import org.amnezia.vpn.util.Log import org.amnezia.vpn.util.net.InetNetwork -import org.amnezia.vpn.util.net.IpRange -import org.amnezia.vpn.util.net.IpRangeSet import org.json.JSONObject private const val TAG = "Protocol" @@ -53,40 +51,16 @@ abstract class Protocol { val splitTunnelType = config.optInt("splitTunnelType") if (splitTunnelType == SPLIT_TUNNEL_DISABLE) return val splitTunnelSites = config.getJSONArray("splitTunnelSites") - when (splitTunnelType) { - SPLIT_TUNNEL_INCLUDE -> { - // remove default routes, if any - removeRoute(InetNetwork("0.0.0.0", 0)) - removeRoute(InetNetwork("::", 0)) - // add routes from config - for (i in 0 until splitTunnelSites.length()) { - val address = InetNetwork.parse(splitTunnelSites.getString(i)) - addRoute(address) - } - } + val addressHandlerFunc = when (splitTunnelType) { + SPLIT_TUNNEL_INCLUDE -> ::includeAddress + SPLIT_TUNNEL_EXCLUDE -> ::excludeAddress - SPLIT_TUNNEL_EXCLUDE -> { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { - // exclude routes from config - for (i in 0 until splitTunnelSites.length()) { - val address = InetNetwork.parse(splitTunnelSites.getString(i)) - excludeRoute(address) - } - } else { - // For older versions of Android, build a list of subnets without excluded addresses - val ipRangeSet = IpRangeSet() - ipRangeSet.remove(IpRange("127.0.0.0", 8)) - for (i in 0 until splitTunnelSites.length()) { - val address = InetNetwork.parse(splitTunnelSites.getString(i)) - ipRangeSet.remove(IpRange(address)) - } - // remove default routes, if any - removeRoute(InetNetwork("0.0.0.0", 0)) - removeRoute(InetNetwork("::", 0)) - ipRangeSet.subnets().forEach(::addRoute) - addRoute(InetNetwork("2000::", 3)) - } - } + else -> throw BadConfigException("Unexpected value of the 'splitTunnelType' parameter: $splitTunnelType") + } + + for (i in 0 until splitTunnelSites.length()) { + val address = InetNetwork.parse(splitTunnelSites.getString(i)) + addressHandlerFunc(address) } } diff --git a/client/android/protocolApi/src/main/kotlin/ProtocolConfig.kt b/client/android/protocolApi/src/main/kotlin/ProtocolConfig.kt index a4d7683e..75ba1abf 100644 --- a/client/android/protocolApi/src/main/kotlin/ProtocolConfig.kt +++ b/client/android/protocolApi/src/main/kotlin/ProtocolConfig.kt @@ -5,6 +5,8 @@ import android.os.Build import androidx.annotation.RequiresApi import java.net.InetAddress import org.amnezia.vpn.util.net.InetNetwork +import org.amnezia.vpn.util.net.IpRange +import org.amnezia.vpn.util.net.IpRangeSet open class ProtocolConfig protected constructor( val addresses: Set, @@ -12,6 +14,8 @@ open class ProtocolConfig protected constructor( val searchDomain: String?, val routes: Set, val excludedRoutes: Set, + val includedAddresses: Set, + val excludedAddresses: Set, val excludedApplications: Set, val httpProxy: ProxyInfo?, val allowAllAF: Boolean, @@ -25,6 +29,8 @@ open class ProtocolConfig protected constructor( builder.searchDomain, builder.routes, builder.excludedRoutes, + builder.includedAddresses, + builder.excludedAddresses, builder.excludedApplications, builder.httpProxy, builder.allowAllAF, @@ -37,6 +43,8 @@ open class ProtocolConfig protected constructor( internal val dnsServers: MutableSet = hashSetOf() internal val routes: MutableSet = hashSetOf() internal val excludedRoutes: MutableSet = hashSetOf() + internal val includedAddresses: MutableSet = hashSetOf() + internal val excludedAddresses: MutableSet = hashSetOf() internal val excludedApplications: MutableSet = hashSetOf() internal var searchDomain: String? = null @@ -71,12 +79,15 @@ open class ProtocolConfig protected constructor( fun removeRoute(route: InetNetwork) = apply { this.routes.remove(route) } fun clearRoutes() = apply { this.routes.clear() } - @RequiresApi(Build.VERSION_CODES.TIRAMISU) fun excludeRoute(route: InetNetwork) = apply { this.excludedRoutes += route } - - @RequiresApi(Build.VERSION_CODES.TIRAMISU) fun excludeRoutes(routes: Collection) = apply { this.excludedRoutes += routes } + fun includeAddress(addr: InetNetwork) = apply { this.includedAddresses += addr } + fun includeAddresses(addresses: Collection) = apply { this.includedAddresses += addresses } + + fun excludeAddress(addr: InetNetwork) = apply { this.excludedAddresses += addr } + fun excludeAddresses(addresses: Collection) = apply { this.excludedAddresses += addresses } + fun excludeApplication(application: String) = apply { this.excludedApplications += application } fun excludeApplications(applications: Collection) = apply { this.excludedApplications += applications } @@ -91,6 +102,48 @@ open class ProtocolConfig protected constructor( fun setMtu(mtu: Int) = apply { this.mtu = mtu } + private fun processSplitTunneling() { + if (includedAddresses.isNotEmpty() && excludedAddresses.isNotEmpty()) { + throw BadConfigException("Config contains addresses for inclusive and exclusive split tunneling at the same time") + } + + if (includedAddresses.isNotEmpty()) { + // remove default routes, if any + removeRoute(InetNetwork("0.0.0.0", 0)) + removeRoute(InetNetwork("::", 0)) + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) { + // for older versions of Android, add the default route to the excluded routes + // to correctly build the excluded subnets list later + excludeRoute(InetNetwork("0.0.0.0", 0)) + } + addRoutes(includedAddresses) + } else if (excludedAddresses.isNotEmpty()) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + // default routes are required for split tunneling in newer versions of Android + addRoute(InetNetwork("0.0.0.0", 0)) + addRoute(InetNetwork("::", 0)) + } + excludeRoutes(excludedAddresses) + } + } + + private fun processExcludedRoutes() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) { + // for older versions of Android, build a list of subnets without excluded routes + // and add them to routes + val ipRangeSet = IpRangeSet() + ipRangeSet.remove(IpRange("127.0.0.0", 8)) + excludedRoutes.forEach { + ipRangeSet.remove(IpRange(it)) + } + // remove default routes, if any + removeRoute(InetNetwork("0.0.0.0", 0)) + removeRoute(InetNetwork("::", 0)) + ipRangeSet.subnets().forEach(::addRoute) + addRoute(InetNetwork("2000::", 3)) + } + } + private fun validate() { val errorMessage = StringBuilder() @@ -103,7 +156,13 @@ open class ProtocolConfig protected constructor( if (errorMessage.isNotEmpty()) throw BadConfigException(errorMessage.toString()) } - open fun build(): ProtocolConfig = validate().run { ProtocolConfig(this@Builder) } + protected fun configBuild() { + processSplitTunneling() + processExcludedRoutes() + validate() + } + + open fun build(): ProtocolConfig = configBuild().run { ProtocolConfig(this@Builder) } } companion object { diff --git a/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt b/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt index 76ccd905..0e303f0e 100644 --- a/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt +++ b/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt @@ -75,7 +75,7 @@ open class WireguardConfig protected constructor( fun setPrivateKeyHex(privateKeyHex: String) = apply { this.privateKeyHex = privateKeyHex } - override fun build(): WireguardConfig = WireguardConfig(this) + override fun build(): WireguardConfig = configBuild().run { WireguardConfig(this@Builder) } } companion object {