Add split tunneling

This commit is contained in:
albexk 2023-12-01 00:12:50 +03:00
parent 20f3c0388a
commit e7658f9859
14 changed files with 422 additions and 104 deletions

View file

@ -65,6 +65,7 @@ class Awg : Wireguard() {
val configData = parseConfigData(configDataJson.getString("config"))
return AwgConfig.build {
configWireguard(configData)
configSplitTunnel(config)
configData["Jc"]?.let { setJc(it.toInt()) }
configData["Jmin"]?.let { setJmin(it.toInt()) }
configData["Jmax"]?.let { setJmax(it.toInt()) }

View file

@ -2,6 +2,7 @@ package org.amnezia.vpn.protocol.openvpn
import android.content.Context
import android.net.VpnService.Builder
import android.os.Build
import kotlinx.coroutines.flow.MutableStateFlow
import net.openvpn.ovpn3.ClientAPI_Config
import org.amnezia.vpn.protocol.BadConfigException
@ -10,7 +11,8 @@ import org.amnezia.vpn.protocol.ProtocolState
import org.amnezia.vpn.protocol.Statistics
import org.amnezia.vpn.protocol.VpnException
import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.util.NetworkUtils
import org.amnezia.vpn.util.net.InetNetwork
import org.amnezia.vpn.util.net.getLocalNetworks
import org.json.JSONObject
/**
@ -59,7 +61,7 @@ open class OpenVpn : Protocol() {
openVpnClient = OpenVpnClient(
configBuilder,
state,
{ ipv6 -> NetworkUtils.getLocalNetworks(context, ipv6) },
{ ipv6 -> getLocalNetworks(context, ipv6) },
makeEstablish(configBuilder, vpnBuilder),
protect
)
@ -71,6 +73,16 @@ open class OpenVpn : Protocol() {
if (evalConfig.error) {
throw BadConfigException("OpenVPN config parse error: ${evalConfig.message}")
}
configBuilder.apply {
// fix for split tunneling
// The exclude split tunneling OpenVpn configuration does not contain a default route.
// It is required for split tunneling in newer versions of Android.
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
addRoute(InetNetwork("0.0.0.0", 0))
addRoute(InetNetwork("::", 0))
}
configSplitTunnel(config)
}
val status = client.connect()
if (status.error) {

View file

@ -1,6 +1,7 @@
package org.amnezia.vpn.protocol.openvpn
import android.net.ProxyInfo
import android.os.Build
import kotlinx.coroutines.flow.MutableStateFlow
import net.openvpn.ovpn3.ClientAPI_Config
import net.openvpn.ovpn3.ClientAPI_EvalConfig
@ -14,9 +15,9 @@ import org.amnezia.vpn.protocol.ProtocolState
import org.amnezia.vpn.protocol.ProtocolState.CONNECTED
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED
import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.util.InetNetwork
import org.amnezia.vpn.util.Log
import org.amnezia.vpn.util.parseInetAddress
import org.amnezia.vpn.util.net.InetNetwork
import org.amnezia.vpn.util.net.parseInetAddress
private const val TAG = "OpenVpnClient"
private const val EMULATED_EXCLUDE_ROUTES = (1 shl 16)
@ -87,7 +88,9 @@ class OpenVpnClient(
// metric is optional and should be ignored if < 0
override fun tun_builder_exclude_route(address: String, prefix_length: Int, metric: Int, ipv6: Boolean): Boolean {
Log.v(TAG, "tun_builder_exclude_route: $address, $prefix_length, $metric, $ipv6")
configBuilder.excludeRoute(InetNetwork(address, prefix_length))
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
configBuilder.excludeRoute(InetNetwork(address, prefix_length))
}
return true
}
@ -179,11 +182,13 @@ class OpenVpnClient(
// Never called more than once per tun_builder session.
override fun tun_builder_set_proxy_http(host: String, port: Int): Boolean {
Log.v(TAG, "tun_builder_set_proxy_http: $host, $port")
try {
configBuilder.setHttpProxy(ProxyInfo.buildDirectProxy(host, port))
} catch (e: Exception) {
Log.e(TAG, "Could not set proxy: ${e.message}")
return false
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
try {
configBuilder.setHttpProxy(ProxyInfo.buildDirectProxy(host, port))
} catch (e: Exception) {
Log.e(TAG, "Could not set proxy: ${e.message}")
return false
}
}
return true
}

View file

@ -12,14 +12,20 @@ import java.io.File
import java.io.FileOutputStream
import java.util.zip.ZipFile
import kotlinx.coroutines.flow.MutableStateFlow
import org.amnezia.vpn.util.InetNetwork
import org.amnezia.vpn.util.Log
import org.amnezia.vpn.util.net.InetNetwork
import org.amnezia.vpn.util.net.IpRange
import org.amnezia.vpn.util.net.IpRangeSet
import org.json.JSONObject
private const val TAG = "Protocol"
const val VPN_SESSION_NAME = "AmneziaVPN"
private const val SPLIT_TUNNEL_DISABLE = 0
private const val SPLIT_TUNNEL_INCLUDE = 1
private const val SPLIT_TUNNEL_EXCLUDE = 2
abstract class Protocol {
abstract val statistics: Statistics
@ -33,6 +39,47 @@ abstract class Protocol {
abstract fun stopVpn()
protected fun ProtocolConfig.Builder.configSplitTunnel(config: JSONObject) {
val splitTunnelType = config.optInt("splitTunnelType")
if (splitTunnelType == SPLIT_TUNNEL_DISABLE) return
val splitTunnelSites = config.getJSONArray("splitTunnelSites")
when (splitTunnelType) {
SPLIT_TUNNEL_INCLUDE -> {
// remove default routes, if any
removeRoute(InetNetwork("0.0.0.0", 0))
removeRoute(InetNetwork("::", 0))
// add routes from config
for (i in 0 until splitTunnelSites.length()) {
val address = InetNetwork.parse(splitTunnelSites.getString(i))
addRoute(address)
}
}
SPLIT_TUNNEL_EXCLUDE -> {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
// exclude routes from config
for (i in 0 until splitTunnelSites.length()) {
val address = InetNetwork.parse(splitTunnelSites.getString(i))
excludeRoute(address)
}
} else {
// For older versions of Android, build a list of subnets without excluded addresses
val ipRangeSet = IpRangeSet()
ipRangeSet.remove(IpRange("127.0.0.0", 8))
for (i in 0 until splitTunnelSites.length()) {
val address = InetNetwork.parse(splitTunnelSites.getString(i))
ipRangeSet.remove(IpRange(address))
}
// remove default routes, if any
removeRoute(InetNetwork("0.0.0.0", 0))
removeRoute(InetNetwork("::", 0))
ipRangeSet.subnets().forEach(::addRoute)
addRoute(InetNetwork("2000::", 3))
}
}
}
}
protected open fun buildVpnInterface(config: ProtocolConfig, vpnBuilder: Builder) {
vpnBuilder.setSession(VPN_SESSION_NAME)

View file

@ -1,8 +1,10 @@
package org.amnezia.vpn.protocol
import android.net.ProxyInfo
import android.os.Build
import androidx.annotation.RequiresApi
import java.net.InetAddress
import org.amnezia.vpn.util.InetNetwork
import org.amnezia.vpn.util.net.InetNetwork
open class ProtocolConfig protected constructor(
val addresses: Set<InetNetwork>,
@ -62,13 +64,19 @@ open class ProtocolConfig protected constructor(
fun addRoute(route: InetNetwork) = apply { this.routes += route }
fun addRoutes(routes: List<InetNetwork>) = apply { this.routes += routes }
fun removeRoute(route: InetNetwork) = apply { this.routes.remove(route) }
fun clearRoutes() = apply { this.routes.clear() }
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun excludeRoute(route: InetNetwork) = apply { this.excludedRoutes += route }
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun excludeRoutes(routes: List<InetNetwork>) = apply { this.excludedRoutes += routes }
fun excludeApplication(application: String) = apply { this.excludedApplications += application }
fun excludeApplications(applications: List<String>) = apply { this.excludedApplications += applications }
@RequiresApi(Build.VERSION_CODES.Q)
fun setHttpProxy(httpProxy: ProxyInfo) = apply { this.httpProxy = httpProxy }
fun setAllowAllAF(allowAllAF: Boolean) = apply { this.allowAllAF = allowAllAF }

View file

@ -1,87 +0,0 @@
package org.amnezia.vpn.util
import android.content.Context
import android.net.ConnectivityManager
import android.net.InetAddresses
import android.net.NetworkCapabilities
import android.os.Build
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
object NetworkUtils {
fun getLocalNetworks(context: Context, ipv6: Boolean): List<InetNetwork> {
val connectivityManager = context.getSystemService(ConnectivityManager::class.java)
connectivityManager.activeNetwork?.let { network ->
val netCapabilities = connectivityManager.getNetworkCapabilities(network)
val linkProperties = connectivityManager.getLinkProperties(network)
if (linkProperties == null ||
netCapabilities == null ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_VPN) ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)
) return emptyList()
val addresses = mutableListOf<InetNetwork>()
for (linkAddress in linkProperties.linkAddresses) {
val address = linkAddress.address
if ((!ipv6 && address is Inet4Address) || (ipv6 && address is Inet6Address)) {
addresses += InetNetwork(address, linkAddress.prefixLength)
}
}
return addresses
}
return emptyList()
}
}
data class InetNetwork(val address: InetAddress, val mask: Int) {
constructor(address: String, mask: Int) : this(parseInetAddress(address), mask)
constructor(address: InetAddress) : this(address, address.maxPrefixLength)
constructor(address: String) : this(parseInetAddress(address))
override fun toString(): String = "${address.hostAddress}/$mask"
companion object {
fun parse(data: String): InetNetwork {
val split = data.split("/")
val address = parseInetAddress(split.first())
val mask = split.last().toInt()
return InetNetwork(address, mask)
}
}
}
data class InetEndpoint(val address: InetAddress, val port: Int) {
override fun toString(): String = "${address.hostAddress}:$port"
companion object {
fun parse(data: String): InetEndpoint {
val split = data.split(":")
val address = parseInetAddress(split.first())
val port = split.last().toInt()
return InetEndpoint(address, port)
}
}
}
private val InetAddress.maxPrefixLength: Int
get() = if (this is Inet4Address) 32 else 128
fun parseInetAddress(address: String): InetAddress = parseNumericAddressCompat(address)
private val parseNumericAddressCompat: (String) -> InetAddress =
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
InetAddresses::parseNumericAddress
} else {
val m = InetAddress::class.java.getMethod("parseNumericAddress", String::class.java)
fun(address: String): InetAddress {
return m.invoke(null, address) as InetAddress
}
}

View file

@ -0,0 +1,17 @@
package org.amnezia.vpn.util.net
import java.net.InetAddress
data class InetEndpoint(val address: InetAddress, val port: Int) {
override fun toString(): String = "${address.hostAddress}:$port"
companion object {
fun parse(data: String): InetEndpoint {
val split = data.split(":")
val address = parseInetAddress(split.first())
val port = split.last().toInt()
return InetEndpoint(address, port)
}
}
}

View file

@ -0,0 +1,26 @@
package org.amnezia.vpn.util.net
import java.net.Inet4Address
import java.net.InetAddress
data class InetNetwork(val address: InetAddress, val mask: Int) {
constructor(address: String, mask: Int) : this(parseInetAddress(address), mask)
constructor(address: InetAddress) : this(address, address.maxPrefixLength)
override fun toString(): String = "${address.hostAddress}/$mask"
companion object {
fun parse(data: String): InetNetwork {
val split = data.split("/")
val address = parseInetAddress(split.first())
if (split.size == 1) return InetNetwork(address)
val mask = split.last().toInt()
return InetNetwork(address, mask)
}
}
}
private val InetAddress.maxPrefixLength: Int
get() = if (this is Inet4Address) 32 else 128

View file

@ -0,0 +1,85 @@
package org.amnezia.vpn.util.net
import java.net.InetAddress
@OptIn(ExperimentalUnsignedTypes::class)
class IpAddress private constructor(private val address: UByteArray) : Comparable<IpAddress> {
val size: Int = address.size
val lastIndex: Int = address.lastIndex
val maxMask: Int = size * 8
constructor(inetAddress: InetAddress) : this(inetAddress.address.asUByteArray())
constructor(ipAddress: String) : this(parseInetAddress(ipAddress))
operator fun get(i: Int): UByte = address[i]
operator fun set(i: Int, b: UByte) { address[i] = b }
fun fill(value: UByte, fromByte: Int) = address.fill(value, fromByte)
fun copy(): IpAddress = IpAddress(address.copyOf())
fun inc(): IpAddress {
if (address.all { it == 0xffu.toUByte() }) {
throw RuntimeException("IP address overflow")
}
val copy = copy()
for (i in copy.lastIndex downTo 0) {
if (++copy[i] != 0u.toUByte()) break
}
return copy
}
fun dec(): IpAddress {
if (address.all { it == 0u.toUByte() }) {
throw RuntimeException("IP address overflow")
}
val copy = copy()
for (i in copy.lastIndex downTo 0) {
if (--copy[i] != 0xffu.toUByte()) break
}
return copy
}
fun isMaxIp(): Boolean = address.all { it == 0xffu.toUByte() }
override fun compareTo(other: IpAddress): Int {
if (size != other.size) return size - other.size
for (i in address.indices) {
val d = (address[i] - other.address[i]).toInt()
if (d != 0) return d
}
return 0
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as IpAddress
return compareTo(other) == 0
}
override fun hashCode(): Int {
return address.hashCode()
}
override fun toString(): String {
if (size > 4) return toIpv6String()
return address.joinToString(".")
}
@OptIn(ExperimentalStdlibApi::class)
private fun toIpv6String(): String {
val sb = StringBuilder()
var i = 0
while (i < size) {
sb.append(address[i++].toHexString())
sb.append(address[i++].toHexString())
sb.append(':')
}
sb.deleteAt(sb.lastIndex)
return sb.toString()
}
}

View file

@ -0,0 +1,119 @@
package org.amnezia.vpn.util.net
import java.net.InetAddress
class IpRange(private val start: IpAddress, private val end: IpAddress) : Comparable<IpRange> {
init {
if (start > end) throw IllegalArgumentException("Start IP: $start is greater then end IP: $end")
}
private constructor(addresses: Pair<IpAddress, IpAddress>) : this(addresses.first, addresses.second)
constructor(inetAddress: InetAddress, mask: Int) : this(from(inetAddress, mask))
constructor(address: String, mask: Int) : this(parseInetAddress(address), mask)
constructor(inetNetwork: InetNetwork) : this(from(inetNetwork))
private operator fun contains(other: IpRange): Boolean =
(start <= other.start) && (end >= other.end)
private fun isIntersect(other: IpRange): Boolean =
(start <= other.end) && (end >= other.start)
operator fun minus(other: IpRange): List<IpRange>? {
if (this in other) return emptyList()
if (!isIntersect(other)) return null
val resultRanges = mutableListOf<IpRange>()
if (start < other.start) resultRanges += IpRange(start, other.start.dec())
if (end > other.end) resultRanges += IpRange(other.end.inc(), end)
return resultRanges
}
fun subnets(): List<InetNetwork> {
var currentIp = start
var mask: Int
val subnets = mutableListOf<InetNetwork>()
while (currentIp <= end) {
mask = getPossibleMaxMask(currentIp)
var lastIp = getLastIpForMask(currentIp, mask)
while (lastIp > end) {
lastIp = getLastIpForMask(currentIp, ++mask)
}
subnets.add(InetNetwork(currentIp.toString(), mask))
if (lastIp.isMaxIp()) break
currentIp = lastIp.inc()
}
return subnets
}
private fun getPossibleMaxMask(ip: IpAddress): Int {
var mask = ip.maxMask
for (i in ip.lastIndex downTo 0) {
val lastZeroBits = ip[i].countTrailingZeroBits()
mask -= lastZeroBits
if (lastZeroBits != 8) break
}
return mask
}
private fun getLastIpForMask(ip: IpAddress, mask: Int): IpAddress {
var remainingBits = ip.maxMask - mask
if (remainingBits == 0) return ip
var i = ip.lastIndex
val lastIp = ip.copy()
while (remainingBits > 0 && i >= 0) {
lastIp[i] =
if (remainingBits > 8) {
lastIp[i] or 0xffu
} else {
lastIp[i] or ((0xffu shl remainingBits).toUByte().inv())
}
remainingBits -= 8
--i
}
return lastIp
}
override fun compareTo(other: IpRange): Int {
val d = start.compareTo(other.start)
return if (d == 0) end.compareTo(other.end) else d
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as IpRange
return compareTo(other) == 0
}
override fun hashCode(): Int {
var result = start.hashCode()
result = 31 * result + end.hashCode()
return result
}
override fun toString(): String {
return "$start - $end"
}
companion object {
private fun from(inetAddress: InetAddress, mask: Int): Pair<IpAddress, IpAddress> {
val start = IpAddress(inetAddress)
val end = IpAddress(inetAddress)
val lastByte = mask / 8
if (lastByte < start.size) {
val byteMask = (0xffu shl (8 - mask % 8)).toUByte()
start[lastByte] = start[lastByte].and(byteMask)
end[lastByte] = end[lastByte].or(byteMask.inv())
start.fill(0u, lastByte + 1)
end.fill(0xffu, lastByte + 1)
}
return Pair(start, end)
}
private fun from(inetNetwork: InetNetwork): Pair<IpAddress, IpAddress> =
from(inetNetwork.address, inetNetwork.mask)
}
}

View file

@ -0,0 +1,26 @@
package org.amnezia.vpn.util.net
class IpRangeSet(ipRange: IpRange = IpRange("0.0.0.0", 0)) {
private val ranges = sortedSetOf(ipRange)
fun remove(ipRange: IpRange) {
val iterator = ranges.iterator()
val splitRanges = mutableListOf<IpRange>()
while (iterator.hasNext()) {
val range = iterator.next()
(range - ipRange)?.let { resultRanges ->
iterator.remove()
splitRanges += resultRanges
}
}
ranges += splitRanges
}
fun subnets(): List<InetNetwork> =
ranges.map(IpRange::subnets).flatten()
override fun toString(): String {
return ranges.toString()
}
}

View file

@ -0,0 +1,46 @@
package org.amnezia.vpn.util.net
import android.content.Context
import android.net.ConnectivityManager
import android.net.InetAddresses
import android.net.NetworkCapabilities
import android.os.Build
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
fun getLocalNetworks(context: Context, ipv6: Boolean): List<InetNetwork> {
val connectivityManager = context.getSystemService(ConnectivityManager::class.java)
connectivityManager.activeNetwork?.let { network ->
val netCapabilities = connectivityManager.getNetworkCapabilities(network)
val linkProperties = connectivityManager.getLinkProperties(network)
if (linkProperties == null ||
netCapabilities == null ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_VPN) ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)
) return emptyList()
val addresses = mutableListOf<InetNetwork>()
for (linkAddress in linkProperties.linkAddresses) {
val address = linkAddress.address
if ((!ipv6 && address is Inet4Address) || (ipv6 && address is Inet6Address)) {
addresses += InetNetwork(address, linkAddress.prefixLength)
}
}
return addresses
}
return emptyList()
}
fun parseInetAddress(address: String): InetAddress = parseNumericAddressCompat(address)
private val parseNumericAddressCompat: (String) -> InetAddress =
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
InetAddresses::parseNumericAddress
} else {
val m = InetAddress::class.java.getMethod("parseNumericAddress", String::class.java)
fun(address: String): InetAddress {
return m.invoke(null, address) as InetAddress
}
}

View file

@ -11,10 +11,10 @@ import org.amnezia.vpn.protocol.ProtocolState.CONNECTED
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED
import org.amnezia.vpn.protocol.Statistics
import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.util.InetEndpoint
import org.amnezia.vpn.util.InetNetwork
import org.amnezia.vpn.util.Log
import org.amnezia.vpn.util.parseInetAddress
import org.amnezia.vpn.util.net.InetEndpoint
import org.amnezia.vpn.util.net.InetNetwork
import org.amnezia.vpn.util.net.parseInetAddress
import org.json.JSONObject
/**
@ -92,7 +92,20 @@ open class Wireguard : Protocol() {
protected open fun parseConfig(config: JSONObject): WireguardConfig {
val configDataJson = config.getJSONObject("wireguard_config_data")
val configData = parseConfigData(configDataJson.getString("config"))
return WireguardConfig.build { configWireguard(configData) }
return WireguardConfig.build {
configWireguard(configData)
// Default Wireguard routes (0.0.0.0/0, ::/0) will be removed,
// allowed routes from the Wireguard configuration will be merged
// with allowed routes from the split tunneling configuration.
//
// Excluded routes from the split tunneling configuration can overwrite
// allowed routes from the Wireguard configuration (two routes are equal
// if they have the same address and prefix).
//
// If multiple routes match the packet destination,
// route with the longest prefix takes precedence
configSplitTunnel(config)
}
}
protected fun WireguardConfig.Builder.configWireguard(configData: Map<String, String>) {

View file

@ -2,7 +2,7 @@ package org.amnezia.vpn.protocol.wireguard
import android.util.Base64
import org.amnezia.vpn.protocol.ProtocolConfig
import org.amnezia.vpn.util.InetEndpoint
import org.amnezia.vpn.util.net.InetEndpoint
private const val WIREGUARD_DEFAULT_MTU = 1280