Add onError callback to handle errors in protocol threads

This commit is contained in:
albexk 2023-12-05 13:47:12 +03:00
parent 5c3e253067
commit 5835a756ce
5 changed files with 28 additions and 27 deletions

View file

@ -11,8 +11,8 @@ import net.openvpn.ovpn3.ClientAPI_Config
import org.amnezia.vpn.protocol.BadConfigException import org.amnezia.vpn.protocol.BadConfigException
import org.amnezia.vpn.protocol.Protocol import org.amnezia.vpn.protocol.Protocol
import org.amnezia.vpn.protocol.ProtocolState import org.amnezia.vpn.protocol.ProtocolState
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED
import org.amnezia.vpn.protocol.Statistics import org.amnezia.vpn.protocol.Statistics
import org.amnezia.vpn.protocol.VpnException
import org.amnezia.vpn.protocol.VpnStartException import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.util.net.InetNetwork import org.amnezia.vpn.util.net.InetNetwork
import org.amnezia.vpn.util.net.getLocalNetworks import org.amnezia.vpn.util.net.getLocalNetworks
@ -53,8 +53,8 @@ open class OpenVpn : Protocol() {
return Statistics.EMPTY_STATISTICS return Statistics.EMPTY_STATISTICS
} }
override fun initialize(context: Context, state: MutableStateFlow<ProtocolState>) { override fun initialize(context: Context, state: MutableStateFlow<ProtocolState>, onError: (String) -> Unit) {
super.initialize(context, state) super.initialize(context, state, onError)
loadSharedLibrary(context, "ovpn3") loadSharedLibrary(context, "ovpn3")
this.context = context this.context = context
scope = CoroutineScope(Dispatchers.IO) scope = CoroutineScope(Dispatchers.IO)
@ -64,11 +64,12 @@ open class OpenVpn : Protocol() {
val configBuilder = OpenVpnConfig.Builder() val configBuilder = OpenVpnConfig.Builder()
openVpnClient = OpenVpnClient( openVpnClient = OpenVpnClient(
configBuilder, configBuilder = configBuilder,
state, state = state,
{ ipv6 -> getLocalNetworks(context, ipv6) }, getLocalNetworks = { ipv6 -> getLocalNetworks(context, ipv6) },
makeEstablish(configBuilder, vpnBuilder), establish = makeEstablish(configBuilder, vpnBuilder),
protect protect = protect,
onError = onError
) )
try { try {
@ -92,7 +93,8 @@ open class OpenVpn : Protocol() {
scope.launch { scope.launch {
val status = client.connect() val status = client.connect()
if (status.error) { if (status.error) {
throw VpnException("OpenVpn connect() error: ${status.status}: ${status.message}") state.value = DISCONNECTED
onError("OpenVpn connect() error: ${status.status}: ${status.message}")
} }
} }
} }

View file

@ -14,7 +14,6 @@ import net.openvpn.ovpn3.ClientAPI_TransportStats
import org.amnezia.vpn.protocol.ProtocolState import org.amnezia.vpn.protocol.ProtocolState
import org.amnezia.vpn.protocol.ProtocolState.CONNECTED import org.amnezia.vpn.protocol.ProtocolState.CONNECTED
import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED import org.amnezia.vpn.protocol.ProtocolState.DISCONNECTED
import org.amnezia.vpn.protocol.VpnStartException
import org.amnezia.vpn.util.Log import org.amnezia.vpn.util.Log
import org.amnezia.vpn.util.net.InetNetwork import org.amnezia.vpn.util.net.InetNetwork
import org.amnezia.vpn.util.net.parseInetAddress import org.amnezia.vpn.util.net.parseInetAddress
@ -27,7 +26,8 @@ class OpenVpnClient(
private val state: MutableStateFlow<ProtocolState>, private val state: MutableStateFlow<ProtocolState>,
private val getLocalNetworks: (Boolean) -> List<InetNetwork>, private val getLocalNetworks: (Boolean) -> List<InetNetwork>,
private val establish: () -> Int, private val establish: () -> Int,
private val protect: (Int) -> Boolean private val protect: (Int) -> Boolean,
private val onError: (String) -> Unit
) : ClientAPI_OpenVPNClient() { ) : ClientAPI_OpenVPNClient() {
/************************************************************************** /**************************************************************************
@ -368,15 +368,11 @@ class OpenVpnClient(
"COMPRESSION_ENABLED", "WARN" -> Log.w(TAG, "$name: $info") "COMPRESSION_ENABLED", "WARN" -> Log.w(TAG, "$name: $info")
"CONNECTED" -> state.value = CONNECTED "CONNECTED" -> state.value = CONNECTED
"DISCONNECTED" -> state.value = DISCONNECTED "DISCONNECTED" -> state.value = DISCONNECTED
"CONNECTION_TIMEOUT" -> {
Log.w(TAG, "$name: $info")
state.value = DISCONNECTED
// todo: test it
throw VpnStartException("Connection timeout")
}
} }
if (event.error) Log.e(TAG, "OpenVpn ERROR: $name: $info") if (event.error || event.fatal) {
if (event.fatal) Log.e(TAG, "OpenVpn FATAL: $name: $info") state.value = DISCONNECTED
onError("OpenVpn ${if (event.error) "ERROR" else "FATAL"}: $name: $info")
}
} }
// Callback for logging. // Callback for logging.

View file

@ -30,9 +30,11 @@ abstract class Protocol {
abstract val statistics: Statistics abstract val statistics: Statistics
protected lateinit var state: MutableStateFlow<ProtocolState> protected lateinit var state: MutableStateFlow<ProtocolState>
protected lateinit var onError: (String) -> Unit
open fun initialize(context: Context, state: MutableStateFlow<ProtocolState>) { open fun initialize(context: Context, state: MutableStateFlow<ProtocolState>, onError: (String) -> Unit) {
this.state = state this.state = state
this.onError = onError
} }
abstract fun startVpn(config: JSONObject, vpnBuilder: Builder, protect: (Int) -> Boolean) abstract fun startVpn(config: JSONObject, vpnBuilder: Builder, protect: (Int) -> Boolean)

View file

@ -345,18 +345,19 @@ class AmneziaVpnService : VpnService() {
"openvpn" -> OpenVpn() "openvpn" -> OpenVpn()
"cloak" -> Cloak() "cloak" -> Cloak()
else -> throw IllegalArgumentException("Protocol '$protocolName' not found") else -> throw IllegalArgumentException("Protocol '$protocolName' not found")
}.apply { initialize(applicationContext, protocolState) } }.apply { initialize(applicationContext, protocolState, ::onError) }
.also { protocolCache[protocolName] = it } .also { protocolCache[protocolName] = it }
/** /**
* Utils methods * Utils methods
*/ */
@MainThread
private fun onError(msg: String) { private fun onError(msg: String) {
Log.e(TAG, msg) Log.e(TAG, msg)
clientMessenger.send { mainScope.launch {
ServiceEvent.ERROR.packToMessage { clientMessenger.send {
putString(ERROR_MSG, msg) ServiceEvent.ERROR.packToMessage {
putString(ERROR_MSG, msg)
}
} }
} }
} }

View file

@ -78,8 +78,8 @@ open class Wireguard : Protocol() {
} }
} }
override fun initialize(context: Context, state: MutableStateFlow<ProtocolState>) { override fun initialize(context: Context, state: MutableStateFlow<ProtocolState>, onError: (String) -> Unit) {
super.initialize(context, state) super.initialize(context, state, onError)
loadSharedLibrary(context, "wg-go") loadSharedLibrary(context, "wg-go")
} }