Add split tunneling

This commit is contained in:
albexk 2023-12-01 00:12:50 +03:00
parent 20f3c0388a
commit e7658f9859
14 changed files with 422 additions and 104 deletions

View file

@ -0,0 +1,17 @@
package org.amnezia.vpn.util.net
import java.net.InetAddress
data class InetEndpoint(val address: InetAddress, val port: Int) {
override fun toString(): String = "${address.hostAddress}:$port"
companion object {
fun parse(data: String): InetEndpoint {
val split = data.split(":")
val address = parseInetAddress(split.first())
val port = split.last().toInt()
return InetEndpoint(address, port)
}
}
}

View file

@ -0,0 +1,26 @@
package org.amnezia.vpn.util.net
import java.net.Inet4Address
import java.net.InetAddress
data class InetNetwork(val address: InetAddress, val mask: Int) {
constructor(address: String, mask: Int) : this(parseInetAddress(address), mask)
constructor(address: InetAddress) : this(address, address.maxPrefixLength)
override fun toString(): String = "${address.hostAddress}/$mask"
companion object {
fun parse(data: String): InetNetwork {
val split = data.split("/")
val address = parseInetAddress(split.first())
if (split.size == 1) return InetNetwork(address)
val mask = split.last().toInt()
return InetNetwork(address, mask)
}
}
}
private val InetAddress.maxPrefixLength: Int
get() = if (this is Inet4Address) 32 else 128

View file

@ -0,0 +1,85 @@
package org.amnezia.vpn.util.net
import java.net.InetAddress
@OptIn(ExperimentalUnsignedTypes::class)
class IpAddress private constructor(private val address: UByteArray) : Comparable<IpAddress> {
val size: Int = address.size
val lastIndex: Int = address.lastIndex
val maxMask: Int = size * 8
constructor(inetAddress: InetAddress) : this(inetAddress.address.asUByteArray())
constructor(ipAddress: String) : this(parseInetAddress(ipAddress))
operator fun get(i: Int): UByte = address[i]
operator fun set(i: Int, b: UByte) { address[i] = b }
fun fill(value: UByte, fromByte: Int) = address.fill(value, fromByte)
fun copy(): IpAddress = IpAddress(address.copyOf())
fun inc(): IpAddress {
if (address.all { it == 0xffu.toUByte() }) {
throw RuntimeException("IP address overflow")
}
val copy = copy()
for (i in copy.lastIndex downTo 0) {
if (++copy[i] != 0u.toUByte()) break
}
return copy
}
fun dec(): IpAddress {
if (address.all { it == 0u.toUByte() }) {
throw RuntimeException("IP address overflow")
}
val copy = copy()
for (i in copy.lastIndex downTo 0) {
if (--copy[i] != 0xffu.toUByte()) break
}
return copy
}
fun isMaxIp(): Boolean = address.all { it == 0xffu.toUByte() }
override fun compareTo(other: IpAddress): Int {
if (size != other.size) return size - other.size
for (i in address.indices) {
val d = (address[i] - other.address[i]).toInt()
if (d != 0) return d
}
return 0
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as IpAddress
return compareTo(other) == 0
}
override fun hashCode(): Int {
return address.hashCode()
}
override fun toString(): String {
if (size > 4) return toIpv6String()
return address.joinToString(".")
}
@OptIn(ExperimentalStdlibApi::class)
private fun toIpv6String(): String {
val sb = StringBuilder()
var i = 0
while (i < size) {
sb.append(address[i++].toHexString())
sb.append(address[i++].toHexString())
sb.append(':')
}
sb.deleteAt(sb.lastIndex)
return sb.toString()
}
}

View file

