From 712fb4d0d32c379402ac19edb5a07f86b8ae9d74 Mon Sep 17 00:00:00 2001 From: albexk Date: Thu, 23 Nov 2023 16:45:15 +0300 Subject: [PATCH] Add Wireguard module --- client/android/build.gradle.kts | 1 + client/android/settings.gradle.kts | 1 + client/android/wireguard/build.gradle.kts | 17 +++ .../wireguard/android/backend/GoBackend.kt | 12 ++ .../vpn/protocol/wireguard/Wireguard.kt | 118 ++++++++++++++++++ .../vpn/protocol/wireguard/WireguardConfig.kt | 83 ++++++++++++ 6 files changed, 232 insertions(+) create mode 100644 client/android/wireguard/build.gradle.kts create mode 100644 client/android/wireguard/src/main/kotlin/com/wireguard/android/backend/GoBackend.kt create mode 100644 client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/Wireguard.kt create mode 100644 client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt diff --git a/client/android/build.gradle.kts b/client/android/build.gradle.kts index 1957a526..39736074 100644 --- a/client/android/build.gradle.kts +++ b/client/android/build.gradle.kts @@ -86,6 +86,7 @@ dependencies { implementation(project(":qt")) implementation(project(":utils")) implementation(project(":protocolApi")) + implementation(project(":wireguard")) implementation(libs.androidx.core) implementation(libs.androidx.activity) implementation(libs.androidx.security.crypto) diff --git a/client/android/settings.gradle.kts b/client/android/settings.gradle.kts index 86359df2..9b8d3c9d 100644 --- a/client/android/settings.gradle.kts +++ b/client/android/settings.gradle.kts @@ -32,6 +32,7 @@ rootProject.buildFileName = "build.gradle.kts" include(":qt") include(":utils") include(":protocolApi") +include(":wireguard") // get values from gradle or local properties val androidBuildToolsVersion: String by gradleProperties diff --git a/client/android/wireguard/build.gradle.kts b/client/android/wireguard/build.gradle.kts new file mode 100644 index 00000000..047d838f --- /dev/null +++ b/client/android/wireguard/build.gradle.kts @@ -0,0 +1,17 @@ +plugins { + id(libs.plugins.android.library.get().pluginId) + id(libs.plugins.kotlin.android.get().pluginId) +} + +kotlin { + jvmToolchain(17) +} + +android { + namespace = "org.amnezia.vpn.protocol.wireguard" +} + +dependencies { + compileOnly(project(":utils")) + compileOnly(project(":protocolApi")) +} diff --git a/client/android/wireguard/src/main/kotlin/com/wireguard/android/backend/GoBackend.kt b/client/android/wireguard/src/main/kotlin/com/wireguard/android/backend/GoBackend.kt new file mode 100644 index 00000000..485df5a0 --- /dev/null +++ b/client/android/wireguard/src/main/kotlin/com/wireguard/android/backend/GoBackend.kt @@ -0,0 +1,12 @@ +package com.wireguard.android.backend + +// TODO: Refactor Amnezia wireguard project by changing the JNI method names +// to move this object to 'org.amnezia.vpn.protocol.wireguard.backend' package +object GoBackend { + external fun wgGetConfig(handle: Int): String? + external fun wgGetSocketV4(handle: Int): Int + external fun wgGetSocketV6(handle: Int): Int + external fun wgTurnOff(handle: Int) + external fun wgTurnOn(ifName: String, tunFd: Int, settings: String): Int + external fun wgVersion(): String +} diff --git a/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/Wireguard.kt b/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/Wireguard.kt new file mode 100644 index 00000000..0125198e --- /dev/null +++ b/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/Wireguard.kt @@ -0,0 +1,118 @@ +package org.amnezia.vpn.protocol.wireguard + +import android.content.Context +import android.net.VpnService.Builder +import java.util.TreeMap +import com.wireguard.android.backend.GoBackend +import org.amnezia.vpn.Log +import org.amnezia.vpn.protocol.InetEndpoint +import org.amnezia.vpn.protocol.InetNetwork +import org.amnezia.vpn.protocol.Protocol +import org.amnezia.vpn.protocol.Statistics +import org.amnezia.vpn.protocol.VPN_SESSION_NAME +import org.amnezia.vpn.protocol.VpnStartException +import org.amnezia.vpn.protocol.parseInetAddress +import org.json.JSONObject + +private const val TAG = "Wireguard" + +class Wireguard(context: Context) : Protocol(context) { + + private var tunnelHandle: Int = -1 + private lateinit var wireguardConfig: WireguardConfig + + override val statistics: Statistics + get() { + if (tunnelHandle == -1) return Statistics.EMPTY_STATISTICS + val config = GoBackend.wgGetConfig(tunnelHandle) ?: return Statistics.EMPTY_STATISTICS + return Statistics.build { + var optsCount = 0 + config.splitToSequence("\\n").forEach { line -> + with(line) { + when { + startsWith("rx_bytes=") -> setRxBytes(substring(9).toLong()).also { ++optsCount } + startsWith("tx_bytes=") -> setTxBytes(substring(9).toLong()).also { ++optsCount } + else -> {} + } + } + if (optsCount == 2) return@forEach + } + } + } + + override fun initialize() { + loadSharedLibrary(context, "wg-go") + } + + override fun parseConfig(config: JSONObject) { + val configDataJson = config.getJSONObject("wireguard_config_data") + val configData = parseConfigData(configDataJson.getString("config")) + wireguardConfig = WireguardConfig.build { + configureBaseProtocol(true) { + configData["Address"]?.let { addAddress(InetNetwork.parse(it)) } + configData["DNS"]?.split(",")?.map { dns -> + parseInetAddress(dns.trim()) + }?.forEach(::addDnsServer) + configData["AllowedIPs"]?.split(",")?.map { route -> + InetNetwork.parse(route.trim()) + }?.forEach(::addRoute) + setMtu(configData["MTU"]?.toInt() ?: WIREGUARD_DEFAULT_MTU) + } + configData["Endpoint"]?.let { setEndpoint(InetEndpoint.parse(it)) } + configData["PersistentKeepalive"]?.let { setPersistentKeepalive(it.toInt()) } + configData["PrivateKey"]?.let { setPrivateKeyHex(it.base64ToHex()) } + configData["PublicKey"]?.let { setPublicKeyHex(it.base64ToHex()) } + configData["PresharedKey"]?.let { setPreSharedKeyHex(it.base64ToHex()) } + } + this.config = wireguardConfig.baseProtocolConfig + } + + private fun parseConfigData(data: String): Map { + val parsedData = TreeMap(String.CASE_INSENSITIVE_ORDER) + data.lineSequence() + .filter { it.isNotEmpty() && !it.startsWith('[') } + .forEach { line -> + val attr = line.split("=", limit = 2) + parsedData[attr.first().trim()] = attr.last().trim() + } + return parsedData + } + + override fun startVpn(vpnBuilder: Builder, protect: (Int) -> Boolean) { + if (tunnelHandle != -1) { + Log.w(TAG, "Tunnel already up") + return + } + + buildVpnInterface(vpnBuilder) + + vpnBuilder.establish().use { tunFd -> + if (tunFd == null) { + throw VpnStartException("Create VPN interface: permission not granted or revoked") + } + Log.v(TAG, "Wg-go backend ${GoBackend.wgVersion()}") + tunnelHandle = GoBackend.wgTurnOn(VPN_SESSION_NAME, tunFd.detachFd(), wireguardConfig.toWgUserspaceString()) + } + + if (tunnelHandle < 0) { + tunnelHandle = -1 + throw VpnStartException("Wireguard tunnel creation error") + } + + if (!protect(GoBackend.wgGetSocketV4(tunnelHandle)) || !protect(GoBackend.wgGetSocketV6(tunnelHandle))) { + GoBackend.wgTurnOff(tunnelHandle) + tunnelHandle = -1 + throw VpnStartException("Protect VPN interface: permission not granted or revoked") + } + } + + override fun stopVpn() { + if (tunnelHandle == -1) { + Log.w(TAG, "Tunnel already down") + return + } + val handleToClose = tunnelHandle + tunnelHandle = -1 + GoBackend.wgTurnOff(handleToClose) + } +} diff --git a/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt b/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt new file mode 100644 index 00000000..740d1ab5 --- /dev/null +++ b/client/android/wireguard/src/main/kotlin/org/amnezia/vpn/protocol/wireguard/WireguardConfig.kt @@ -0,0 +1,83 @@ +package org.amnezia.vpn.protocol.wireguard + +import android.util.Base64 +import org.amnezia.vpn.protocol.InetEndpoint +import org.amnezia.vpn.protocol.ProtocolConfig + +internal const val WIREGUARD_DEFAULT_MTU = 1280 + +data class WireguardConfig( + val baseProtocolConfig: ProtocolConfig, + val endpoint: InetEndpoint, + val persistentKeepalive: Int, + val publicKeyHex: String, + val preSharedKeyHex: String, + val privateKeyHex: String +) { + + private constructor(builder: Builder) : this( + builder.baseProtocolConfig, + builder.endpoint, + builder.persistentKeepalive, + builder.publicKeyHex, + builder.preSharedKeyHex, + builder.privateKeyHex + ) + + fun toWgUserspaceString(): String = with(StringBuilder()) { + appendLine("private_key=$privateKeyHex") + appendLine("replace_peers=true") + appendLine("public_key=$publicKeyHex") + baseProtocolConfig.routes.forEach { route -> + appendLine("allowed_ip=$route") + } + appendLine("endpoint=$endpoint") + if (persistentKeepalive != 0) + appendLine("persistent_keepalive_interval=$persistentKeepalive") + appendLine("preshared_key=$preSharedKeyHex") + return this.toString() + } + + class Builder { + internal lateinit var baseProtocolConfig: ProtocolConfig + private set + + internal lateinit var endpoint: InetEndpoint + private set + + internal var persistentKeepalive: Int = 0 + private set + + internal lateinit var publicKeyHex: String + private set + + internal lateinit var preSharedKeyHex: String + private set + + internal lateinit var privateKeyHex: String + private set + + fun configureBaseProtocol(blockingMode: Boolean, block: ProtocolConfig.Builder.() -> Unit) = apply { + baseProtocolConfig = ProtocolConfig.Builder(blockingMode).apply(block).build() + } + + fun setEndpoint(endpoint: InetEndpoint) = apply { this.endpoint = endpoint } + + fun setPersistentKeepalive(persistentKeepalive: Int) = apply { this.persistentKeepalive = persistentKeepalive } + + fun setPublicKeyHex(publicKeyHex: String) = apply { this.publicKeyHex = publicKeyHex } + + fun setPreSharedKeyHex(preSharedKeyHex: String) = apply { this.preSharedKeyHex = preSharedKeyHex } + + fun setPrivateKeyHex(privateKeyHex: String) = apply { this.privateKeyHex = privateKeyHex } + + fun build(): WireguardConfig = WireguardConfig(this) + } + + companion object { + inline fun build(block: Builder.() -> Unit): WireguardConfig = Builder().apply(block).build() + } +} + +@OptIn(ExperimentalStdlibApi::class) +internal fun String.base64ToHex(): String = Base64.decode(this, Base64.DEFAULT).toHexString()