ProtocolApi refactoring, move network classes to NetworkUtils.kt

This commit is contained in:
albexk 2023-11-28 19:47:22 +03:00
parent 8ec105bee0
commit 9738ada946
13 changed files with 213 additions and 122 deletions

View file

@ -17,7 +17,6 @@
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" /> <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" /> <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<uses-permission android:name="android.permission.CAMERA" /> <uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" /> <uses-permission android:name="android.permission.FOREGROUND_SERVICE" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_SPECIAL_USE" /> <uses-permission android:name="android.permission.FOREGROUND_SERVICE_SPECIAL_USE" />

View file

@ -64,7 +64,7 @@ class Awg : Wireguard() {
val configDataJson = config.getJSONObject("awg_config_data") val configDataJson = config.getJSONObject("awg_config_data")
val configData = parseConfigData(configDataJson.getString("config")) val configData = parseConfigData(configDataJson.getString("config"))
return AwgConfig.build { return AwgConfig.build {
configureWireguard(wireguardConfigBuilder(configData)) configWireguard(configData)
configData["Jc"]?.let { setJc(it.toInt()) } configData["Jc"]?.let { setJc(it.toInt()) }
configData["Jmin"]?.let { setJmin(it.toInt()) } configData["Jmin"]?.let { setJmin(it.toInt()) }
configData["Jmax"]?.let { setJmax(it.toInt()) } configData["Jmax"]?.let { setJmax(it.toInt()) }

View file

@ -17,7 +17,7 @@ class AwgConfig private constructor(
) : WireguardConfig(wireguardConfigBuilder) { ) : WireguardConfig(wireguardConfigBuilder) {
private constructor(builder: Builder) : this( private constructor(builder: Builder) : this(
builder.wireguardConfigBuilder, builder,
builder.jc, builder.jc,
builder.jmin, builder.jmin,
builder.jmax, builder.jmax,
@ -43,58 +43,52 @@ class AwgConfig private constructor(
return this.toString() return this.toString()
} }
class Builder { class Builder : WireguardConfig.Builder() {
internal lateinit var wireguardConfigBuilder: WireguardConfig.Builder
private set
private var _jc: Int? = null private var _jc: Int? = null
internal var jc: Int internal var jc: Int
get() = _jc ?: throw BadConfigException("AWG: parameter jc is undefined") get() = _jc ?: throw BadConfigException("AWG: parameter jc is undefined")
private set(value) { _jc = value} private set(value) { _jc = value }
private var _jmin: Int? = null private var _jmin: Int? = null
internal var jmin: Int internal var jmin: Int
get() = _jmin ?: throw BadConfigException("AWG: parameter jmin is undefined") get() = _jmin ?: throw BadConfigException("AWG: parameter jmin is undefined")
private set(value) { _jmin = value} private set(value) { _jmin = value }
private var _jmax: Int? = null private var _jmax: Int? = null
internal var jmax: Int internal var jmax: Int
get() = _jmax ?: throw BadConfigException("AWG: parameter jmax is undefined") get() = _jmax ?: throw BadConfigException("AWG: parameter jmax is undefined")
private set(value) { _jmax = value} private set(value) { _jmax = value }
private var _s1: Int? = null private var _s1: Int? = null
internal var s1: Int internal var s1: Int
get() = _s1 ?: throw BadConfigException("AWG: parameter s1 is undefined") get() = _s1 ?: throw BadConfigException("AWG: parameter s1 is undefined")
private set(value) { _s1 = value} private set(value) { _s1 = value }
private var _s2: Int? = null private var _s2: Int? = null
internal var s2: Int internal var s2: Int
get() = _s2 ?: throw BadConfigException("AWG: parameter s2 is undefined") get() = _s2 ?: throw BadConfigException("AWG: parameter s2 is undefined")
private set(value) { _s2 = value} private set(value) { _s2 = value }
private var _h1: Long? = null private var _h1: Long? = null
internal var h1: Long internal var h1: Long
get() = _h1 ?: throw BadConfigException("AWG: parameter h1 is undefined") get() = _h1 ?: throw BadConfigException("AWG: parameter h1 is undefined")
private set(value) { _h1 = value} private set(value) { _h1 = value }
private var _h2: Long? = null private var _h2: Long? = null
internal var h2: Long internal var h2: Long
get() = _h2 ?: throw BadConfigException("AWG: parameter h2 is undefined") get() = _h2 ?: throw BadConfigException("AWG: parameter h2 is undefined")
private set(value) { _h2 = value} private set(value) { _h2 = value }
private var _h3: Long? = null private var _h3: Long? = null
internal var h3: Long internal var h3: Long
get() = _h3 ?: throw BadConfigException("AWG: parameter h3 is undefined") get() = _h3 ?: throw BadConfigException("AWG: parameter h3 is undefined")
private set(value) { _h3 = value} private set(value) { _h3 = value }
private var _h4: Long? = null private var _h4: Long? = null
internal var h4: Long internal var h4: Long
get() = _h4 ?: throw BadConfigException("AWG: parameter h4 is undefined") get() = _h4 ?: throw BadConfigException("AWG: parameter h4 is undefined")
private set(value) { _h4 = value} private set(value) { _h4 = value }
fun configureWireguard(block: WireguardConfig.Builder.() -> Unit) = apply {
wireguardConfigBuilder = WireguardConfig.Builder().apply(block)
}
fun setJc(jc: Int) = apply { this.jc = jc } fun setJc(jc: Int) = apply { this.jc = jc }
fun setJmin(jmin: Int) = apply { this.jmin = jmin } fun setJmin(jmin: Int) = apply { this.jmin = jmin }
@ -106,7 +100,7 @@ class AwgConfig private constructor(
fun setH3(h3: Long) = apply { this.h3 = h3 } fun setH3(h3: Long) = apply { this.h3 = h3 }
fun setH4(h4: Long) = apply { this.h4 = h4 } fun setH4(h4: Long) = apply { this.h4 = h4 }
fun build(): AwgConfig = AwgConfig(this) override fun build(): AwgConfig = AwgConfig(this)
} }
companion object { companion object {

View file

@ -14,4 +14,5 @@ android {
dependencies { dependencies {
compileOnly(project(":utils")) compileOnly(project(":utils"))
implementation(libs.androidx.annotation) implementation(libs.androidx.annotation)
implementation(libs.kotlinx.coroutines)
} }

View file

@ -11,6 +11,8 @@ import androidx.annotation.RequiresApi
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.util.zip.ZipFile import java.util.zip.ZipFile
import kotlinx.coroutines.flow.MutableStateFlow
import org.amnezia.vpn.util.InetNetwork
import org.amnezia.vpn.util.Log import org.amnezia.vpn.util.Log
import org.json.JSONObject import org.json.JSONObject
@ -21,8 +23,11 @@ const val VPN_SESSION_NAME = "AmneziaVPN"
abstract class Protocol { abstract class Protocol {
abstract val statistics: Statistics abstract val statistics: Statistics
protected lateinit var state: MutableStateFlow<ProtocolState>
abstract fun initialize(context: Context) open fun initialize(context: Context, state: MutableStateFlow<ProtocolState>) {
this.state = state
}
abstract fun startVpn(config: JSONObject, vpnBuilder: Builder, protect: (Int) -> Boolean) abstract fun startVpn(config: JSONObject, vpnBuilder: Builder, protect: (Int) -> Boolean)
@ -30,11 +35,17 @@ abstract class Protocol {
protected open fun buildVpnInterface(config: ProtocolConfig, vpnBuilder: Builder) { protected open fun buildVpnInterface(config: ProtocolConfig, vpnBuilder: Builder) {
vpnBuilder.setSession(VPN_SESSION_NAME) vpnBuilder.setSession(VPN_SESSION_NAME)
vpnBuilder.allowFamily(OsConstants.AF_INET)
vpnBuilder.allowFamily(OsConstants.AF_INET6)
for (addr in config.addresses) vpnBuilder.addAddress(addr) for (addr in config.addresses) vpnBuilder.addAddress(addr)
for (addr in config.dnsServers) vpnBuilder.addDnsServer(addr) for (addr in config.dnsServers) vpnBuilder.addDnsServer(addr)
// fix for Samsung android ignoring DNS servers outside the VPN route range
if (Build.BRAND == "samsung") {
for (addr in config.dnsServers) vpnBuilder.addRoute(InetNetwork(addr))
}
config.searchDomain?.let { vpnBuilder.addSearchDomain(it) }
for (addr in config.routes) vpnBuilder.addRoute(addr) for (addr in config.routes) vpnBuilder.addRoute(addr)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU)
@ -43,6 +54,15 @@ abstract class Protocol {
for (app in config.excludedApplications) vpnBuilder.addDisallowedApplication(app) for (app in config.excludedApplications) vpnBuilder.addDisallowedApplication(app)
vpnBuilder.setMtu(config.mtu) vpnBuilder.setMtu(config.mtu)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)
config.httpProxy?.let { vpnBuilder.setHttpProxy(it) }
if (config.allowAllAF) {
vpnBuilder.allowFamily(OsConstants.AF_INET)
vpnBuilder.allowFamily(OsConstants.AF_INET6)
}
vpnBuilder.setBlocking(config.blockingMode) vpnBuilder.setBlocking(config.blockingMode)
vpnBuilder.setUnderlyingNetworks(null) vpnBuilder.setUnderlyingNetworks(null)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)

View file

@ -1,16 +1,18 @@
package org.amnezia.vpn.protocol package org.amnezia.vpn.protocol
import android.net.InetAddresses import android.net.ProxyInfo
import android.os.Build
import androidx.annotation.RequiresApi
import java.net.InetAddress import java.net.InetAddress
import org.amnezia.vpn.util.InetNetwork
open class ProtocolConfig protected constructor( open class ProtocolConfig protected constructor(
val addresses: Set<InetNetwork>, val addresses: Set<InetNetwork>,
val dnsServers: Set<InetAddress>, val dnsServers: Set<InetAddress>,
val searchDomain: String?,
val routes: Set<InetNetwork>, val routes: Set<InetNetwork>,
val excludedRoutes: Set<InetNetwork>, val excludedRoutes: Set<InetNetwork>,
val excludedApplications: Set<String>, val excludedApplications: Set<String>,
val httpProxy: ProxyInfo?,
val allowAllAF: Boolean,
val blockingMode: Boolean, val blockingMode: Boolean,
val mtu: Int val mtu: Int
) { ) {
@ -18,25 +20,37 @@ open class ProtocolConfig protected constructor(
protected constructor(builder: Builder) : this( protected constructor(builder: Builder) : this(
builder.addresses, builder.addresses,
builder.dnsServers, builder.dnsServers,
builder.searchDomain,
builder.routes, builder.routes,
builder.excludedRoutes, builder.excludedRoutes,
builder.excludedApplications, builder.excludedApplications,
builder.httpProxy,
builder.allowAllAF,
builder.blockingMode, builder.blockingMode,
builder.mtu builder.mtu
) )
class Builder(blockingMode: Boolean) { open class Builder(blockingMode: Boolean) {
internal val addresses: MutableSet<InetNetwork> = hashSetOf() internal val addresses: MutableSet<InetNetwork> = hashSetOf()
internal val dnsServers: MutableSet<InetAddress> = hashSetOf() internal val dnsServers: MutableSet<InetAddress> = hashSetOf()
internal val routes: MutableSet<InetNetwork> = hashSetOf() internal val routes: MutableSet<InetNetwork> = hashSetOf()
internal val excludedRoutes: MutableSet<InetNetwork> = hashSetOf() internal val excludedRoutes: MutableSet<InetNetwork> = hashSetOf()
internal val excludedApplications: MutableSet<String> = hashSetOf() internal val excludedApplications: MutableSet<String> = hashSetOf()
internal var searchDomain: String? = null
private set
internal var httpProxy: ProxyInfo? = null
private set
internal var allowAllAF: Boolean = false
private set
internal var blockingMode: Boolean = blockingMode internal var blockingMode: Boolean = blockingMode
private set private set
internal var mtu: Int = 0 open var mtu: Int = 0
private set protected set
fun addAddress(addr: InetNetwork) = apply { this.addresses += addr } fun addAddress(addr: InetNetwork) = apply { this.addresses += addr }
fun addAddresses(addresses: List<InetNetwork>) = apply { this.addresses += addresses } fun addAddresses(addresses: List<InetNetwork>) = apply { this.addresses += addresses }
@ -44,18 +58,21 @@ open class ProtocolConfig protected constructor(
fun addDnsServer(dnsServer: InetAddress) = apply { this.dnsServers += dnsServer } fun addDnsServer(dnsServer: InetAddress) = apply { this.dnsServers += dnsServer }
fun addDnsServers(dnsServers: List<InetAddress>) = apply { this.dnsServers += dnsServers } fun addDnsServers(dnsServers: List<InetAddress>) = apply { this.dnsServers += dnsServers }
fun setSearchDomain(domain: String) = apply { this.searchDomain = domain }
fun addRoute(route: InetNetwork) = apply { this.routes += route } fun addRoute(route: InetNetwork) = apply { this.routes += route }
fun addRoutes(routes: List<InetNetwork>) = apply { this.routes += routes } fun addRoutes(routes: List<InetNetwork>) = apply { this.routes += routes }
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun excludeRoute(route: InetNetwork) = apply { this.excludedRoutes += route } fun excludeRoute(route: InetNetwork) = apply { this.excludedRoutes += route }
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
fun excludeRoutes(routes: List<InetNetwork>) = apply { this.excludedRoutes += routes } fun excludeRoutes(routes: List<InetNetwork>) = apply { this.excludedRoutes += routes }
fun excludeApplication(application: String) = apply { this.excludedApplications += application } fun excludeApplication(application: String) = apply { this.excludedApplications += application }
fun excludeApplications(applications: List<String>) = apply { this.excludedApplications += applications } fun excludeApplications(applications: List<String>) = apply { this.excludedApplications += applications }
fun setHttpProxy(httpProxy: ProxyInfo) = apply { this.httpProxy = httpProxy }
fun setAllowAllAF(allowAllAF: Boolean) = apply { this.allowAllAF = allowAllAF }
fun setBlockingMode(blockingMode: Boolean) = apply { this.blockingMode = blockingMode } fun setBlockingMode(blockingMode: Boolean) = apply { this.blockingMode = blockingMode }
fun setMtu(mtu: Int) = apply { this.mtu = mtu } fun setMtu(mtu: Int) = apply { this.mtu = mtu }
@ -72,7 +89,7 @@ open class ProtocolConfig protected constructor(
if (errorMessage.isNotEmpty()) throw BadConfigException(errorMessage.toString()) if (errorMessage.isNotEmpty()) throw BadConfigException(errorMessage.toString())
} }
fun build(): ProtocolConfig = validate().run { ProtocolConfig(this@Builder) } open fun build(): ProtocolConfig = validate().run { ProtocolConfig(this@Builder) }
} }
companion object { companion object {
@ -80,43 +97,3 @@ open class ProtocolConfig protected constructor(
Builder(blockingMode).apply(block).build() Builder(blockingMode).apply(block).build()
} }
} }
data class InetNetwork(val address: InetAddress, val mask: Int) {
override fun toString(): String = "${address.hostAddress}/$mask"
companion object {
fun parse(data: String): InetNetwork {
val split = data.split("/")
val address = parseInetAddress(split.first())
val mask = split.last().toInt()
return InetNetwork(address, mask)
}
}
}
data class InetEndpoint(val address: InetAddress, val port: Int) {
override fun toString(): String = "${address.hostAddress}:$port"
companion object {
fun parse(data: String): InetEndpoint {
val split = data.split(":")
val address = parseInetAddress(split.first())
val port = split.last().toInt()
return InetEndpoint(address, port)
}
}
}
fun parseInetAddress(address: String): InetAddress = parseNumericAddressCompat(address)
private val parseNumericAddressCompat: (String) -> InetAddress =
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
InetAddresses::parseNumericAddress
} else {
val m = InetAddress::class.java.getMethod("parseNumericAddress", String::class.java)
fun(address: String): InetAddress {
return m.invoke(null, address) as InetAddress
}
}

View file

@ -4,5 +4,6 @@ enum class ProtocolState {
CONNECTED, CONNECTED,
CONNECTING, CONNECTING,
DISCONNECTED, DISCONNECTED,
DISCONNECTING DISCONNECTING,
UNKNOWN
} }

View file

@ -21,10 +21,13 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
import org.amnezia.vpn.protocol.BadConfigException import org.amnezia.vpn.protocol.BadConfigException
import org.amnezia.vpn.protocol.LoadLibraryException import org.amnezia.vpn.protocol.LoadLibraryException
import org.amnezia.vpn.protocol.Protocol import org.amnezia.vpn.protocol.Protocol
@ -32,10 +35,12 @@ import org.amnezia.vpn.protocol.ProtocolState.CONNECTED
import org.amnezia.vpn.protocol.ProtocolState.CONNECTING import org.amnezia.vpn.protocol.ProtocolState.CONNECTING
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTING import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTING
import org.amnezia.vpn.protocol.ProtocolState.UNKNOWN
import org.amnezia.vpn.protocol.Statistics import org.amnezia.vpn.protocol.Statistics
import org.amnezia.vpn.protocol.Status import org.amnezia.vpn.protocol.Status
import org.amnezia.vpn.protocol.VpnStartException import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.protocol.awg.Awg import org.amnezia.vpn.protocol.awg.Awg
import org.amnezia.vpn.protocol.openvpn.OpenVpn
import org.amnezia.vpn.protocol.putStatistics import org.amnezia.vpn.protocol.putStatistics
import org.amnezia.vpn.protocol.putStatus import org.amnezia.vpn.protocol.putStatus
import org.amnezia.vpn.protocol.wireguard.Wireguard import org.amnezia.vpn.protocol.wireguard.Wireguard
@ -54,10 +59,11 @@ private const val NOTIFICATION_ID = 1337
class AmneziaVpnService : VpnService() { class AmneziaVpnService : VpnService() {
private lateinit var mainScope: CoroutineScope private lateinit var mainScope: CoroutineScope
private lateinit var connectionScope: CoroutineScope
private var isServiceBound = false private var isServiceBound = false
private var protocol: Protocol? = null private var protocol: Protocol? = null
private val protocolCache = mutableMapOf<String, Protocol>() private val protocolCache = mutableMapOf<String, Protocol>()
private var protocolState = MutableStateFlow(DISCONNECTED) private var protocolState = MutableStateFlow(UNKNOWN)
private val isConnected private val isConnected
get() = protocolState.value == CONNECTED get() = protocolState.value == CONNECTED
@ -65,6 +71,9 @@ class AmneziaVpnService : VpnService() {
private val isDisconnected private val isDisconnected
get() = protocolState.value == DISCONNECTED get() = protocolState.value == DISCONNECTED
private val isUnknown
get() = protocolState.value == UNKNOWN
private var connectionJob: Job? = null private var connectionJob: Job? = null
private var disconnectionJob: Job? = null private var disconnectionJob: Job? = null
private lateinit var clientMessenger: IpcMessenger private lateinit var clientMessenger: IpcMessenger
@ -167,7 +176,8 @@ class AmneziaVpnService : VpnService() {
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
Log.v(TAG, "Create Amnezia VPN service") Log.v(TAG, "Create Amnezia VPN service")
mainScope = CoroutineScope(SupervisorJob() + Dispatchers.Main.immediate + connectionExceptionHandler) mainScope = CoroutineScope(SupervisorJob() + Dispatchers.Main.immediate)
connectionScope = CoroutineScope(SupervisorJob() + Dispatchers.IO + connectionExceptionHandler)
clientMessenger = IpcMessenger(messengerName = "Client") clientMessenger = IpcMessenger(messengerName = "Client")
launchProtocolStateHandler() launchProtocolStateHandler()
} }
@ -205,13 +215,14 @@ class AmneziaVpnService : VpnService() {
if (intent?.action != "android.net.VpnService") { if (intent?.action != "android.net.VpnService") {
isServiceBound = false isServiceBound = false
clientMessenger.reset() clientMessenger.reset()
if (isDisconnected) stopSelf() if (isUnknown || isDisconnected) stopSelf()
} }
return super.onUnbind(intent) return super.onUnbind(intent)
} }
override fun onRevoke() { override fun onRevoke() {
Log.v(TAG, "onRevoke") Log.v(TAG, "onRevoke")
// Calls to onRevoke() method may not happen on the main thread of the process
mainScope.launch { mainScope.launch {
disconnect() disconnect()
} }
@ -219,10 +230,9 @@ class AmneziaVpnService : VpnService() {
override fun onDestroy() { override fun onDestroy() {
Log.v(TAG, "Destroy service") Log.v(TAG, "Destroy service")
if (!isDisconnected) { // todo: add sync disconnect
protocol?.stopVpn() disconnect()
protocolState.value = DISCONNECTED connectionScope.cancel()
}
mainScope.cancel() mainScope.cancel()
super.onDestroy() super.onDestroy()
} }
@ -244,7 +254,7 @@ class AmneziaVpnService : VpnService() {
if (!isServiceBound) stopSelf() if (!isServiceBound) stopSelf()
} }
CONNECTING, DISCONNECTING -> {} CONNECTING, DISCONNECTING, UNKNOWN -> {}
} }
} }
} }
@ -270,13 +280,12 @@ class AmneziaVpnService : VpnService() {
return return
} }
connectionJob = mainScope.launch { connectionJob = connectionScope.launch {
disconnectionJob?.join() disconnectionJob?.join()
disconnectionJob = null disconnectionJob = null
protocol = getProtocol(config.getString("protocol")) protocol = getProtocol(config.getString("protocol"))
protocol?.startVpn(config, Builder(), ::protect) protocol?.startVpn(config, Builder(), ::protect)
protocolState.value = CONNECTED
} }
} }
@ -284,11 +293,11 @@ class AmneziaVpnService : VpnService() {
private fun disconnect() { private fun disconnect() {
Log.v(TAG, "Stop VPN connection") Log.v(TAG, "Stop VPN connection")
if (isDisconnected || protocolState.value == DISCONNECTING) return if (isUnknown || isDisconnected || protocolState.value == DISCONNECTING) return
protocolState.value = DISCONNECTING protocolState.value = DISCONNECTING
disconnectionJob = mainScope.launch { disconnectionJob = connectionScope.launch {
connectionJob?.let { connectionJob?.let {
if (it.isActive) it.cancelAndJoin() if (it.isActive) it.cancelAndJoin()
} }
@ -296,7 +305,6 @@ class AmneziaVpnService : VpnService() {
protocol?.stopVpn() protocol?.stopVpn()
protocol = null protocol = null
protocolState.value = DISCONNECTED
} }
} }
@ -306,7 +314,7 @@ class AmneziaVpnService : VpnService() {
"wireguard" -> Wireguard() "wireguard" -> Wireguard()
"awg" -> Awg() "awg" -> Awg()
else -> throw IllegalArgumentException("Failed to load $protocolName protocol") else -> throw IllegalArgumentException("Failed to load $protocolName protocol")
}.apply { initialize(applicationContext) } }.apply { initialize(applicationContext, protocolState) }
/** /**
* Utils methods * Utils methods

View file

@ -0,0 +1,4 @@
<?xml version="1.0"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
</manifest>

View file

@ -0,0 +1,87 @@
package org.amnezia.vpn.util
import android.content.Context
import android.net.ConnectivityManager
import android.net.InetAddresses
import android.net.NetworkCapabilities
import android.os.Build
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
object NetworkUtils {
fun getLocalNetworks(context: Context, ipv6: Boolean): List<InetNetwork> {
val connectivityManager = context.getSystemService(ConnectivityManager::class.java)
connectivityManager.activeNetwork?.let { network ->
val netCapabilities = connectivityManager.getNetworkCapabilities(network)
val linkProperties = connectivityManager.getLinkProperties(network)
if (linkProperties == null ||
netCapabilities == null ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_VPN) ||
netCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)
) return emptyList()
val addresses = mutableListOf<InetNetwork>()
for (linkAddress in linkProperties.linkAddresses) {
val address = linkAddress.address
if ((!ipv6 && address is Inet4Address) || (ipv6 && address is Inet6Address)) {
addresses += InetNetwork(address, linkAddress.prefixLength)
}
}
return addresses
}
return emptyList()
}
}
data class InetNetwork(val address: InetAddress, val mask: Int) {
constructor(address: String, mask: Int) : this(parseInetAddress(address), mask)
constructor(address: InetAddress) : this(address, address.maxPrefixLength)
constructor(address: String) : this(parseInetAddress(address))
override fun toString(): String = "${address.hostAddress}/$mask"
companion object {
fun parse(data: String): InetNetwork {
val split = data.split("/")
val address = parseInetAddress(split.first())
val mask = split.last().toInt()
return InetNetwork(address, mask)
}
}
}
data class InetEndpoint(val address: InetAddress, val port: Int) {
override fun toString(): String = "${address.hostAddress}:$port"
companion object {
fun parse(data: String): InetEndpoint {
val split = data.split(":")
val address = parseInetAddress(split.first())
val port = split.last().toInt()
return InetEndpoint(address, port)
}
}
}
private val InetAddress.maxPrefixLength: Int
get() = if (this is Inet4Address) 32 else 128
fun parseInetAddress(address: String): InetAddress = parseNumericAddressCompat(address)
private val parseNumericAddressCompat: (String) -> InetAddress =
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
InetAddresses::parseNumericAddress
} else {
val m = InetAddress::class.java.getMethod("parseNumericAddress", String::class.java)
fun(address: String): InetAddress {
return m.invoke(null, address) as InetAddress
}
}

View file

@ -14,4 +14,5 @@ android {
dependencies { dependencies {
compileOnly(project(":utils")) compileOnly(project(":utils"))
compileOnly(project(":protocolApi")) compileOnly(project(":protocolApi"))
implementation(libs.kotlinx.coroutines)
} }

View file

@ -4,13 +4,17 @@ import android.content.Context
import android.net.VpnService.Builder import android.net.VpnService.Builder
import java.util.TreeMap import java.util.TreeMap
import com.wireguard.android.backend.GoBackend import com.wireguard.android.backend.GoBackend
import org.amnezia.vpn.protocol.InetEndpoint import kotlinx.coroutines.flow.MutableStateFlow
import org.amnezia.vpn.protocol.InetNetwork
import org.amnezia.vpn.protocol.Protocol import org.amnezia.vpn.protocol.Protocol
import org.amnezia.vpn.protocol.ProtocolState
import org.amnezia.vpn.protocol.ProtocolState.CONNECTED
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED
import org.amnezia.vpn.protocol.Statistics import org.amnezia.vpn.protocol.Statistics
import org.amnezia.vpn.protocol.VpnStartException import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.protocol.parseInetAddress import org.amnezia.vpn.util.InetEndpoint
import org.amnezia.vpn.util.InetNetwork
import org.amnezia.vpn.util.Log import org.amnezia.vpn.util.Log
import org.amnezia.vpn.util.parseInetAddress
import org.json.JSONObject import org.json.JSONObject
/** /**
@ -74,39 +78,38 @@ open class Wireguard : Protocol() {
} }
} }
override fun initialize(context: Context) { override fun initialize(context: Context, state: MutableStateFlow<ProtocolState>) {
super.initialize(context, state)
loadSharedLibrary(context, "wg-go") loadSharedLibrary(context, "wg-go")
} }
override fun startVpn(config: JSONObject, vpnBuilder: Builder, protect: (Int) -> Boolean) { override fun startVpn(config: JSONObject, vpnBuilder: Builder, protect: (Int) -> Boolean) {
val wireguardConfig = parseConfig(config) val wireguardConfig = parseConfig(config)
start(wireguardConfig, vpnBuilder, protect) start(wireguardConfig, vpnBuilder, protect)
state.value = CONNECTED
} }
protected open fun parseConfig(config: JSONObject): WireguardConfig { protected open fun parseConfig(config: JSONObject): WireguardConfig {
val configDataJson = config.getJSONObject("wireguard_config_data") val configDataJson = config.getJSONObject("wireguard_config_data")
val configData = parseConfigData(configDataJson.getString("config")) val configData = parseConfigData(configDataJson.getString("config"))
return WireguardConfig.build(wireguardConfigBuilder(configData)) return WireguardConfig.build { configWireguard(configData) }
} }
protected fun wireguardConfigBuilder(configData: Map<String, String>): WireguardConfig.Builder.() -> Unit = protected fun WireguardConfig.Builder.configWireguard(configData: Map<String, String>) {
{ configData["Address"]?.let { addAddress(InetNetwork.parse(it)) }
configureBaseProtocol(true) { configData["DNS"]?.split(",")?.map { dns ->
configData["Address"]?.let { addAddress(InetNetwork.parse(it)) } parseInetAddress(dns.trim())
configData["DNS"]?.split(",")?.map { dns -> }?.forEach(::addDnsServer)
parseInetAddress(dns.trim()) configData["AllowedIPs"]?.split(",")?.map { route ->
}?.forEach(::addDnsServer) InetNetwork.parse(route.trim())
configData["AllowedIPs"]?.split(",")?.map { route -> }?.forEach(::addRoute)
InetNetwork.parse(route.trim()) configData["MTU"]?.let { setMtu(it.toInt()) }
}?.forEach(::addRoute) configData["Endpoint"]?.let { setEndpoint(InetEndpoint.parse(it)) }
setMtu(configData["MTU"]?.toInt() ?: WIREGUARD_DEFAULT_MTU) configData["PersistentKeepalive"]?.let { setPersistentKeepalive(it.toInt()) }
} configData["PrivateKey"]?.let { setPrivateKeyHex(it.base64ToHex()) }
configData["Endpoint"]?.let { setEndpoint(InetEndpoint.parse(it)) } configData["PublicKey"]?.let { setPublicKeyHex(it.base64ToHex()) }
configData["PersistentKeepalive"]?.let { setPersistentKeepalive(it.toInt()) } configData["PresharedKey"]?.let { setPreSharedKeyHex(it.base64ToHex()) }
configData["PrivateKey"]?.let { setPrivateKeyHex(it.base64ToHex()) } }
configData["PublicKey"]?.let { setPublicKeyHex(it.base64ToHex()) }
configData["PresharedKey"]?.let { setPreSharedKeyHex(it.base64ToHex()) }
}
protected fun parseConfigData(data: String): Map<String, String> { protected fun parseConfigData(data: String): Map<String, String> {
val parsedData = TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER) val parsedData = TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER)
@ -155,5 +158,6 @@ open class Wireguard : Protocol() {
val handleToClose = tunnelHandle val handleToClose = tunnelHandle
tunnelHandle = -1 tunnelHandle = -1
GoBackend.wgTurnOff(handleToClose) GoBackend.wgTurnOff(handleToClose)
state.value = DISCONNECTED
} }
} }

View file

@ -1,10 +1,10 @@
package org.amnezia.vpn.protocol.wireguard package org.amnezia.vpn.protocol.wireguard
import android.util.Base64 import android.util.Base64
import org.amnezia.vpn.protocol.InetEndpoint
import org.amnezia.vpn.protocol.ProtocolConfig import org.amnezia.vpn.protocol.ProtocolConfig
import org.amnezia.vpn.util.InetEndpoint
internal const val WIREGUARD_DEFAULT_MTU = 1280 private const val WIREGUARD_DEFAULT_MTU = 1280
open class WireguardConfig protected constructor( open class WireguardConfig protected constructor(
protocolConfigBuilder: ProtocolConfig.Builder, protocolConfigBuilder: ProtocolConfig.Builder,
@ -16,7 +16,7 @@ open class WireguardConfig protected constructor(
) : ProtocolConfig(protocolConfigBuilder) { ) : ProtocolConfig(protocolConfigBuilder) {
protected constructor(builder: Builder) : this( protected constructor(builder: Builder) : this(
builder.protocolConfigBuilder, builder,
builder.endpoint, builder.endpoint,
builder.persistentKeepalive, builder.persistentKeepalive,
builder.publicKeyHex, builder.publicKeyHex,
@ -38,10 +38,7 @@ open class WireguardConfig protected constructor(
return this.toString() return this.toString()
} }
class Builder { open class Builder : ProtocolConfig.Builder(true) {
internal lateinit var protocolConfigBuilder: ProtocolConfig.Builder
private set
internal lateinit var endpoint: InetEndpoint internal lateinit var endpoint: InetEndpoint
private set private set
@ -57,9 +54,7 @@ open class WireguardConfig protected constructor(
internal lateinit var privateKeyHex: String internal lateinit var privateKeyHex: String
private set private set
fun configureBaseProtocol(blockingMode: Boolean, block: ProtocolConfig.Builder.() -> Unit) = apply { override var mtu: Int = WIREGUARD_DEFAULT_MTU
protocolConfigBuilder = ProtocolConfig.Builder(blockingMode).apply(block)
}
fun setEndpoint(endpoint: InetEndpoint) = apply { this.endpoint = endpoint } fun setEndpoint(endpoint: InetEndpoint) = apply { this.endpoint = endpoint }
@ -71,7 +66,7 @@ open class WireguardConfig protected constructor(
fun setPrivateKeyHex(privateKeyHex: String) = apply { this.privateKeyHex = privateKeyHex } fun setPrivateKeyHex(privateKeyHex: String) = apply { this.privateKeyHex = privateKeyHex }
fun build(): WireguardConfig = WireguardConfig(this) override fun build(): WireguardConfig = WireguardConfig(this)
} }
companion object { companion object {