@ -0,0 +1,119 @@
package org.amnezia.vpn.util.net
import java.net.InetAddress
class IpRange(private val start: IpAddress, private val end: IpAddress) : Comparable<IpRange> {
init {
if (start > end) throw IllegalArgumentException("Start IP: $start is greater then end IP: $end")
}
private constructor(addresses: Pair<IpAddress, IpAddress>) : this(addresses.first, addresses.second)
constructor(inetAddress: InetAddress, mask: Int) : this(from(inetAddress, mask))
constructor(address: String, mask: Int) : this(parseInetAddress(address), mask)
constructor(inetNetwork: InetNetwork) : this(from(inetNetwork))
private operator fun contains(other: IpRange): Boolean =
(start <= other.start) && (end >= other.end)
private fun isIntersect(other: IpRange): Boolean =
(start <= other.end) && (end >= other.start)
operator fun minus(other: IpRange): List<IpRange>? {
if (this in other) return emptyList()
if (!isIntersect(other)) return null
val resultRanges = mutableListOf<IpRange>()
if (start < other.start) resultRanges += IpRange(start, other.start.dec())
if (end > other.end) resultRanges += IpRange(other.end.inc(), end)
return resultRanges
}
fun subnets(): List<InetNetwork> {
var currentIp = start
var mask: Int
val subnets = mutableListOf<InetNetwork>()
while (currentIp <= end) {
mask = getPossibleMaxMask(currentIp)
var lastIp = getLastIpForMask(currentIp, mask)
while (lastIp > end) {
lastIp = getLastIpForMask(currentIp, ++mask)
}
subnets.add(InetNetwork(currentIp.toString(), mask))
if (lastIp.isMaxIp()) break
currentIp = lastIp.inc()
}
return subnets
}
private fun getPossibleMaxMask(ip: IpAddress): Int {
var mask = ip.maxMask
for (i in ip.lastIndex downTo 0) {
val lastZeroBits = ip[i].countTrailingZeroBits()
mask -= lastZeroBits
if (lastZeroBits != 8) break
}
return mask
}
private fun getLastIpForMask(ip: IpAddress, mask: Int): IpAddress {
var remainingBits = ip.maxMask - mask
if (remainingBits == 0) return ip
var i = ip.lastIndex
val lastIp = ip.copy()
while (remainingBits > 0 && i >= 0) {
lastIp[i] =
if (remainingBits > 8) {
lastIp[i] or 0xffu
} else {
lastIp[i] or ((0xffu shl remainingBits).toUByte().inv())
}
remainingBits -= 8
--i
}
return lastIp
}
override fun compareTo(other: IpRange): Int {
val d = start.compareTo(other.start)
return if (d == 0) end.compareTo(other.end) else d
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as IpRange
return compareTo(other) == 0
}
override fun hashCode(): Int {
var result = start.hashCode()
result = 31 * result + end.hashCode()
return result
}
override fun toString(): String {
return "$start - $end"
}
companion object {
private fun from(inetAddress: InetAddress, mask: Int): Pair<IpAddress, IpAddress> {
val start = IpAddress(inetAddress)
val end = IpAddress(inetAddress)
val lastByte = mask / 8
if (lastByte < start.size) {
val byteMask = (0xffu shl (8 - mask % 8)).toUByte()
start[lastByte] = start[lastByte].and(byteMask)
end[lastByte] = end[lastByte].or(byteMask.inv())
start.fill(0u, lastByte + 1)
end.fill(0xffu, lastByte + 1)
}
return Pair(start, end)
}
private fun from(inetNetwork: InetNetwork): Pair<IpAddress, IpAddress> =
from(inetNetwork.address, inetNetwork.mask)
}
}

View file

@ -0,0 +1,26 @@
package org.amnezia.vpn.util.net
class IpRangeSet(ipRange: IpRange = IpRange("0.0.0.0", 0)) {
private val ranges = sortedSetOf(ipRange)
fun remove(ipRange: IpRange) {
val iterator = ranges.iterator()
val splitRanges = mutableListOf<IpRange>()
while (iterator.hasNext()) {
val range = iterator.next()
(range - ipRange)?.let { resultRanges ->
iterator.remove()
splitRanges += resultRanges
}
}
ranges += splitRanges
}
fun subnets(): List<InetNetwork> =
ranges.map(IpRange::subnets).flatten()
override fun toString(): String {
return ranges.toString()
}
}

View file

@ -0,0 +1,46 @@
package org.amnezia.vpn.util.net
import android.content.Context
import android.net.ConnectivityManager
import android.net.InetAddresses
import android.net.NetworkCapabilities
import android.os.Build
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
fun getLocalNetworks(context: Context, ipv6: Boolean): List<InetNetwork> {
val connectivityManager = context.getSystemService(ConnectivityManager::class.java)
connectivityManager.activeNetwork?.let { network ->
val netCapabilities = connectivityManager.getNetworkCapabilities(network)
val linkProperties = connectivityManager.getLinkProperties(network)
if (linkProperties == null ||
netCapabilities == null ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_VPN) ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)
) return emptyList()
val addresses = mutableListOf<InetNetwork>()
for (linkAddress in linkProperties.linkAddresses) {
val address = linkAddress.address
if ((!ipv6 && address is Inet4Address) || (ipv6 && address is Inet6Address)) {
addresses += InetNetwork(address, linkAddress.prefixLength)
}
}
return addresses
}
return emptyList()
}
fun parseInetAddress(address: String): InetAddress = parseNumericAddressCompat(address)
private val parseNumericAddressCompat: (String) -> InetAddress =
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
InetAddresses::parseNumericAddress
} else {
val m = InetAddress::class.java.getMethod("parseNumericAddress", String::class.java)
fun(address: String): InetAddress {
return m.invoke(null, address) as InetAddress
}
}