diff --git a/.gitignore b/.gitignore index aa26c3d7..ebe0acdf 100644 --- a/.gitignore +++ b/.gitignore @@ -43,4 +43,46 @@ CMakeLists.txt.user* # tmp files *.*~ +######################### Android +# Built application files +*.apk +*.aar +*.ap_ +*.aab +# Files for the ART/Dalvik VM +*.dex + +# Java class files +*.class + +# Gradle files +.gradle/ +build/ + +# Local configuration file (sdk path, etc) +local.properties + +# Proguard folder generated by Eclipse +proguard/ + +# Android Studio Navigation editor temp files +.navigation/ + +# Android Studio captures folder +captures/ + +# IntelliJ +*.iml +.idea/ + +# Keystore files +# Uncomment the following lines if you do not want to check your keystore files in. +#*.jks +#*.keystore + +# External native build folder generated in Android Studio 2.2 and later +.externalNativeBuild + +# Android Profiling +*.hprof diff --git a/client/android/AndroidManifest.xml b/client/android/AndroidManifest.xml index 092cf165..f3cf53a5 100644 --- a/client/android/AndroidManifest.xml +++ b/client/android/AndroidManifest.xml @@ -1,8 +1,13 @@ - - - + + + + + + + + + @@ -77,6 +82,19 @@ + + + + + + + + + + diff --git a/client/android/build.gradle b/client/android/build.gradle index 443a8002..355d4c72 100644 --- a/client/android/build.gradle +++ b/client/android/build.gradle @@ -1,23 +1,52 @@ buildscript { + ext{ + kotlin_version = "1.4.30-M1" + // for libwg + appcompatVersion = '1.1.0' + annotationsVersion = '1.0.1' + databindingVersion = '3.3.1' + jsr305Version = '3.0.2' + streamsupportVersion = '1.7.0' + threetenabpVersion = '1.1.1' + groupName = 'org.amnezia.vpn' + } + repositories { google() jcenter() + mavenCentral() } dependencies { - classpath 'com.android.tools.build:gradle:3.6.0' + classpath 'com.android.tools.build:gradle:4.0.0' + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" + classpath "org.jetbrains.kotlin:kotlin-serialization:$kotlin_version" } } repositories { google() jcenter() + mavenCentral() } apply plugin: 'com.android.application' +apply plugin: 'kotlin-android' +apply plugin: 'kotlin-android-extensions' +apply plugin: 'kotlinx-serialization' dependencies { implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.core:core-ktx:1.1.0' + implementation 'com.android.installreferrer:installreferrer:2.2' + implementation 'com.android.billingclient:billing-ktx:4.0.0' + implementation "androidx.lifecycle:lifecycle-livedata-ktx:2.4.0-alpha02" + implementation "androidx.security:security-crypto:1.1.0-alpha03" + implementation "androidx.security:security-identity-credential:1.0.0-alpha02" + implementation 'com.adjust.sdk:adjust-android:4.28.2' + implementation 'com.google.android.gms:play-services-ads-identifier:17.0.1' + implementation "org.jetbrains.kotlinx:kotlinx-serialization-json:1.2.2" + coreLibraryDesugaring "com.android.tools:desugar_jdk_libs:1.0.10" } android { @@ -71,7 +100,13 @@ android { defaultConfig { resConfig "en" - minSdkVersion = qtMinSdkVersion + minSdkVersion = 24 targetSdkVersion = qtTargetSdkVersion } + +// externalNativeBuild { +// cmake { +// path 'tunnel/CMakeLists.txt' +// } +// } } diff --git a/client/android/gradle.properties b/client/android/gradle.properties index fded106b..c5a864cb 100644 --- a/client/android/gradle.properties +++ b/client/android/gradle.properties @@ -9,3 +9,15 @@ org.gradle.jvmargs=-Xmx2048m # build with the same inputs. However, over time, the cache size will # grow. Uncomment the following line to enable it. #org.gradle.caching=true + +android.useAndroidX=true +# Automatically convert third-party libraries to use AndroidX +android.enableJetifier=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official + +android.bundle.enableUncompressedNativeLibs=false +androidBuildToolsVersion=30.0.2 +androidCompileSdkVersion=30 +org.gradle.caching=true +org.gradle.parallel=true diff --git a/client/android/gradle/wrapper/gradle-wrapper.properties b/client/android/gradle/wrapper/gradle-wrapper.properties index 5028f28f..4e1cc9db 100644 --- a/client/android/gradle/wrapper/gradle-wrapper.properties +++ b/client/android/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.4-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-all.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/client/android/res/drawable/ic_amnezia_round.xml b/client/android/res/drawable/ic_amnezia_round.xml new file mode 100644 index 00000000..c3158ddd --- /dev/null +++ b/client/android/res/drawable/ic_amnezia_round.xml @@ -0,0 +1,11 @@ + + + + diff --git a/client/android/res/drawable/ic_launcher_foreground.xml b/client/android/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 00000000..9e459c81 --- /dev/null +++ b/client/android/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,15 @@ + + + + + diff --git a/client/android/res/drawable/ic_logo_on.xml b/client/android/res/drawable/ic_logo_on.xml new file mode 100644 index 00000000..0aa2db2e --- /dev/null +++ b/client/android/res/drawable/ic_logo_on.xml @@ -0,0 +1,10 @@ + + + diff --git a/client/android/res/drawable/splash_background.xml b/client/android/res/drawable/splash_background.xml new file mode 100644 index 00000000..b37de4df --- /dev/null +++ b/client/android/res/drawable/splash_background.xml @@ -0,0 +1,4 @@ + + + + diff --git a/client/android/res/mipmap-anydpi-v26/vpnicon.xml b/client/android/res/mipmap-anydpi-v26/vpnicon.xml new file mode 100644 index 00000000..172bc624 --- /dev/null +++ b/client/android/res/mipmap-anydpi-v26/vpnicon.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/client/android/res/mipmap-anydpi-v26/vpnicon_round.xml b/client/android/res/mipmap-anydpi-v26/vpnicon_round.xml new file mode 100644 index 00000000..172bc624 --- /dev/null +++ b/client/android/res/mipmap-anydpi-v26/vpnicon_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/client/android/res/mipmap-hdpi/vpnicon.png b/client/android/res/mipmap-hdpi/vpnicon.png new file mode 100644 index 00000000..89b60f6f Binary files /dev/null and b/client/android/res/mipmap-hdpi/vpnicon.png differ diff --git a/client/android/res/mipmap-hdpi/vpnicon_foreground.png b/client/android/res/mipmap-hdpi/vpnicon_foreground.png new file mode 100644 index 00000000..72283f59 Binary files /dev/null and b/client/android/res/mipmap-hdpi/vpnicon_foreground.png differ diff --git a/client/android/res/mipmap-hdpi/vpnicon_round.png b/client/android/res/mipmap-hdpi/vpnicon_round.png new file mode 100644 index 00000000..7fe050ca Binary files /dev/null and b/client/android/res/mipmap-hdpi/vpnicon_round.png differ diff --git a/client/android/res/mipmap-mdpi/vpnicon.png b/client/android/res/mipmap-mdpi/vpnicon.png new file mode 100644 index 00000000..319c03a5 Binary files /dev/null and b/client/android/res/mipmap-mdpi/vpnicon.png differ diff --git a/client/android/res/mipmap-mdpi/vpnicon_foreground.png b/client/android/res/mipmap-mdpi/vpnicon_foreground.png new file mode 100644 index 00000000..ff5d71d6 Binary files /dev/null and b/client/android/res/mipmap-mdpi/vpnicon_foreground.png differ diff --git a/client/android/res/mipmap-mdpi/vpnicon_round.png b/client/android/res/mipmap-mdpi/vpnicon_round.png new file mode 100644 index 00000000..f28295b5 Binary files /dev/null and b/client/android/res/mipmap-mdpi/vpnicon_round.png differ diff --git a/client/android/res/mipmap-xhdpi/vpnicon.png b/client/android/res/mipmap-xhdpi/vpnicon.png new file mode 100644 index 00000000..f59ae9d4 Binary files /dev/null and b/client/android/res/mipmap-xhdpi/vpnicon.png differ diff --git a/client/android/res/mipmap-xhdpi/vpnicon_foreground.png b/client/android/res/mipmap-xhdpi/vpnicon_foreground.png new file mode 100644 index 00000000..57a02420 Binary files /dev/null and b/client/android/res/mipmap-xhdpi/vpnicon_foreground.png differ diff --git a/client/android/res/mipmap-xhdpi/vpnicon_round.png b/client/android/res/mipmap-xhdpi/vpnicon_round.png new file mode 100644 index 00000000..61ad0714 Binary files /dev/null and b/client/android/res/mipmap-xhdpi/vpnicon_round.png differ diff --git a/client/android/res/mipmap-xxhdpi/vpnicon.png b/client/android/res/mipmap-xxhdpi/vpnicon.png new file mode 100644 index 00000000..01a7d620 Binary files /dev/null and b/client/android/res/mipmap-xxhdpi/vpnicon.png differ diff --git a/client/android/res/mipmap-xxhdpi/vpnicon_foreground.png b/client/android/res/mipmap-xxhdpi/vpnicon_foreground.png new file mode 100644 index 00000000..12bcb9dd Binary files /dev/null and b/client/android/res/mipmap-xxhdpi/vpnicon_foreground.png differ diff --git a/client/android/res/mipmap-xxhdpi/vpnicon_round.png b/client/android/res/mipmap-xxhdpi/vpnicon_round.png new file mode 100644 index 00000000..afd69c31 Binary files /dev/null and b/client/android/res/mipmap-xxhdpi/vpnicon_round.png differ diff --git a/client/android/res/mipmap-xxxhdpi/vpnicon.png b/client/android/res/mipmap-xxxhdpi/vpnicon.png new file mode 100644 index 00000000..187fc2df Binary files /dev/null and b/client/android/res/mipmap-xxxhdpi/vpnicon.png differ diff --git a/client/android/res/mipmap-xxxhdpi/vpnicon_foreground.png b/client/android/res/mipmap-xxxhdpi/vpnicon_foreground.png new file mode 100644 index 00000000..3921d3fe Binary files /dev/null and b/client/android/res/mipmap-xxxhdpi/vpnicon_foreground.png differ diff --git a/client/android/res/mipmap-xxxhdpi/vpnicon_round.png b/client/android/res/mipmap-xxxhdpi/vpnicon_round.png new file mode 100644 index 00000000..52f4299a Binary files /dev/null and b/client/android/res/mipmap-xxxhdpi/vpnicon_round.png differ diff --git a/client/android/res/values/style.xml b/client/android/res/values/style.xml new file mode 100644 index 00000000..24380943 --- /dev/null +++ b/client/android/res/values/style.xml @@ -0,0 +1,10 @@ + + + + + + diff --git a/client/android/res/values/vpnicon_background.xml b/client/android/res/values/vpnicon_background.xml new file mode 100644 index 00000000..5f343ea1 --- /dev/null +++ b/client/android/res/values/vpnicon_background.xml @@ -0,0 +1,4 @@ + + + #000000 + \ No newline at end of file diff --git a/client/android/src/com/wireguard/android/util/SharedLibraryLoader.java b/client/android/src/com/wireguard/android/util/SharedLibraryLoader.java new file mode 100644 index 00000000..d872fd4f --- /dev/null +++ b/client/android/src/com/wireguard/android/util/SharedLibraryLoader.java @@ -0,0 +1,93 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import android.os.Build; +import android.util.Log; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; + +import androidx.annotation.RestrictTo; +import androidx.annotation.RestrictTo.Scope; + +public final class SharedLibraryLoader { + private static final String TAG = "WireGuard/SharedLibraryLoader"; + + private SharedLibraryLoader() {} + + public static boolean extractLibrary( + final Context context, final String libName, final File destination) throws IOException { + final Collection apks = new HashSet<>(); + Log.d(TAG, "Loading Lib ->" + libName); + if (context.getApplicationInfo().sourceDir != null) + apks.add(context.getApplicationInfo().sourceDir); + if (context.getApplicationInfo().splitSourceDirs != null) + apks.addAll(Arrays.asList(context.getApplicationInfo().splitSourceDirs)); + + for (final String abi : Build.SUPPORTED_ABIS) { + for (final String apk : apks) { + try (final ZipFile zipFile = new ZipFile(new File(apk), ZipFile.OPEN_READ)) { + final String mappedLibName = System.mapLibraryName(libName); + final String libZipPath = + "lib" + File.separatorChar + abi + File.separatorChar + mappedLibName; + final ZipEntry zipEntry = zipFile.getEntry(libZipPath); + if (zipEntry == null) + continue; + Log.d(TAG, "Extracting apk:/" + libZipPath + " to " + destination.getAbsolutePath()); + try (final FileOutputStream out = new FileOutputStream(destination); + final InputStream in = zipFile.getInputStream(zipEntry)) { + int len; + final byte[] buffer = new byte[1024 * 32]; + while ((len = in.read(buffer)) != -1) { + out.write(buffer, 0, len); + } + out.getFD().sync(); + } + } + return true; + } + } + return false; + } + + public static void loadSharedLibrary(final Context context, final String libName) { + Throwable noAbiException; + try { + System.loadLibrary(libName); + return; + } catch (final UnsatisfiedLinkError e) { + Log.d(TAG, "Failed to load library normally, so attempting to extract from apk", e); + noAbiException = e; + } + File f = null; + try { + f = File.createTempFile("lib", ".so", context.getCodeCacheDir()); + if (extractLibrary(context, libName, f)) { + System.load(f.getAbsolutePath()); + return; + } + } catch (final Exception e) { + Log.d(TAG, "Failed to load library apk:/" + libName, e); + noAbiException = e; + } finally { + if (f != null) + // noinspection ResultOfMethodCallIgnored + f.delete(); + } + if (noAbiException instanceof RuntimeException) + throw(RuntimeException) noAbiException; + throw new RuntimeException(noAbiException); + } +} diff --git a/client/android/src/com/wireguard/config/Attribute.java b/client/android/src/com/wireguard/config/Attribute.java new file mode 100644 index 00000000..2cabe428 --- /dev/null +++ b/client/android/src/com/wireguard/config/Attribute.java @@ -0,0 +1,57 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.util.Iterator; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public final class Attribute { + private static final Pattern LINE_PATTERN = Pattern.compile("(\\w+)\\s*=\\s*([^\\s#][^#]*)"); + private static final Pattern LIST_SEPARATOR = Pattern.compile("\\s*,\\s*"); + + private final String key; + private final String value; + + private Attribute(final String key, final String value) { + this.key = key; + this.value = value; + } + + public static String join(final Iterable values) { + final Iterator it = values.iterator(); + if (!it.hasNext()) { + return ""; + } + final StringBuilder sb = new StringBuilder(); + sb.append(it.next()); + while (it.hasNext()) { + sb.append(", "); + sb.append(it.next()); + } + return sb.toString(); + } + + public static Optional parse(final CharSequence line) { + final Matcher matcher = LINE_PATTERN.matcher(line); + if (!matcher.matches()) + return Optional.empty(); + return Optional.of(new Attribute(matcher.group(1), matcher.group(2))); + } + + public static String[] split(final CharSequence value) { + return LIST_SEPARATOR.split(value); + } + + public String getKey() { + return key; + } + + public String getValue() { + return value; + } +} diff --git a/client/android/src/com/wireguard/config/BadConfigException.java b/client/android/src/com/wireguard/config/BadConfigException.java new file mode 100644 index 00000000..33910501 --- /dev/null +++ b/client/android/src/com/wireguard/config/BadConfigException.java @@ -0,0 +1,116 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import com.wireguard.crypto.KeyFormatException; + +import androidx.annotation.Nullable; + +public class BadConfigException extends Exception { + private final Location location; + private final Reason reason; + private final Section section; + @Nullable private final CharSequence text; + + private BadConfigException(final Section section, final Location location, final Reason reason, + @Nullable final CharSequence text, @Nullable final Throwable cause) { + super(cause); + this.section = section; + this.location = location; + this.reason = reason; + this.text = text; + } + + public BadConfigException(final Section section, final Location location, final Reason reason, + @Nullable final CharSequence text) { + this(section, location, reason, text, null); + } + + public BadConfigException( + final Section section, final Location location, final KeyFormatException cause) { + this(section, location, Reason.INVALID_KEY, null, cause); + } + + public BadConfigException(final Section section, final Location location, + @Nullable final CharSequence text, final NumberFormatException cause) { + this(section, location, Reason.INVALID_NUMBER, text, cause); + } + + public BadConfigException( + final Section section, final Location location, final ParseException cause) { + this(section, location, Reason.INVALID_VALUE, cause.getText(), cause); + } + + public Location getLocation() { + return location; + } + + public Reason getReason() { + return reason; + } + + public Section getSection() { + return section; + } + + @Nullable + public CharSequence getText() { + return text; + } + + public enum Location { + TOP_LEVEL(""), + ADDRESS("Address"), + ALLOWED_IPS("AllowedIPs"), + DNS("DNS"), + ENDPOINT("Endpoint"), + EXCLUDED_APPLICATIONS("ExcludedApplications"), + INCLUDED_APPLICATIONS("IncludedApplications"), + LISTEN_PORT("ListenPort"), + MTU("MTU"), + PERSISTENT_KEEPALIVE("PersistentKeepalive"), + PRE_SHARED_KEY("PresharedKey"), + PRIVATE_KEY("PrivateKey"), + PUBLIC_KEY("PublicKey"); + + private final String name; + + Location(final String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + public enum Reason { + INVALID_KEY, + INVALID_NUMBER, + INVALID_VALUE, + MISSING_ATTRIBUTE, + MISSING_SECTION, + SYNTAX_ERROR, + UNKNOWN_ATTRIBUTE, + UNKNOWN_SECTION + } + + public enum Section { + CONFIG("Config"), + INTERFACE("Interface"), + PEER("Peer"); + + private final String name; + + Section(final String name) { + this.name = name; + } + + public String getName() { + return name; + } + } +} diff --git a/client/android/src/com/wireguard/config/Config.java b/client/android/src/com/wireguard/config/Config.java new file mode 100644 index 00000000..06409bab --- /dev/null +++ b/client/android/src/com/wireguard/config/Config.java @@ -0,0 +1,218 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import com.wireguard.config.BadConfigException.Location; +import com.wireguard.config.BadConfigException.Reason; +import com.wireguard.config.BadConfigException.Section; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import androidx.annotation.Nullable; + +/** + * Represents the contents of a wg-quick configuration file, made up of one or more "Interface" + * sections (combined together), and zero or more "Peer" sections (treated individually). + *

+ * Instances of this class are immutable. + */ + +public final class Config { + private final Interface interfaze; + private final List peers; + + private Config(final Builder builder) { + interfaze = Objects.requireNonNull(builder.interfaze, "An [Interface] section is required"); + // Defensively copy to ensure immutability even if the Builder is reused. + peers = Collections.unmodifiableList(new ArrayList<>(builder.peers)); + } + + /** + * Parses an series of "Interface" and "Peer" sections into a {@code Config}. Throws + * {@link BadConfigException} if the input is not well-formed or contains data that cannot + * be parsed. + * + * @param stream a stream of UTF-8 text that is interpreted as a WireGuard configuration + * @return a {@code Config} instance representing the supplied configuration + */ + public static Config parse(final InputStream stream) throws IOException, BadConfigException { + return parse(new BufferedReader(new InputStreamReader(stream))); + } + + /** + * Parses an series of "Interface" and "Peer" sections into a {@code Config}. Throws + * {@link BadConfigException} if the input is not well-formed or contains data that cannot + * be parsed. + * + * @param reader a BufferedReader of UTF-8 text that is interpreted as a WireGuard configuration + * @return a {@code Config} instance representing the supplied configuration + */ + public static Config parse(final BufferedReader reader) throws IOException, BadConfigException { + final Builder builder = new Builder(); + final Collection interfaceLines = new ArrayList<>(); + final Collection peerLines = new ArrayList<>(); + boolean inInterfaceSection = false; + boolean inPeerSection = false; + boolean seenInterfaceSection = false; + @Nullable String line; + while ((line = reader.readLine()) != null) { + final int commentIndex = line.indexOf('#'); + if (commentIndex != -1) + line = line.substring(0, commentIndex); + line = line.trim(); + if (line.isEmpty()) + continue; + if (line.startsWith("[")) { + // Consume all [Peer] lines read so far. + if (inPeerSection) { + builder.parsePeer(peerLines); + peerLines.clear(); + } + if ("[Interface]".equalsIgnoreCase(line)) { + inInterfaceSection = true; + inPeerSection = false; + seenInterfaceSection = true; + } else if ("[Peer]".equalsIgnoreCase(line)) { + inInterfaceSection = false; + inPeerSection = true; + } else { + throw new BadConfigException( + Section.CONFIG, Location.TOP_LEVEL, Reason.UNKNOWN_SECTION, line); + } + } else if (inInterfaceSection) { + interfaceLines.add(line); + } else if (inPeerSection) { + peerLines.add(line); + } else { + throw new BadConfigException( + Section.CONFIG, Location.TOP_LEVEL, Reason.UNKNOWN_SECTION, line); + } + } + if (inPeerSection) + builder.parsePeer(peerLines); + if (!seenInterfaceSection) + throw new BadConfigException( + Section.CONFIG, Location.TOP_LEVEL, Reason.MISSING_SECTION, null); + // Combine all [Interface] sections in the file. + builder.parseInterface(interfaceLines); + return builder.build(); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Config)) + return false; + final Config other = (Config) obj; + return interfaze.equals(other.interfaze) && peers.equals(other.peers); + } + + /** + * Returns the interface section of the configuration. + * + * @return the interface configuration + */ + public Interface getInterface() { + return interfaze; + } + + /** + * Returns a list of the configuration's peer sections. + * + * @return a list of {@link Peer}s + */ + public List getPeers() { + return peers; + } + + @Override + public int hashCode() { + return 31 * interfaze.hashCode() + peers.hashCode(); + } + + /** + * Converts the {@code Config} into a string suitable for debugging purposes. The {@code Config} + * is identified by its interface's public key and the number of peers it has. + * + * @return a concise single-line identifier for the {@code Config} + */ + @Override + public String toString() { + return "(Config " + interfaze + " (" + peers.size() + " peers))"; + } + + /** + * Converts the {@code Config} into a string suitable for use as a {@code wg-quick} + * configuration file. + * + * @return the {@code Config} represented as one [Interface] and zero or more [Peer] sections + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + sb.append("[Interface]\n").append(interfaze.toWgQuickString()); + for (final Peer peer : peers) sb.append("\n[Peer]\n").append(peer.toWgQuickString()); + return sb.toString(); + } + + /** + * Serializes the {@code Config} for use with the WireGuard cross-platform userspace API. + * + * @return the {@code Config} represented as a series of "key=value" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + sb.append(interfaze.toWgUserspaceString()); + sb.append("replace_peers=true\n"); + for (final Peer peer : peers) sb.append(peer.toWgUserspaceString()); + return sb.toString(); + } + + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // Defaults to an empty set. + private final ArrayList peers = new ArrayList<>(); + // No default; must be provided before building. + @Nullable private Interface interfaze; + + public Builder addPeer(final Peer peer) { + peers.add(peer); + return this; + } + + public Builder addPeers(final Collection peers) { + this.peers.addAll(peers); + return this; + } + + public Config build() { + if (interfaze == null) + throw new IllegalArgumentException("An [Interface] section is required"); + return new Config(this); + } + + public Builder parseInterface(final Iterable lines) + throws BadConfigException { + return setInterface(Interface.parse(lines)); + } + + public Builder parsePeer(final Iterable lines) + throws BadConfigException { + return addPeer(Peer.parse(lines)); + } + + public Builder setInterface(final Interface interfaze) { + this.interfaze = interfaze; + return this; + } + } +} diff --git a/client/android/src/com/wireguard/config/InetAddresses.java b/client/android/src/com/wireguard/config/InetAddresses.java new file mode 100644 index 00000000..bb6a6854 --- /dev/null +++ b/client/android/src/com/wireguard/config/InetAddresses.java @@ -0,0 +1,73 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.lang.reflect.Method; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.regex.Pattern; + +import androidx.annotation.Nullable; + +/** + * Utility methods for creating instances of {@link InetAddress}. + */ + +public final class InetAddresses { + @Nullable private static final Method PARSER_METHOD; + private static final Pattern WONT_TOUCH_RESOLVER = Pattern.compile( + "^(((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}|((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})|:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(:(((:[0-9A-Fa-f]{1,4}){1,7})|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:)))(%.+)?)|((?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))$"); + + static { + Method m = null; + try { + if (android.os.Build.VERSION.SDK_INT < android.os.Build.VERSION_CODES.Q) + // noinspection JavaReflectionMemberAccess + m = InetAddress.class.getMethod("parseNumericAddress", String.class); + } catch (final Exception ignored) { + } + PARSER_METHOD = m; + } + + private InetAddresses() {} + + /** + * Parses a numeric IPv4 or IPv6 address without performing any DNS lookups. + * + * @param address a string representing the IP address + * @return an instance of {@link Inet4Address} or {@link Inet6Address}, as appropriate + */ + public static InetAddress parse(final String address) throws ParseException { + if (address.isEmpty()) + throw new ParseException(InetAddress.class, address, "Empty address"); + try { + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.Q) + return android.net.InetAddresses.parseNumericAddress(address); + else if (PARSER_METHOD != null) + return (InetAddress) PARSER_METHOD.invoke(null, address); + else + throw new NoSuchMethodException("parseNumericAddress"); + } catch (final IllegalArgumentException e) { + throw new ParseException(InetAddress.class, address, e); + } catch (final Exception e) { + final Throwable cause = e.getCause(); + // Re-throw parsing exceptions with the original type, as callers might try to catch + // them. On the other hand, callers cannot be expected to handle reflection failures. + if (cause instanceof IllegalArgumentException) + throw new ParseException(InetAddress.class, address, cause); + try { + if (WONT_TOUCH_RESOLVER.matcher(address).matches()) + return InetAddress.getByName(address); + else + throw new ParseException(InetAddress.class, address, "Not an IP address"); + } catch (final UnknownHostException f) { + throw new ParseException(InetAddress.class, address, f); + } + } + } +} diff --git a/client/android/src/com/wireguard/config/InetEndpoint.java b/client/android/src/com/wireguard/config/InetEndpoint.java new file mode 100644 index 00000000..655835fb --- /dev/null +++ b/client/android/src/com/wireguard/config/InetEndpoint.java @@ -0,0 +1,123 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; +import java.util.regex.Pattern; + +import androidx.annotation.Nullable; + +/** + * An external endpoint (host and port) used to connect to a WireGuard {@link Peer}. + *

+ * Instances of this class are externally immutable. + */ + +public final class InetEndpoint { + private static final Pattern BARE_IPV6 = Pattern.compile("^[^\\[\\]]*:[^\\[\\]]*"); + private static final Pattern FORBIDDEN_CHARACTERS = Pattern.compile("[/?#]"); + + private final String host; + private final boolean isResolved; + private final Object lock = new Object(); + private final int port; + private Instant lastResolution = Instant.EPOCH; + @Nullable private InetEndpoint resolved; + + private InetEndpoint(final String host, final boolean isResolved, final int port) { + this.host = host; + this.isResolved = isResolved; + this.port = port; + } + + public static InetEndpoint parse(final String endpoint) throws ParseException { + if (FORBIDDEN_CHARACTERS.matcher(endpoint).find()) + throw new ParseException(InetEndpoint.class, endpoint, "Forbidden characters"); + final URI uri; + try { + uri = new URI("wg://" + endpoint); + } catch (final URISyntaxException e) { + throw new ParseException(InetEndpoint.class, endpoint, e); + } + if (uri.getPort() < 0 || uri.getPort() > 65535) + throw new ParseException(InetEndpoint.class, endpoint, "Missing/invalid port number"); + try { + InetAddresses.parse(uri.getHost()); + // Parsing ths host as a numeric address worked, so we don't need to do DNS lookups. + return new InetEndpoint(uri.getHost(), true, uri.getPort()); + } catch (final ParseException ignored) { + // Failed to parse the host as a numeric address, so it must be a DNS hostname/FQDN. + return new InetEndpoint(uri.getHost(), false, uri.getPort()); + } + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof InetEndpoint)) + return false; + final InetEndpoint other = (InetEndpoint) obj; + return host.equals(other.host) && port == other.port; + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + /** + * Generate an {@code InetEndpoint} instance with the same port and the host resolved using DNS + * to a numeric address. If the host is already numeric, the existing instance may be returned. + * Because this function may perform network I/O, it must not be called from the main thread. + * + * @return the resolved endpoint, or {@link Optional#empty()} + */ + public Optional getResolved() { + if (isResolved) + return Optional.of(this); + synchronized (lock) { + // TODO(zx2c4): Implement a real timeout mechanism using DNS TTL + if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { + try { + // Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues. + final InetAddress[] candidates = InetAddress.getAllByName(host); + InetAddress address = candidates[0]; + for (final InetAddress candidate : candidates) { + if (candidate instanceof Inet4Address) { + address = candidate; + break; + } + } + resolved = new InetEndpoint(address.getHostAddress(), true, port); + lastResolution = Instant.now(); + } catch (final UnknownHostException e) { + resolved = null; + } + } + return Optional.ofNullable(resolved); + } + } + + @Override + public int hashCode() { + return host.hashCode() ^ port; + } + + @Override + public String toString() { + final boolean isBareIpv6 = isResolved && BARE_IPV6.matcher(host).matches(); + return (isBareIpv6 ? '[' + host + ']' : host) + ':' + port; + } +} diff --git a/client/android/src/com/wireguard/config/InetNetwork.java b/client/android/src/com/wireguard/config/InetNetwork.java new file mode 100644 index 00000000..584950d4 --- /dev/null +++ b/client/android/src/com/wireguard/config/InetNetwork.java @@ -0,0 +1,77 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.net.Inet4Address; +import java.net.InetAddress; + +/** + * An Internet network, denoted by its address and netmask + *

+ * Instances of this class are immutable. + */ + +public final class InetNetwork { + private final InetAddress address; + private final int mask; + + private InetNetwork(final InetAddress address, final int mask) { + this.address = address; + this.mask = mask; + } + + public static InetNetwork parse(final String network) throws ParseException { + final int slash = network.lastIndexOf('/'); + final String maskString; + final int rawMask; + final String rawAddress; + if (slash >= 0) { + maskString = network.substring(slash + 1); + try { + rawMask = Integer.parseInt(maskString, 10); + } catch (final NumberFormatException ignored) { + throw new ParseException(Integer.class, maskString); + } + rawAddress = network.substring(0, slash); + } else { + maskString = ""; + rawMask = -1; + rawAddress = network; + } + final InetAddress address = InetAddresses.parse(rawAddress); + final int maxMask = (address instanceof Inet4Address) ? 32 : 128; + if (rawMask > maxMask) + throw new ParseException(InetNetwork.class, maskString, "Invalid network mask"); + final int mask = rawMask >= 0 ? rawMask : maxMask; + return new InetNetwork(address, mask); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof InetNetwork)) + return false; + final InetNetwork other = (InetNetwork) obj; + return address.equals(other.address) && mask == other.mask; + } + + public InetAddress getAddress() { + return address; + } + + public int getMask() { + return mask; + } + + @Override + public int hashCode() { + return address.hashCode() ^ mask; + } + + @Override + public String toString() { + return address.getHostAddress() + '/' + mask; + } +} diff --git a/client/android/src/com/wireguard/config/Interface.java b/client/android/src/com/wireguard/config/Interface.java new file mode 100644 index 00000000..2594d701 --- /dev/null +++ b/client/android/src/com/wireguard/config/Interface.java @@ -0,0 +1,394 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import com.wireguard.config.BadConfigException.Location; +import com.wireguard.config.BadConfigException.Reason; +import com.wireguard.config.BadConfigException.Section; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; +import com.wireguard.crypto.KeyPair; + +import java.net.InetAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import androidx.annotation.Nullable; + +/** + * Represents the configuration for a WireGuard interface (an [Interface] block). Interfaces must + * have a private key (used to initialize a {@code KeyPair}), and may optionally have several other + * attributes. + *

+ * Instances of this class are immutable. + */ + +public final class Interface { + private static final int MAX_UDP_PORT = 65535; + private static final int MIN_UDP_PORT = 0; + + private final Set addresses; + private final Set dnsServers; + private final Set excludedApplications; + private final Set includedApplications; + private final KeyPair keyPair; + private final Optional listenPort; + private final Optional mtu; + + private Interface(final Builder builder) { + // Defensively copy to ensure immutability even if the Builder is reused. + addresses = Collections.unmodifiableSet(new LinkedHashSet<>(builder.addresses)); + dnsServers = Collections.unmodifiableSet(new LinkedHashSet<>(builder.dnsServers)); + excludedApplications = + Collections.unmodifiableSet(new LinkedHashSet<>(builder.excludedApplications)); + includedApplications = + Collections.unmodifiableSet(new LinkedHashSet<>(builder.includedApplications)); + keyPair = Objects.requireNonNull(builder.keyPair, "Interfaces must have a private key"); + listenPort = builder.listenPort; + mtu = builder.mtu; + } + + /** + * Parses an series of "KEY = VALUE" lines into an {@code Interface}. Throws + * {@link ParseException} if the input is not well-formed or contains unknown attributes. + * + * @param lines An iterable sequence of lines, containing at least a private key attribute + * @return An {@code Interface} with all of the attributes from {@code lines} set + */ + public static Interface parse(final Iterable lines) + throws BadConfigException { + final Builder builder = new Builder(); + for (final CharSequence line : lines) { + final Attribute attribute = + Attribute.parse(line).orElseThrow(() + -> new BadConfigException(Section.INTERFACE, + Location.TOP_LEVEL, Reason.SYNTAX_ERROR, line)); + switch (attribute.getKey().toLowerCase(Locale.ENGLISH)) { + case "address": + builder.parseAddresses(attribute.getValue()); + break; + case "dns": + builder.parseDnsServers(attribute.getValue()); + break; + case "excludedapplications": + builder.parseExcludedApplications(attribute.getValue()); + break; + case "includedapplications": + builder.parseIncludedApplications(attribute.getValue()); + break; + case "listenport": + builder.parseListenPort(attribute.getValue()); + break; + case "mtu": + builder.parseMtu(attribute.getValue()); + break; + case "privatekey": + builder.parsePrivateKey(attribute.getValue()); + break; + default: + throw new BadConfigException( + Section.INTERFACE, Location.TOP_LEVEL, Reason.UNKNOWN_ATTRIBUTE, attribute.getKey()); + } + } + return builder.build(); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Interface)) + return false; + final Interface other = (Interface) obj; + return addresses.equals(other.addresses) && dnsServers.equals(other.dnsServers) + && excludedApplications.equals(other.excludedApplications) + && includedApplications.equals(other.includedApplications) && keyPair.equals(other.keyPair) + && listenPort.equals(other.listenPort) && mtu.equals(other.mtu); + } + + /** + * Returns the set of IP addresses assigned to the interface. + * + * @return a set of {@link InetNetwork}s + */ + public Set getAddresses() { + // The collection is already immutable. + return addresses; + } + + /** + * Returns the set of DNS servers associated with the interface. + * + * @return a set of {@link InetAddress}es + */ + public Set getDnsServers() { + // The collection is already immutable. + return dnsServers; + } + + /** + * Returns the set of applications excluded from using the interface. + * + * @return a set of package names + */ + public Set getExcludedApplications() { + // The collection is already immutable. + return excludedApplications; + } + + /** + * Returns the set of applications included exclusively for using the interface. + * + * @return a set of package names + */ + public Set getIncludedApplications() { + // The collection is already immutable. + return includedApplications; + } + + /** + * Returns the public/private key pair used by the interface. + * + * @return a key pair + */ + public KeyPair getKeyPair() { + return keyPair; + } + + /** + * Returns the UDP port number that the WireGuard interface will listen on. + * + * @return a UDP port number, or {@code Optional.empty()} if none is configured + */ + public Optional getListenPort() { + return listenPort; + } + + /** + * Returns the MTU used for the WireGuard interface. + * + * @return the MTU, or {@code Optional.empty()} if none is configured + */ + public Optional getMtu() { + return mtu; + } + + @Override + public int hashCode() { + int hash = 1; + hash = 31 * hash + addresses.hashCode(); + hash = 31 * hash + dnsServers.hashCode(); + hash = 31 * hash + excludedApplications.hashCode(); + hash = 31 * hash + includedApplications.hashCode(); + hash = 31 * hash + keyPair.hashCode(); + hash = 31 * hash + listenPort.hashCode(); + hash = 31 * hash + mtu.hashCode(); + return hash; + } + + /** + * Converts the {@code Interface} into a string suitable for debugging purposes. The {@code + * Interface} is identified by its public key and (if set) the port used for its UDP socket. + * + * @return A concise single-line identifier for the {@code Interface} + */ + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("(Interface "); + sb.append(keyPair.getPublicKey().toBase64()); + listenPort.ifPresent(lp -> sb.append(" @").append(lp)); + sb.append(')'); + return sb.toString(); + } + + /** + * Converts the {@code Interface} into a string suitable for inclusion in a {@code wg-quick} + * configuration file. + * + * @return The {@code Interface} represented as a series of "Key = Value" lines + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + if (!addresses.isEmpty()) + sb.append("Address = ").append(Attribute.join(addresses)).append('\n'); + if (!dnsServers.isEmpty()) { + final List dnsServerStrings = + dnsServers.stream().map(InetAddress::getHostAddress).collect(Collectors.toList()); + sb.append("DNS = ").append(Attribute.join(dnsServerStrings)).append('\n'); + } + if (!excludedApplications.isEmpty()) + sb.append("ExcludedApplications = ") + .append(Attribute.join(excludedApplications)) + .append('\n'); + if (!includedApplications.isEmpty()) + sb.append("IncludedApplications = ") + .append(Attribute.join(includedApplications)) + .append('\n'); + listenPort.ifPresent(lp -> sb.append("ListenPort = ").append(lp).append('\n')); + mtu.ifPresent(m -> sb.append("MTU = ").append(m).append('\n')); + sb.append("PrivateKey = ").append(keyPair.getPrivateKey().toBase64()).append('\n'); + return sb.toString(); + } + + /** + * Serializes the {@code Interface} for use with the WireGuard cross-platform userspace API. + * Note that not all attributes are included in this representation. + * + * @return the {@code Interface} represented as a series of "KEY=VALUE" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + sb.append("private_key=").append(keyPair.getPrivateKey().toHex()).append('\n'); + listenPort.ifPresent(lp -> sb.append("listen_port=").append(lp).append('\n')); + return sb.toString(); + } + + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // Defaults to an empty set. + private final Set addresses = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set dnsServers = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set excludedApplications = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set includedApplications = new LinkedHashSet<>(); + // No default; must be provided before building. + @Nullable private KeyPair keyPair; + // Defaults to not present. + private Optional listenPort = Optional.empty(); + // Defaults to not present. + private Optional mtu = Optional.empty(); + + public Builder addAddress(final InetNetwork address) { + addresses.add(address); + return this; + } + + public Builder addAddresses(final Collection addresses) { + this.addresses.addAll(addresses); + return this; + } + + public Builder addDnsServer(final InetAddress dnsServer) { + dnsServers.add(dnsServer); + return this; + } + + public Builder addDnsServers(final Collection dnsServers) { + this.dnsServers.addAll(dnsServers); + return this; + } + + public Interface build() throws BadConfigException { + if (keyPair == null) + throw new BadConfigException( + Section.INTERFACE, Location.PRIVATE_KEY, Reason.MISSING_ATTRIBUTE, null); + if (!includedApplications.isEmpty() && !excludedApplications.isEmpty()) + throw new BadConfigException( + Section.INTERFACE, Location.INCLUDED_APPLICATIONS, Reason.INVALID_KEY, null); + return new Interface(this); + } + + public Builder excludeApplication(final String application) { + excludedApplications.add(application); + return this; + } + + public Builder excludeApplications(final Collection applications) { + excludedApplications.addAll(applications); + return this; + } + + public Builder includeApplication(final String application) { + includedApplications.add(application); + return this; + } + + public Builder includeApplications(final Collection applications) { + includedApplications.addAll(applications); + return this; + } + + public Builder parseAddresses(final CharSequence addresses) throws BadConfigException { + try { + for (final String address : Attribute.split(addresses)) + addAddress(InetNetwork.parse(address)); + return this; + } catch (final ParseException e) { + throw new BadConfigException(Section.INTERFACE, Location.ADDRESS, e); + } + } + + public Builder parseDnsServers(final CharSequence dnsServers) throws BadConfigException { + try { + for (final String dnsServer : Attribute.split(dnsServers)) + addDnsServer(InetAddresses.parse(dnsServer)); + return this; + } catch (final ParseException e) { + throw new BadConfigException(Section.INTERFACE, Location.DNS, e); + } + } + + public Builder parseExcludedApplications(final CharSequence apps) { + return excludeApplications(List.of(Attribute.split(apps))); + } + + public Builder parseIncludedApplications(final CharSequence apps) { + return includeApplications(List.of(Attribute.split(apps))); + } + + public Builder parseListenPort(final String listenPort) throws BadConfigException { + try { + return setListenPort(Integer.parseInt(listenPort)); + } catch (final NumberFormatException e) { + throw new BadConfigException(Section.INTERFACE, Location.LISTEN_PORT, listenPort, e); + } + } + + public Builder parseMtu(final String mtu) throws BadConfigException { + try { + return setMtu(Integer.parseInt(mtu)); + } catch (final NumberFormatException e) { + throw new BadConfigException(Section.INTERFACE, Location.MTU, mtu, e); + } + } + + public Builder parsePrivateKey(final String privateKey) throws BadConfigException { + try { + return setKeyPair(new KeyPair(Key.fromBase64(privateKey))); + } catch (final KeyFormatException e) { + throw new BadConfigException(Section.INTERFACE, Location.PRIVATE_KEY, e); + } + } + + public Builder setKeyPair(final KeyPair keyPair) { + this.keyPair = keyPair; + return this; + } + + public Builder setListenPort(final int listenPort) throws BadConfigException { + if (listenPort < MIN_UDP_PORT || listenPort > MAX_UDP_PORT) + throw new BadConfigException(Section.INTERFACE, Location.LISTEN_PORT, Reason.INVALID_VALUE, + String.valueOf(listenPort)); + this.listenPort = listenPort == 0 ? Optional.empty() : Optional.of(listenPort); + return this; + } + + public Builder setMtu(final int mtu) throws BadConfigException { + if (mtu < 0) + throw new BadConfigException( + Section.INTERFACE, Location.LISTEN_PORT, Reason.INVALID_VALUE, String.valueOf(mtu)); + this.mtu = mtu == 0 ? Optional.empty() : Optional.of(mtu); + return this; + } + } +} diff --git a/client/android/src/com/wireguard/config/ParseException.java b/client/android/src/com/wireguard/config/ParseException.java new file mode 100644 index 00000000..57d19e9b --- /dev/null +++ b/client/android/src/com/wireguard/config/ParseException.java @@ -0,0 +1,46 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +/** + * + */ + +public class ParseException extends Exception { + private final Class parsingClass; + private final CharSequence text; + + public ParseException(final Class parsingClass, final CharSequence text, + @Nullable final String message, @Nullable final Throwable cause) { + super(message, cause); + this.parsingClass = parsingClass; + this.text = text; + } + + public ParseException( + final Class parsingClass, final CharSequence text, @Nullable final String message) { + this(parsingClass, text, message, null); + } + + public ParseException( + final Class parsingClass, final CharSequence text, @Nullable final Throwable cause) { + this(parsingClass, text, null, cause); + } + + public ParseException(final Class parsingClass, final CharSequence text) { + this(parsingClass, text, null, null); + } + + public Class getParsingClass() { + return parsingClass; + } + + public CharSequence getText() { + return text; + } +} diff --git a/client/android/src/com/wireguard/config/Peer.java b/client/android/src/com/wireguard/config/Peer.java new file mode 100644 index 00000000..8b66b7d0 --- /dev/null +++ b/client/android/src/com/wireguard/config/Peer.java @@ -0,0 +1,306 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import com.wireguard.config.BadConfigException.Location; +import com.wireguard.config.BadConfigException.Reason; +import com.wireguard.config.BadConfigException.Section; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; + +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import androidx.annotation.Nullable; + +/** + * Represents the configuration for a WireGuard peer (a [Peer] block). Peers must have a public key, + * and may optionally have several other attributes. + *

+ * Instances of this class are immutable. + */ + +public final class Peer { + private final Set allowedIps; + private final Optional endpoint; + private final Optional persistentKeepalive; + private final Optional preSharedKey; + private final Key publicKey; + + private Peer(final Builder builder) { + // Defensively copy to ensure immutability even if the Builder is reused. + allowedIps = Collections.unmodifiableSet(new LinkedHashSet<>(builder.allowedIps)); + endpoint = builder.endpoint; + persistentKeepalive = builder.persistentKeepalive; + preSharedKey = builder.preSharedKey; + publicKey = Objects.requireNonNull(builder.publicKey, "Peers must have a public key"); + } + + /** + * Parses an series of "KEY = VALUE" lines into a {@code Peer}. Throws {@link ParseException} if + * the input is not well-formed or contains unknown attributes. + * + * @param lines an iterable sequence of lines, containing at least a public key attribute + * @return a {@code Peer} with all of its attributes set from {@code lines} + */ + public static Peer parse(final Iterable lines) throws BadConfigException { + final Builder builder = new Builder(); + for (final CharSequence line : lines) { + final Attribute attribute = + Attribute.parse(line).orElseThrow(() + -> new BadConfigException(Section.PEER, + Location.TOP_LEVEL, Reason.SYNTAX_ERROR, line)); + switch (attribute.getKey().toLowerCase(Locale.ENGLISH)) { + case "allowedips": + builder.parseAllowedIPs(attribute.getValue()); + break; + case "endpoint": + builder.parseEndpoint(attribute.getValue()); + break; + case "persistentkeepalive": + builder.parsePersistentKeepalive(attribute.getValue()); + break; + case "presharedkey": + builder.parsePreSharedKey(attribute.getValue()); + break; + case "publickey": + builder.parsePublicKey(attribute.getValue()); + break; + default: + throw new BadConfigException( + Section.PEER, Location.TOP_LEVEL, Reason.UNKNOWN_ATTRIBUTE, attribute.getKey()); + } + } + return builder.build(); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Peer)) + return false; + final Peer other = (Peer) obj; + return allowedIps.equals(other.allowedIps) && endpoint.equals(other.endpoint) + && persistentKeepalive.equals(other.persistentKeepalive) + && preSharedKey.equals(other.preSharedKey) && publicKey.equals(other.publicKey); + } + + /** + * Returns the peer's set of allowed IPs. + * + * @return the set of allowed IPs + */ + public Set getAllowedIps() { + // The collection is already immutable. + return allowedIps; + } + + /** + * Returns the peer's endpoint. + * + * @return the endpoint, or {@code Optional.empty()} if none is configured + */ + public Optional getEndpoint() { + return endpoint; + } + + /** + * Returns the peer's persistent keepalive. + * + * @return the persistent keepalive, or {@code Optional.empty()} if none is configured + */ + public Optional getPersistentKeepalive() { + return persistentKeepalive; + } + + /** + * Returns the peer's pre-shared key. + * + * @return the pre-shared key, or {@code Optional.empty()} if none is configured + */ + public Optional getPreSharedKey() { + return preSharedKey; + } + + /** + * Returns the peer's public key. + * + * @return the public key + */ + public Key getPublicKey() { + return publicKey; + } + + @Override + public int hashCode() { + int hash = 1; + hash = 31 * hash + allowedIps.hashCode(); + hash = 31 * hash + endpoint.hashCode(); + hash = 31 * hash + persistentKeepalive.hashCode(); + hash = 31 * hash + preSharedKey.hashCode(); + hash = 31 * hash + publicKey.hashCode(); + return hash; + } + + /** + * Converts the {@code Peer} into a string suitable for debugging purposes. The {@code Peer} is + * identified by its public key and (if known) its endpoint. + * + * @return a concise single-line identifier for the {@code Peer} + */ + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("(Peer "); + sb.append(publicKey.toBase64()); + endpoint.ifPresent(ep -> sb.append(" @").append(ep)); + sb.append(')'); + return sb.toString(); + } + + /** + * Converts the {@code Peer} into a string suitable for inclusion in a {@code wg-quick} + * configuration file. + * + * @return the {@code Peer} represented as a series of "Key = Value" lines + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + if (!allowedIps.isEmpty()) + sb.append("AllowedIPs = ").append(Attribute.join(allowedIps)).append('\n'); + endpoint.ifPresent(ep -> sb.append("Endpoint = ").append(ep).append('\n')); + persistentKeepalive.ifPresent( + pk -> sb.append("PersistentKeepalive = ").append(pk).append('\n')); + preSharedKey.ifPresent(psk -> sb.append("PreSharedKey = ").append(psk.toBase64()).append('\n')); + sb.append("PublicKey = ").append(publicKey.toBase64()).append('\n'); + return sb.toString(); + } + + /** + * Serializes the {@code Peer} for use with the WireGuard cross-platform userspace API. Note + * that not all attributes are included in this representation. + * + * @return the {@code Peer} represented as a series of "key=value" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + // The order here is important: public_key signifies the beginning of a new peer. + sb.append("public_key=").append(publicKey.toHex()).append('\n'); + for (final InetNetwork allowedIp : allowedIps) + sb.append("allowed_ip=").append(allowedIp).append('\n'); + endpoint.flatMap(InetEndpoint::getResolved) + .ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + persistentKeepalive.ifPresent( + pk -> sb.append("persistent_keepalive_interval=").append(pk).append('\n')); + preSharedKey.ifPresent(psk -> sb.append("preshared_key=").append(psk.toHex()).append('\n')); + return sb.toString(); + } + + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // See wg(8) + private static final int MAX_PERSISTENT_KEEPALIVE = 65535; + + // Defaults to an empty set. + private final Set allowedIps = new LinkedHashSet<>(); + // Defaults to not present. + private Optional endpoint = Optional.empty(); + // Defaults to not present. + private Optional persistentKeepalive = Optional.empty(); + // Defaults to not present. + private Optional preSharedKey = Optional.empty(); + // No default; must be provided before building. + @Nullable private Key publicKey; + + public Builder addAllowedIp(final InetNetwork allowedIp) { + allowedIps.add(allowedIp); + return this; + } + + public Builder addAllowedIps(final Collection allowedIps) { + this.allowedIps.addAll(allowedIps); + return this; + } + + public Peer build() throws BadConfigException { + if (publicKey == null) + throw new BadConfigException( + Section.PEER, Location.PUBLIC_KEY, Reason.MISSING_ATTRIBUTE, null); + return new Peer(this); + } + + public Builder parseAllowedIPs(final CharSequence allowedIps) throws BadConfigException { + try { + for (final String allowedIp : Attribute.split(allowedIps)) + addAllowedIp(InetNetwork.parse(allowedIp)); + return this; + } catch (final ParseException e) { + throw new BadConfigException(Section.PEER, Location.ALLOWED_IPS, e); + } + } + + public Builder parseEndpoint(final String endpoint) throws BadConfigException { + try { + return setEndpoint(InetEndpoint.parse(endpoint)); + } catch (final ParseException e) { + throw new BadConfigException(Section.PEER, Location.ENDPOINT, e); + } + } + + public Builder parsePersistentKeepalive(final String persistentKeepalive) + throws BadConfigException { + try { + return setPersistentKeepalive(Integer.parseInt(persistentKeepalive)); + } catch (final NumberFormatException e) { + throw new BadConfigException( + Section.PEER, Location.PERSISTENT_KEEPALIVE, persistentKeepalive, e); + } + } + + public Builder parsePreSharedKey(final String preSharedKey) throws BadConfigException { + try { + return setPreSharedKey(Key.fromBase64(preSharedKey)); + } catch (final KeyFormatException e) { + throw new BadConfigException(Section.PEER, Location.PRE_SHARED_KEY, e); + } + } + + public Builder parsePublicKey(final String publicKey) throws BadConfigException { + try { + return setPublicKey(Key.fromBase64(publicKey)); + } catch (final KeyFormatException e) { + throw new BadConfigException(Section.PEER, Location.PUBLIC_KEY, e); + } + } + + public Builder setEndpoint(final InetEndpoint endpoint) { + this.endpoint = Optional.of(endpoint); + return this; + } + + public Builder setPersistentKeepalive(final int persistentKeepalive) throws BadConfigException { + if (persistentKeepalive < 0 || persistentKeepalive > MAX_PERSISTENT_KEEPALIVE) + throw new BadConfigException(Section.PEER, Location.PERSISTENT_KEEPALIVE, + Reason.INVALID_VALUE, String.valueOf(persistentKeepalive)); + this.persistentKeepalive = + persistentKeepalive == 0 ? Optional.empty() : Optional.of(persistentKeepalive); + return this; + } + + public Builder setPreSharedKey(final Key preSharedKey) { + this.preSharedKey = Optional.of(preSharedKey); + return this; + } + + public Builder setPublicKey(final Key publicKey) { + this.publicKey = publicKey; + return this; + } + } +} diff --git a/client/android/src/com/wireguard/crypto/Curve25519.java b/client/android/src/com/wireguard/crypto/Curve25519.java new file mode 100644 index 00000000..9c314202 --- /dev/null +++ b/client/android/src/com/wireguard/crypto/Curve25519.java @@ -0,0 +1,497 @@ +/* + * Copyright © 2016 Southern Storm Software, Pty Ltd. + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import java.util.Arrays; + +import androidx.annotation.Nullable; + +/** + * Implementation of Curve25519 ECDH. + *

+ * This implementation was imported to WireGuard from noise-java: + * https://github.com/rweather/noise-java + *

+ * This implementation is based on that from arduinolibs: + * https://github.com/rweather/arduinolibs + *

+ * Differences in this version are due to using 26-bit limbs for the + * representation instead of the 8/16/32-bit limbs in the original. + *

+ * References: http://cr.yp.to/ecdh.html, RFC 7748 + */ +@SuppressWarnings({"MagicNumber", "NonConstantFieldWithUpperCaseName", "SuspiciousNameCombination"}) + +public final class Curve25519 { + // Numbers modulo 2^255 - 19 are broken up into ten 26-bit words. + private static final int NUM_LIMBS_255BIT = 10; + private static final int NUM_LIMBS_510BIT = 20; + + private final int[] A; + private final int[] AA; + private final int[] B; + private final int[] BB; + private final int[] C; + private final int[] CB; + private final int[] D; + private final int[] DA; + private final int[] E; + private final long[] t1; + private final int[] t2; + private final int[] x_1; + private final int[] x_2; + private final int[] x_3; + private final int[] z_2; + private final int[] z_3; + + /** + * Constructs the temporary state holder for Curve25519 evaluation. + */ + private Curve25519() { + // Allocate memory for all of the temporary variables we will need. + x_1 = new int[NUM_LIMBS_255BIT]; + x_2 = new int[NUM_LIMBS_255BIT]; + x_3 = new int[NUM_LIMBS_255BIT]; + z_2 = new int[NUM_LIMBS_255BIT]; + z_3 = new int[NUM_LIMBS_255BIT]; + A = new int[NUM_LIMBS_255BIT]; + B = new int[NUM_LIMBS_255BIT]; + C = new int[NUM_LIMBS_255BIT]; + D = new int[NUM_LIMBS_255BIT]; + E = new int[NUM_LIMBS_255BIT]; + AA = new int[NUM_LIMBS_255BIT]; + BB = new int[NUM_LIMBS_255BIT]; + DA = new int[NUM_LIMBS_255BIT]; + CB = new int[NUM_LIMBS_255BIT]; + t1 = new long[NUM_LIMBS_510BIT]; + t2 = new int[NUM_LIMBS_510BIT]; + } + + /** + * Conditional swap of two values. + * + * @param select Set to 1 to swap, 0 to leave as-is. + * @param x The first value. + * @param y The second value. + */ + private static void cswap(int select, final int[] x, final int[] y) { + select = -select; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + final int dummy = select & (x[index] ^ y[index]); + x[index] ^= dummy; + y[index] ^= dummy; + } + } + + /** + * Evaluates the Curve25519 curve. + * + * @param result Buffer to place the result of the evaluation into. + * @param offset Offset into the result buffer. + * @param privateKey The private key to use in the evaluation. + * @param publicKey The public key to use in the evaluation, or null + * if the base point of the curve should be used. + */ + public static void eval(final byte[] result, final int offset, final byte[] privateKey, + @Nullable final byte[] publicKey) { + final Curve25519 state = new Curve25519(); + try { + // Unpack the public key value. If null, use 9 as the base point. + Arrays.fill(state.x_1, 0); + if (publicKey != null) { + // Convert the input value from little-endian into 26-bit limbs. + for (int index = 0; index < 32; ++index) { + final int bit = (index * 8) % 26; + final int word = (index * 8) / 26; + final int value = publicKey[index] & 0xFF; + if (bit <= (26 - 8)) { + state.x_1[word] |= value << bit; + } else { + state.x_1[word] |= value << bit; + state.x_1[word] &= 0x03FFFFFF; + state.x_1[word + 1] |= value >> (26 - bit); + } + } + + // Just in case, we reduce the number modulo 2^255 - 19 to + // make sure that it is in range of the field before we start. + // This eliminates values between 2^255 - 19 and 2^256 - 1. + state.reduceQuick(state.x_1); + state.reduceQuick(state.x_1); + } else { + state.x_1[0] = 9; + } + + // Initialize the other temporary variables. + Arrays.fill(state.x_2, 0); // x_2 = 1 + state.x_2[0] = 1; + Arrays.fill(state.z_2, 0); // z_2 = 0 + System.arraycopy(state.x_1, 0, state.x_3, 0, state.x_1.length); // x_3 = x_1 + Arrays.fill(state.z_3, 0); // z_3 = 1 + state.z_3[0] = 1; + + // Evaluate the curve for every bit of the private key. + state.evalCurve(privateKey); + + // Compute x_2 * (z_2 ^ (p - 2)) where p = 2^255 - 19. + state.recip(state.z_3, state.z_2); + state.mul(state.x_2, state.x_2, state.z_3); + + // Convert x_2 into little-endian in the result buffer. + for (int index = 0; index < 32; ++index) { + final int bit = (index * 8) % 26; + final int word = (index * 8) / 26; + if (bit <= (26 - 8)) + result[offset + index] = (byte) (state.x_2[word] >> bit); + else + result[offset + index] = + (byte) ((state.x_2[word] >> bit) | (state.x_2[word + 1] << (26 - bit))); + } + } finally { + // Clean up all temporary state before we exit. + state.destroy(); + } + } + + /** + * Subtracts two numbers modulo 2^255 - 19. + * + * @param result The result. + * @param x The first number to subtract. + * @param y The second number to subtract. + */ + private static void sub(final int[] result, final int[] x, final int[] y) { + int index; + int borrow; + + // Subtract y from x to generate the intermediate result. + borrow = 0; + for (index = 0; index < NUM_LIMBS_255BIT; ++index) { + borrow = x[index] - y[index] - ((borrow >> 26) & 0x01); + result[index] = borrow & 0x03FFFFFF; + } + + // If we had a borrow, then the result has gone negative and we + // have to add 2^255 - 19 to the result to make it positive again. + // The top bits of "borrow" will be all 1's if there is a borrow + // or it will be all 0's if there was no borrow. Easiest is to + // conditionally subtract 19 and then mask off the high bits. + borrow = result[0] - ((-((borrow >> 26) & 0x01)) & 19); + result[0] = borrow & 0x03FFFFFF; + for (index = 1; index < NUM_LIMBS_255BIT; ++index) { + borrow = result[index] - ((borrow >> 26) & 0x01); + result[index] = borrow & 0x03FFFFFF; + } + result[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + } + + /** + * Adds two numbers modulo 2^255 - 19. + * + * @param result The result. + * @param x The first number to add. + * @param y The second number to add. + */ + private void add(final int[] result, final int[] x, final int[] y) { + int carry = x[0] + y[0]; + result[0] = carry & 0x03FFFFFF; + for (int index = 1; index < NUM_LIMBS_255BIT; ++index) { + carry = (carry >> 26) + x[index] + y[index]; + result[index] = carry & 0x03FFFFFF; + } + reduceQuick(result); + } + + /** + * Destroy all sensitive data in this object. + */ + private void destroy() { + // Destroy all temporary variables. + Arrays.fill(x_1, 0); + Arrays.fill(x_2, 0); + Arrays.fill(x_3, 0); + Arrays.fill(z_2, 0); + Arrays.fill(z_3, 0); + Arrays.fill(A, 0); + Arrays.fill(B, 0); + Arrays.fill(C, 0); + Arrays.fill(D, 0); + Arrays.fill(E, 0); + Arrays.fill(AA, 0); + Arrays.fill(BB, 0); + Arrays.fill(DA, 0); + Arrays.fill(CB, 0); + Arrays.fill(t1, 0L); + Arrays.fill(t2, 0); + } + + /** + * Evaluates the curve for every bit in a secret key. + * + * @param s The 32-byte secret key. + */ + private void evalCurve(final byte[] s) { + int sposn = 31; + int sbit = 6; + int svalue = s[sposn] | 0x40; + int swap = 0; + + // Iterate over all 255 bits of "s" from the highest to the lowest. + // We ignore the high bit of the 256-bit representation of "s". + while (true) { + // Conditional swaps on entry to this bit but only if we + // didn't swap on the previous bit. + final int select = (svalue >> sbit) & 0x01; + swap ^= select; + cswap(swap, x_2, x_3); + cswap(swap, z_2, z_3); + swap = select; + + // Evaluate the curve. + add(A, x_2, z_2); // A = x_2 + z_2 + square(AA, A); // AA = A^2 + sub(B, x_2, z_2); // B = x_2 - z_2 + square(BB, B); // BB = B^2 + sub(E, AA, BB); // E = AA - BB + add(C, x_3, z_3); // C = x_3 + z_3 + sub(D, x_3, z_3); // D = x_3 - z_3 + mul(DA, D, A); // DA = D * A + mul(CB, C, B); // CB = C * B + add(x_3, DA, CB); // x_3 = (DA + CB)^2 + square(x_3, x_3); + sub(z_3, DA, CB); // z_3 = x_1 * (DA - CB)^2 + square(z_3, z_3); + mul(z_3, z_3, x_1); + mul(x_2, AA, BB); // x_2 = AA * BB + mulA24(z_2, E); // z_2 = E * (AA + a24 * E) + add(z_2, z_2, AA); + mul(z_2, z_2, E); + + // Move onto the next lower bit of "s". + if (sbit > 0) { + --sbit; + } else if (sposn == 0) { + break; + } else if (sposn == 1) { + --sposn; + svalue = s[sposn] & 0xF8; + sbit = 7; + } else { + --sposn; + svalue = s[sposn]; + sbit = 7; + } + } + + // Final conditional swaps. + cswap(swap, x_2, x_3); + cswap(swap, z_2, z_3); + } + + /** + * Multiplies two numbers modulo 2^255 - 19. + * + * @param result The result. + * @param x The first number to multiply. + * @param y The second number to multiply. + */ + private void mul(final int[] result, final int[] x, final int[] y) { + // Multiply the two numbers to create the intermediate result. + long v = x[0]; + for (int i = 0; i < NUM_LIMBS_255BIT; ++i) { + t1[i] = v * y[i]; + } + for (int i = 1; i < NUM_LIMBS_255BIT; ++i) { + v = x[i]; + for (int j = 0; j < (NUM_LIMBS_255BIT - 1); ++j) { + t1[i + j] += v * y[j]; + } + t1[i + NUM_LIMBS_255BIT - 1] = v * y[NUM_LIMBS_255BIT - 1]; + } + + // Propagate carries and convert back into 26-bit words. + v = t1[0]; + t2[0] = ((int) v) & 0x03FFFFFF; + for (int i = 1; i < NUM_LIMBS_510BIT; ++i) { + v = (v >> 26) + t1[i]; + t2[i] = ((int) v) & 0x03FFFFFF; + } + + // Reduce the result modulo 2^255 - 19. + reduce(result, t2, NUM_LIMBS_255BIT); + } + + /** + * Multiplies a number by the a24 constant, modulo 2^255 - 19. + * + * @param result The result. + * @param x The number to multiply by a24. + */ + private void mulA24(final int[] result, final int[] x) { + final long a24 = 121665; + long carry = 0; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + carry += a24 * x[index]; + t2[index] = ((int) carry) & 0x03FFFFFF; + carry >>= 26; + } + t2[NUM_LIMBS_255BIT] = ((int) carry) & 0x03FFFFFF; + reduce(result, t2, 1); + } + + /** + * Raise x to the power of (2^250 - 1). + * + * @param result The result. Must not overlap with x. + * @param x The argument. + */ + private void pow250(final int[] result, final int[] x) { + // The big-endian hexadecimal expansion of (2^250 - 1) is: + // 03FFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF + // + // The naive implementation needs to do 2 multiplications per 1 bit and + // 1 multiplication per 0 bit. We can improve upon this by creating a + // pattern 0000000001 ... 0000000001. If we square and multiply the + // pattern by itself we can turn the pattern into the partial results + // 0000000011 ... 0000000011, 0000000111 ... 0000000111, etc. + // This averages out to about 1.1 multiplications per 1 bit instead of 2. + + // Build a pattern of 250 bits in length of repeated copies of 0000000001. + square(A, x); + for (int j = 0; j < 9; ++j) square(A, A); + mul(result, A, x); + for (int i = 0; i < 23; ++i) { + for (int j = 0; j < 10; ++j) square(A, A); + mul(result, result, A); + } + + // Multiply bit-shifted versions of the 0000000001 pattern into + // the result to "fill in" the gaps in the pattern. + square(A, result); + mul(result, result, A); + for (int j = 0; j < 8; ++j) { + square(A, A); + mul(result, result, A); + } + } + + /** + * Computes the reciprocal of a number modulo 2^255 - 19. + * + * @param result The result. Must not overlap with x. + * @param x The argument. + */ + private void recip(final int[] result, final int[] x) { + // The reciprocal is the same as x ^ (p - 2) where p = 2^255 - 19. + // The big-endian hexadecimal expansion of (p - 2) is: + // 7FFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFEB + // Start with the 250 upper bits of the expansion of (p - 2). + pow250(result, x); + + // Deal with the 5 lowest bits of (p - 2), 01011, from highest to lowest. + square(result, result); + square(result, result); + mul(result, result, x); + square(result, result); + square(result, result); + mul(result, result, x); + square(result, result); + mul(result, result, x); + } + + /** + * Reduce a number modulo 2^255 - 19. + * + * @param result The result. + * @param x The value to be reduced. This array will be + * modified during the reduction. + * @param size The number of limbs in the high order half of x. + */ + private void reduce(final int[] result, final int[] x, final int size) { + // Calculate (x mod 2^255) + ((x / 2^255) * 19) which will + // either produce the answer we want or it will produce a + // value of the form "answer + j * (2^255 - 19)". There are + // 5 left-over bits in the top-most limb of the bottom half. + int carry = 0; + int limb = x[NUM_LIMBS_255BIT - 1] >> 21; + x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + for (int index = 0; index < size; ++index) { + limb += x[NUM_LIMBS_255BIT + index] << 5; + carry += (limb & 0x03FFFFFF) * 19 + x[index]; + x[index] = carry & 0x03FFFFFF; + limb >>= 26; + carry >>= 26; + } + if (size < NUM_LIMBS_255BIT) { + // The high order half of the number is short; e.g. for mulA24(). + // Propagate the carry through the rest of the low order part. + for (int index = size; index < NUM_LIMBS_255BIT; ++index) { + carry += x[index]; + x[index] = carry & 0x03FFFFFF; + carry >>= 26; + } + } + + // The "j" value may still be too large due to the final carry-out. + // We must repeat the reduction. If we already have the answer, + // then this won't do any harm but we must still do the calculation + // to preserve the overall timing. The "j" value will be between + // 0 and 19, which means that the carry we care about is in the + // top 5 bits of the highest limb of the bottom half. + carry = (x[NUM_LIMBS_255BIT - 1] >> 21) * 19; + x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + carry += x[index]; + result[index] = carry & 0x03FFFFFF; + carry >>= 26; + } + + // At this point "x" will either be the answer or it will be the + // answer plus (2^255 - 19). Perform a trial subtraction to + // complete the reduction process. + reduceQuick(result); + } + + /** + * Reduces a number modulo 2^255 - 19 where it is known that the + * number can be reduced with only 1 trial subtraction. + * + * @param x The number to reduce, and the result. + */ + private void reduceQuick(final int[] x) { + // Perform a trial subtraction of (2^255 - 19) from "x" which is + // equivalent to adding 19 and subtracting 2^255. We add 19 here; + // the subtraction of 2^255 occurs in the next step. + int carry = 19; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + carry += x[index]; + t2[index] = carry & 0x03FFFFFF; + carry >>= 26; + } + + // If there was a borrow, then the original "x" is the correct answer. + // If there was no borrow, then "t2" is the correct answer. Select the + // correct answer but do it in a way that instruction timing will not + // reveal which value was selected. Borrow will occur if bit 21 of + // "t2" is zero. Turn the bit into a selection mask. + final int mask = -((t2[NUM_LIMBS_255BIT - 1] >> 21) & 0x01); + final int nmask = ~mask; + t2[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) + x[index] = (x[index] & nmask) | (t2[index] & mask); + } + + /** + * Squares a number modulo 2^255 - 19. + * + * @param result The result. + * @param x The number to square. + */ + private void square(final int[] result, final int[] x) { + mul(result, x, x); + } +} diff --git a/client/android/src/com/wireguard/crypto/Ed25519.java b/client/android/src/com/wireguard/crypto/Ed25519.java new file mode 100644 index 00000000..b3dcd69c --- /dev/null +++ b/client/android/src/com/wireguard/crypto/Ed25519.java @@ -0,0 +1,2435 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright 2017 Google Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.util.Arrays; + +/** + * Implementation of Ed25519 signature verification. + * + *

This implementation is based on the ed25519/ref10 implementation in NaCl.

+ * + *

It implements this twisted Edwards curve: + * + *

+ * -x^2 + y^2 = 1 + (-121665 / 121666 mod 2^255-19)*x^2*y^2
+ * 
+ * + * @see Bernstein D.J., Birkner P., Joye M., Lange + * T., Peters C. (2008) Twisted Edwards Curves + * @see Hisil H., Wong K.KH., Carter G., Dawson E. + * (2008) Twisted Edwards Curves Revisited + */ +public final class Ed25519 { + // d = -121665 / 121666 mod 2^255-19 + private static final long[] D; + // 2d + private static final long[] D2; + // 2^((p-1)/4) mod p where p = 2^255-19 + private static final long[] SQRTM1; + + /** + * Base point for the Edwards twisted curve = (x, 4/5) and its exponentiations. B_TABLE[i][j] = + * (j+1)*256^i*B for i in [0, 32) and j in [0, 8). Base point B = B_TABLE[0][0] + */ + private static final CachedXYT[][] B_TABLE; + private static final CachedXYT[] B2; + + private static final BigInteger P_BI = + BigInteger.valueOf(2).pow(255).subtract(BigInteger.valueOf(19)); + private static final BigInteger D_BI = + BigInteger.valueOf(-121665).multiply(BigInteger.valueOf(121666).modInverse(P_BI)).mod(P_BI); + private static final BigInteger D2_BI = BigInteger.valueOf(2).multiply(D_BI).mod(P_BI); + private static final BigInteger SQRTM1_BI = BigInteger.valueOf(2).modPow( + P_BI.subtract(BigInteger.ONE).divide(BigInteger.valueOf(4)), P_BI); + + private Ed25519() {} + + private static class Point { + private BigInteger x; + private BigInteger y; + } + + private static BigInteger recoverX(BigInteger y) { + // x^2 = (y^2 - 1) / (d * y^2 + 1) mod 2^255-19 + BigInteger xx = y.pow(2) + .subtract(BigInteger.ONE) + .multiply(D_BI.multiply(y.pow(2)).add(BigInteger.ONE).modInverse(P_BI)); + BigInteger x = xx.modPow(P_BI.add(BigInteger.valueOf(3)).divide(BigInteger.valueOf(8)), P_BI); + if (!x.pow(2).subtract(xx).mod(P_BI).equals(BigInteger.ZERO)) { + x = x.multiply(SQRTM1_BI).mod(P_BI); + } + if (x.testBit(0)) { + x = P_BI.subtract(x); + } + return x; + } + + private static Point edwards(Point a, Point b) { + Point o = new Point(); + BigInteger xxyy = D_BI.multiply(a.x.multiply(b.x).multiply(a.y).multiply(b.y)).mod(P_BI); + o.x = (a.x.multiply(b.y).add(b.x.multiply(a.y))) + .multiply(BigInteger.ONE.add(xxyy).modInverse(P_BI)) + .mod(P_BI); + o.y = (a.y.multiply(b.y).add(a.x.multiply(b.x))) + .multiply(BigInteger.ONE.subtract(xxyy).modInverse(P_BI)) + .mod(P_BI); + return o; + } + + private static byte[] toLittleEndian(BigInteger n) { + byte[] b = new byte[32]; + byte[] nBytes = n.toByteArray(); + System.arraycopy(nBytes, 0, b, 32 - nBytes.length, nBytes.length); + for (int i = 0; i < b.length / 2; i++) { + byte t = b[i]; + b[i] = b[b.length - i - 1]; + b[b.length - i - 1] = t; + } + return b; + } + + private static CachedXYT getCachedXYT(Point p) { + return new CachedXYT(Field25519.expand(toLittleEndian(p.y.add(p.x).mod(P_BI))), + Field25519.expand(toLittleEndian(p.y.subtract(p.x).mod(P_BI))), + Field25519.expand(toLittleEndian(D2_BI.multiply(p.x).multiply(p.y).mod(P_BI)))); + } + + static { + Point b = new Point(); + b.y = BigInteger.valueOf(4).multiply(BigInteger.valueOf(5).modInverse(P_BI)).mod(P_BI); + b.x = recoverX(b.y); + + D = Field25519.expand(toLittleEndian(D_BI)); + D2 = Field25519.expand(toLittleEndian(D2_BI)); + SQRTM1 = Field25519.expand(toLittleEndian(SQRTM1_BI)); + + Point bi = b; + B_TABLE = new CachedXYT[32][8]; + for (int i = 0; i < 32; i++) { + Point bij = bi; + for (int j = 0; j < 8; j++) { + B_TABLE[i][j] = getCachedXYT(bij); + bij = edwards(bij, bi); + } + for (int j = 0; j < 8; j++) { + bi = edwards(bi, bi); + } + } + bi = b; + Point b2 = edwards(b, b); + B2 = new CachedXYT[8]; + for (int i = 0; i < 8; i++) { + B2[i] = getCachedXYT(bi); + bi = edwards(bi, b2); + } + } + + private static final int PUBLIC_KEY_LEN = Field25519.FIELD_LEN; + private static final int SIGNATURE_LEN = Field25519.FIELD_LEN * 2; + + /** + * Defines field 25519 function based on curve25519-donna + * C implementation (mostly identical). + * + *

Field elements are written as an array of signed, 64-bit limbs (an array of longs), least + * significant first. The value of the field element is: + * + *

+   * x[0] + 2^26·x[1] + 2^51·x[2] + 2^77·x[3] + 2^102·x[4] + 2^128·x[5] + 2^153·x[6] + 2^179·x[7] +
+   * 2^204·x[8] + 2^230·x[9],
+   * 
+ * + *

i.e. the limbs are 26, 25, 26, 25, ... bits wide. + */ + private static final class Field25519 { + /** + * During Field25519 computation, the mixed radix representation may be in different forms: + *

    + *
  • Reduced-size form: the array has size at most 10. + *
  • Non-reduced-size form: the array is not reduced modulo 2^255 - 19 and has size at most + * 19. + *
+ *

+ * TODO(quannguyen): + *

    + *
  • Clarify ill-defined terminologies. + *
  • The reduction procedure is different from DJB's paper + * (http://cr.yp.to/ecdh/curve25519-20060209.pdf). The coefficients after reducing degree and + * reducing coefficients aren't guaranteed to be in range {-2^25, ..., 2^25}. We should check + * to see what's going on.
  • Consider using method mult() everywhere and making product() + * private. + *
+ */ + + static final int FIELD_LEN = 32; + static final int LIMB_CNT = 10; + private static final long TWO_TO_25 = 1 << 25; + private static final long TWO_TO_26 = TWO_TO_25 << 1; + + private static final int[] EXPAND_START = {0, 3, 6, 9, 12, 16, 19, 22, 25, 28}; + private static final int[] EXPAND_SHIFT = {0, 2, 3, 5, 6, 0, 1, 3, 4, 6}; + private static final int[] MASK = {0x3ffffff, 0x1ffffff}; + private static final int[] SHIFT = {26, 25}; + + /** + * Sums two numbers: output = in1 + in2 + *

+ * On entry: in1, in2 are in reduced-size form. + */ + static void sum(long[] output, long[] in1, long[] in2) { + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = in1[i] + in2[i]; + } + } + + /** + * Sums two numbers: output += in + *

+ * On entry: in is in reduced-size form. + */ + static void sum(long[] output, long[] in) { + sum(output, output, in); + } + + /** + * Find the difference of two numbers: output = in1 - in2 + * (note the order of the arguments!). + *

+ * On entry: in1, in2 are in reduced-size form. + */ + static void sub(long[] output, long[] in1, long[] in2) { + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = in1[i] - in2[i]; + } + } + + /** + * Find the difference of two numbers: output = in - output + * (note the order of the arguments!). + *

+ * On entry: in, output are in reduced-size form. + */ + static void sub(long[] output, long[] in) { + sub(output, in, output); + } + + /** + * Multiply a number by a scalar: output = in * scalar + */ + static void scalarProduct(long[] output, long[] in, long scalar) { + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = in[i] * scalar; + } + } + + /** + * Multiply two numbers: out = in2 * in + *

+ * output must be distinct to both inputs. The inputs are reduced coefficient form, + * the output is not. + *

+ * out[x] <= 14 * the largest product of the input limbs. + */ + static void product(long[] out, long[] in2, long[] in) { + out[0] = in2[0] * in[0]; + out[1] = in2[0] * in[1] + in2[1] * in[0]; + out[2] = 2 * in2[1] * in[1] + in2[0] * in[2] + in2[2] * in[0]; + out[3] = in2[1] * in[2] + in2[2] * in[1] + in2[0] * in[3] + in2[3] * in[0]; + out[4] = + in2[2] * in[2] + 2 * (in2[1] * in[3] + in2[3] * in[1]) + in2[0] * in[4] + in2[4] * in[0]; + out[5] = in2[2] * in[3] + in2[3] * in[2] + in2[1] * in[4] + in2[4] * in[1] + in2[0] * in[5] + + in2[5] * in[0]; + out[6] = 2 * (in2[3] * in[3] + in2[1] * in[5] + in2[5] * in[1]) + in2[2] * in[4] + + in2[4] * in[2] + in2[0] * in[6] + in2[6] * in[0]; + out[7] = in2[3] * in[4] + in2[4] * in[3] + in2[2] * in[5] + in2[5] * in[2] + in2[1] * in[6] + + in2[6] * in[1] + in2[0] * in[7] + in2[7] * in[0]; + out[8] = in2[4] * in[4] + + 2 * (in2[3] * in[5] + in2[5] * in[3] + in2[1] * in[7] + in2[7] * in[1]) + in2[2] * in[6] + + in2[6] * in[2] + in2[0] * in[8] + in2[8] * in[0]; + out[9] = in2[4] * in[5] + in2[5] * in[4] + in2[3] * in[6] + in2[6] * in[3] + in2[2] * in[7] + + in2[7] * in[2] + in2[1] * in[8] + in2[8] * in[1] + in2[0] * in[9] + in2[9] * in[0]; + out[10] = + 2 * (in2[5] * in[5] + in2[3] * in[7] + in2[7] * in[3] + in2[1] * in[9] + in2[9] * in[1]) + + in2[4] * in[6] + in2[6] * in[4] + in2[2] * in[8] + in2[8] * in[2]; + out[11] = in2[5] * in[6] + in2[6] * in[5] + in2[4] * in[7] + in2[7] * in[4] + in2[3] * in[8] + + in2[8] * in[3] + in2[2] * in[9] + in2[9] * in[2]; + out[12] = in2[6] * in[6] + + 2 * (in2[5] * in[7] + in2[7] * in[5] + in2[3] * in[9] + in2[9] * in[3]) + in2[4] * in[8] + + in2[8] * in[4]; + out[13] = in2[6] * in[7] + in2[7] * in[6] + in2[5] * in[8] + in2[8] * in[5] + in2[4] * in[9] + + in2[9] * in[4]; + out[14] = + 2 * (in2[7] * in[7] + in2[5] * in[9] + in2[9] * in[5]) + in2[6] * in[8] + in2[8] * in[6]; + out[15] = in2[7] * in[8] + in2[8] * in[7] + in2[6] * in[9] + in2[9] * in[6]; + out[16] = in2[8] * in[8] + 2 * (in2[7] * in[9] + in2[9] * in[7]); + out[17] = in2[8] * in[9] + in2[9] * in[8]; + out[18] = 2 * in2[9] * in[9]; + } + + /** + * Reduce a field element by calling reduceSizeByModularReduction and reduceCoefficients. + * + * @param input An input array of any length. If the array has 19 elements, it will be used as + * temporary buffer and its contents changed. + * @param output An output array of size LIMB_CNT. After the call |output[i]| < 2^26 will hold. + */ + static void reduce(long[] input, long[] output) { + long[] tmp; + if (input.length == 19) { + tmp = input; + } else { + tmp = new long[19]; + System.arraycopy(input, 0, tmp, 0, input.length); + } + reduceSizeByModularReduction(tmp); + reduceCoefficients(tmp); + System.arraycopy(tmp, 0, output, 0, LIMB_CNT); + } + + /** + * Reduce a long form to a reduced-size form by taking the input mod 2^255 - 19. + *

+ * On entry: |output[i]| < 14*2^54 + * On exit: |output[0..8]| < 280*2^54 + */ + static void reduceSizeByModularReduction(long[] output) { + // The coefficients x[10], x[11],..., x[18] are eliminated by reduction modulo 2^255 - 19. + // For example, the coefficient x[18] is multiplied by 19 and added to the coefficient x[8]. + // + // Each of these shifts and adds ends up multiplying the value by 19. + // + // For output[0..8], the absolute entry value is < 14*2^54 and we add, at most, 19*14*2^54 + // thus, on exit, |output[0..8]| < 280*2^54. + output[8] += output[18] << 4; + output[8] += output[18] << 1; + output[8] += output[18]; + output[7] += output[17] << 4; + output[7] += output[17] << 1; + output[7] += output[17]; + output[6] += output[16] << 4; + output[6] += output[16] << 1; + output[6] += output[16]; + output[5] += output[15] << 4; + output[5] += output[15] << 1; + output[5] += output[15]; + output[4] += output[14] << 4; + output[4] += output[14] << 1; + output[4] += output[14]; + output[3] += output[13] << 4; + output[3] += output[13] << 1; + output[3] += output[13]; + output[2] += output[12] << 4; + output[2] += output[12] << 1; + output[2] += output[12]; + output[1] += output[11] << 4; + output[1] += output[11] << 1; + output[1] += output[11]; + output[0] += output[10] << 4; + output[0] += output[10] << 1; + output[0] += output[10]; + } + + /** + * Reduce all coefficients of the short form input so that |x| < 2^26. + *

+ * On entry: |output[i]| < 280*2^54 + */ + static void reduceCoefficients(long[] output) { + output[10] = 0; + + for (int i = 0; i < LIMB_CNT; i += 2) { + long over = output[i] / TWO_TO_26; + // The entry condition (that |output[i]| < 280*2^54) means that over is, at most, 280*2^28 + // in the first iteration of this loop. This is added to the next limb and we can + // approximate the resulting bound of that limb by 281*2^54. + output[i] -= over << 26; + output[i + 1] += over; + + // For the first iteration, |output[i+1]| < 281*2^54, thus |over| < 281*2^29. When this is + // added to the next limb, the resulting bound can be approximated as 281*2^54. + // + // For subsequent iterations of the loop, 281*2^54 remains a conservative bound and no + // overflow occurs. + over = output[i + 1] / TWO_TO_25; + output[i + 1] -= over << 25; + output[i + 2] += over; + } + // Now |output[10]| < 281*2^29 and all other coefficients are reduced. + output[0] += output[10] << 4; + output[0] += output[10] << 1; + output[0] += output[10]; + + output[10] = 0; + // Now output[1..9] are reduced, and |output[0]| < 2^26 + 19*281*2^29 so |over| will be no + // more than 2^16. + long over = output[0] / TWO_TO_26; + output[0] -= over << 26; + output[1] += over; + // Now output[0,2..9] are reduced, and |output[1]| < 2^25 + 2^16 < 2^26. The bound on + // |output[1]| is sufficient to meet our needs. + } + + /** + * A helpful wrapper around {@ref Field25519#product}: output = in * in2. + *

+ * On entry: |in[i]| < 2^27 and |in2[i]| < 2^27. + *

+ * The output is reduced degree (indeed, one need only provide storage for 10 limbs) and + * |output[i]| < 2^26. + */ + static void mult(long[] output, long[] in, long[] in2) { + long[] t = new long[19]; + product(t, in, in2); + // |t[i]| < 2^26 + reduce(t, output); + } + + /** + * Square a number: out = in**2 + *

+ * output must be distinct from the input. The inputs are reduced coefficient form, the output + * is not.

out[x] <= 14 * the largest product of the input limbs. + */ + private static void squareInner(long[] out, long[] in) { + out[0] = in[0] * in[0]; + out[1] = 2 * in[0] * in[1]; + out[2] = 2 * (in[1] * in[1] + in[0] * in[2]); + out[3] = 2 * (in[1] * in[2] + in[0] * in[3]); + out[4] = in[2] * in[2] + 4 * in[1] * in[3] + 2 * in[0] * in[4]; + out[5] = 2 * (in[2] * in[3] + in[1] * in[4] + in[0] * in[5]); + out[6] = 2 * (in[3] * in[3] + in[2] * in[4] + in[0] * in[6] + 2 * in[1] * in[5]); + out[7] = 2 * (in[3] * in[4] + in[2] * in[5] + in[1] * in[6] + in[0] * in[7]); + out[8] = + in[4] * in[4] + 2 * (in[2] * in[6] + in[0] * in[8] + 2 * (in[1] * in[7] + in[3] * in[5])); + out[9] = 2 * (in[4] * in[5] + in[3] * in[6] + in[2] * in[7] + in[1] * in[8] + in[0] * in[9]); + out[10] = + 2 * (in[5] * in[5] + in[4] * in[6] + in[2] * in[8] + 2 * (in[3] * in[7] + in[1] * in[9])); + out[11] = 2 * (in[5] * in[6] + in[4] * in[7] + in[3] * in[8] + in[2] * in[9]); + out[12] = in[6] * in[6] + 2 * (in[4] * in[8] + 2 * (in[5] * in[7] + in[3] * in[9])); + out[13] = 2 * (in[6] * in[7] + in[5] * in[8] + in[4] * in[9]); + out[14] = 2 * (in[7] * in[7] + in[6] * in[8] + 2 * in[5] * in[9]); + out[15] = 2 * (in[7] * in[8] + in[6] * in[9]); + out[16] = in[8] * in[8] + 4 * in[7] * in[9]; + out[17] = 2 * in[8] * in[9]; + out[18] = 2 * in[9] * in[9]; + } + + /** + * Returns in^2. + *

+ * On entry: The |in| argument is in reduced coefficients form and |in[i]| < 2^27. + *

+ * On exit: The |output| argument is in reduced coefficients form (indeed, one need only provide + * storage for 10 limbs) and |out[i]| < 2^26. + */ + static void square(long[] output, long[] in) { + long[] t = new long[19]; + squareInner(t, in); + // |t[i]| < 14*2^54 because the largest product of two limbs will be < 2^(27+27) and + // SquareInner adds together, at most, 14 of those products. + reduce(t, output); + } + + /** + * Takes a little-endian, 32-byte number and expands it into mixed radix form. + */ + static long[] expand(byte[] input) { + long[] output = new long[LIMB_CNT]; + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = ((((long) (input[EXPAND_START[i]] & 0xff)) + | ((long) (input[EXPAND_START[i] + 1] & 0xff)) << 8 + | ((long) (input[EXPAND_START[i] + 2] & 0xff)) << 16 + | ((long) (input[EXPAND_START[i] + 3] & 0xff)) << 24) + >> EXPAND_SHIFT[i]) + & MASK[i & 1]; + } + return output; + } + + /** + * Takes a fully reduced mixed radix form number and contract it into a little-endian, 32-byte + * array. + *

+ * On entry: |input_limbs[i]| < 2^26 + */ + @SuppressWarnings("NarrowingCompoundAssignment") + static byte[] contract(long[] inputLimbs) { + long[] input = Arrays.copyOf(inputLimbs, LIMB_CNT); + for (int j = 0; j < 2; j++) { + for (int i = 0; i < 9; i++) { + // This calculation is a time-invariant way to make input[i] non-negative by borrowing + // from the next-larger limb. + int carry = -(int) ((input[i] & (input[i] >> 31)) >> SHIFT[i & 1]); + input[i] = input[i] + (carry << SHIFT[i & 1]); + input[i + 1] -= carry; + } + + // There's no greater limb for input[9] to borrow from, but we can multiply by 19 and borrow + // from input[0], which is valid mod 2^255-19. + { + int carry = -(int) ((input[9] & (input[9] >> 31)) >> 25); + input[9] += (carry << 25); + input[0] -= (carry * 19); + } + + // After the first iteration, input[1..9] are non-negative and fit within 25 or 26 bits, + // depending on position. However, input[0] may be negative. + } + + // The first borrow-propagation pass above ended with every limb except (possibly) input[0] + // non-negative. + // + // If input[0] was negative after the first pass, then it was because of a carry from + // input[9]. On entry, input[9] < 2^26 so the carry was, at most, one, since (2**26-1) >> 25 + // = 1. Thus input[0] >= -19. + // + // In the second pass, each limb is decreased by at most one. Thus the second + // borrow-propagation pass could only have wrapped around to decrease input[0] again if the + // first pass left input[0] negative *and* input[1] through input[9] were all zero. In that + // case, input[1] is now 2^25 - 1, and this last borrow-propagation step will leave input[1] + // non-negative. + { + int carry = -(int) ((input[0] & (input[0] >> 31)) >> 26); + input[0] += (carry << 26); + input[1] -= carry; + } + + // All input[i] are now non-negative. However, there might be values between 2^25 and 2^26 in + // a limb which is, nominally, 25 bits wide. + for (int j = 0; j < 2; j++) { + for (int i = 0; i < 9; i++) { + int carry = (int) (input[i] >> SHIFT[i & 1]); + input[i] &= MASK[i & 1]; + input[i + 1] += carry; + } + } + + { + int carry = (int) (input[9] >> 25); + input[9] &= 0x1ffffff; + input[0] += 19 * carry; + } + + // If the first carry-chain pass, just above, ended up with a carry from input[9], and that + // caused input[0] to be out-of-bounds, then input[0] was < 2^26 + 2*19, because the carry + // was, at most, two. + // + // If the second pass carried from input[9] again then input[0] is < 2*19 and the input[9] -> + // input[0] carry didn't push input[0] out of bounds. + + // It still remains the case that input might be between 2^255-19 and 2^255. In this case, + // input[1..9] must take their maximum value and input[0] must be >= (2^255-19) & 0x3ffffff, + // which is 0x3ffffed. + int mask = gte((int) input[0], 0x3ffffed); + for (int i = 1; i < LIMB_CNT; i++) { + mask &= eq((int) input[i], MASK[i & 1]); + } + + // mask is either 0xffffffff (if input >= 2^255-19) and zero otherwise. Thus this + // conditionally subtracts 2^255-19. + input[0] -= mask & 0x3ffffed; + input[1] -= mask & 0x1ffffff; + for (int i = 2; i < LIMB_CNT; i += 2) { + input[i] -= mask & 0x3ffffff; + input[i + 1] -= mask & 0x1ffffff; + } + + for (int i = 0; i < LIMB_CNT; i++) { + input[i] <<= EXPAND_SHIFT[i]; + } + byte[] output = new byte[FIELD_LEN]; + for (int i = 0; i < LIMB_CNT; i++) { + output[EXPAND_START[i]] |= input[i] & 0xff; + output[EXPAND_START[i] + 1] |= (input[i] >> 8) & 0xff; + output[EXPAND_START[i] + 2] |= (input[i] >> 16) & 0xff; + output[EXPAND_START[i] + 3] |= (input[i] >> 24) & 0xff; + } + return output; + } + + /** + * Computes inverse of z = z(2^255 - 21) + *

+ * Shamelessly copied from agl's code which was shamelessly copied from djb's code. Only the + * comment format and the variable namings are different from those. + */ + static void inverse(long[] out, long[] z) { + long[] z2 = new long[Field25519.LIMB_CNT]; + long[] z9 = new long[Field25519.LIMB_CNT]; + long[] z11 = new long[Field25519.LIMB_CNT]; + long[] z2To5Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To10Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To20Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To50Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To100Minus1 = new long[Field25519.LIMB_CNT]; + long[] t0 = new long[Field25519.LIMB_CNT]; + long[] t1 = new long[Field25519.LIMB_CNT]; + + square(z2, z); // 2 + square(t1, z2); // 4 + square(t0, t1); // 8 + mult(z9, t0, z); // 9 + mult(z11, z9, z2); // 11 + square(t0, z11); // 22 + mult(z2To5Minus1, t0, z9); // 2^5 - 2^0 = 31 + + square(t0, z2To5Minus1); // 2^6 - 2^1 + square(t1, t0); // 2^7 - 2^2 + square(t0, t1); // 2^8 - 2^3 + square(t1, t0); // 2^9 - 2^4 + square(t0, t1); // 2^10 - 2^5 + mult(z2To10Minus1, t0, z2To5Minus1); // 2^10 - 2^0 + + square(t0, z2To10Minus1); // 2^11 - 2^1 + square(t1, t0); // 2^12 - 2^2 + for (int i = 2; i < 10; i += 2) { // 2^20 - 2^10 + square(t0, t1); + square(t1, t0); + } + mult(z2To20Minus1, t1, z2To10Minus1); // 2^20 - 2^0 + + square(t0, z2To20Minus1); // 2^21 - 2^1 + square(t1, t0); // 2^22 - 2^2 + for (int i = 2; i < 20; i += 2) { // 2^40 - 2^20 + square(t0, t1); + square(t1, t0); + } + mult(t0, t1, z2To20Minus1); // 2^40 - 2^0 + + square(t1, t0); // 2^41 - 2^1 + square(t0, t1); // 2^42 - 2^2 + for (int i = 2; i < 10; i += 2) { // 2^50 - 2^10 + square(t1, t0); + square(t0, t1); + } + mult(z2To50Minus1, t0, z2To10Minus1); // 2^50 - 2^0 + + square(t0, z2To50Minus1); // 2^51 - 2^1 + square(t1, t0); // 2^52 - 2^2 + for (int i = 2; i < 50; i += 2) { // 2^100 - 2^50 + square(t0, t1); + square(t1, t0); + } + mult(z2To100Minus1, t1, z2To50Minus1); // 2^100 - 2^0 + + square(t1, z2To100Minus1); // 2^101 - 2^1 + square(t0, t1); // 2^102 - 2^2 + for (int i = 2; i < 100; i += 2) { // 2^200 - 2^100 + square(t1, t0); + square(t0, t1); + } + mult(t1, t0, z2To100Minus1); // 2^200 - 2^0 + + square(t0, t1); // 2^201 - 2^1 + square(t1, t0); // 2^202 - 2^2 + for (int i = 2; i < 50; i += 2) { // 2^250 - 2^50 + square(t0, t1); + square(t1, t0); + } + mult(t0, t1, z2To50Minus1); // 2^250 - 2^0 + + square(t1, t0); // 2^251 - 2^1 + square(t0, t1); // 2^252 - 2^2 + square(t1, t0); // 2^253 - 2^3 + square(t0, t1); // 2^254 - 2^4 + square(t1, t0); // 2^255 - 2^5 + mult(out, t1, z11); // 2^255 - 21 + } + + /** + * Returns 0xffffffff iff a == b and zero otherwise. + */ + private static int eq(int a, int b) { + a = ~(a ^ b); + a &= a << 16; + a &= a << 8; + a &= a << 4; + a &= a << 2; + a &= a << 1; + return a >> 31; + } + + /** + * returns 0xffffffff if a >= b and zero otherwise, where a and b are both non-negative. + */ + private static int gte(int a, int b) { + a -= b; + // a >= 0 iff a >= b. + return ~(a >> 31); + } + } + + // (x = 0, y = 1) point + private static final CachedXYT CACHED_NEUTRAL = + new CachedXYT(new long[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new long[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, new long[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + private static final PartialXYZT NEUTRAL = new PartialXYZT( + new XYZ(new long[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, new long[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new long[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}), + new long[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + /** + * Projective point representation (X:Y:Z) satisfying x = X/Z, y = Y/Z + *

+ * Note that this is referred as ge_p2 in ref10 impl. + * Also note that x = X, y = Y and z = Z below following Java coding style. + *

+ * See + * Koyama K., Tsuruoka Y. (1993) Speeding up Elliptic Cryptosystems by Using a Signed Binary + * Window Method. + *

+ * https://hyperelliptic.org/EFD/g1p/auto-twisted-projective.html + */ + private static class XYZ { + final long[] x; + final long[] y; + final long[] z; + + XYZ() { + this(new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT], + new long[Field25519.LIMB_CNT]); + } + + XYZ(long[] x, long[] y, long[] z) { + this.x = x; + this.y = y; + this.z = z; + } + + XYZ(XYZ xyz) { + x = Arrays.copyOf(xyz.x, Field25519.LIMB_CNT); + y = Arrays.copyOf(xyz.y, Field25519.LIMB_CNT); + z = Arrays.copyOf(xyz.z, Field25519.LIMB_CNT); + } + + XYZ(PartialXYZT partialXYZT) { + this(); + fromPartialXYZT(this, partialXYZT); + } + + /** + * ge_p1p1_to_p2.c + */ + static XYZ fromPartialXYZT(XYZ out, PartialXYZT in) { + Field25519.mult(out.x, in.xyz.x, in.t); + Field25519.mult(out.y, in.xyz.y, in.xyz.z); + Field25519.mult(out.z, in.xyz.z, in.t); + return out; + } + + /** + * Encodes this point to bytes. + */ + byte[] toBytes() { + long[] recip = new long[Field25519.LIMB_CNT]; + long[] x = new long[Field25519.LIMB_CNT]; + long[] y = new long[Field25519.LIMB_CNT]; + Field25519.inverse(recip, z); + Field25519.mult(x, this.x, recip); + Field25519.mult(y, this.y, recip); + byte[] s = Field25519.contract(y); + s[31] = (byte) (s[31] ^ (getLsb(x) << 7)); + return s; + } + + /** + * Best effort fix-timing array comparison. + * + * @return true if two arrays are equal. + */ + private static boolean bytesEqual(final byte[] x, final byte[] y) { + if (x == null || y == null) { + return false; + } + if (x.length != y.length) { + return false; + } + int res = 0; + for (int i = 0; i < x.length; i++) { + res |= x[i] ^ y[i]; + } + return res == 0; + } + + /** + * Checks that the point is on curve + */ + boolean isOnCurve() { + long[] x2 = new long[Field25519.LIMB_CNT]; + Field25519.square(x2, x); + long[] y2 = new long[Field25519.LIMB_CNT]; + Field25519.square(y2, y); + long[] z2 = new long[Field25519.LIMB_CNT]; + Field25519.square(z2, z); + long[] z4 = new long[Field25519.LIMB_CNT]; + Field25519.square(z4, z2); + long[] lhs = new long[Field25519.LIMB_CNT]; + // lhs = y^2 - x^2 + Field25519.sub(lhs, y2, x2); + // lhs = z^2 * (y2 - x2) + Field25519.mult(lhs, lhs, z2); + long[] rhs = new long[Field25519.LIMB_CNT]; + // rhs = x^2 * y^2 + Field25519.mult(rhs, x2, y2); + // rhs = D * x^2 * y^2 + Field25519.mult(rhs, rhs, D); + // rhs = z^4 + D * x^2 * y^2 + Field25519.sum(rhs, z4); + // Field25519.mult reduces its output, but Field25519.sum does not, so we have to manually + // reduce it here. + Field25519.reduce(rhs, rhs); + // z^2 (y^2 - x^2) == z^4 + D * x^2 * y^2 + return bytesEqual(Field25519.contract(lhs), Field25519.contract(rhs)); + } + } + + /** + * Represents extended projective point representation (X:Y:Z:T) satisfying x = X/Z, y = Y/Z, + * XY = ZT + *

+ * Note that this is referred as ge_p3 in ref10 impl. + * Also note that t = T below following Java coding style. + *

+ * See + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + *

+ * https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html + */ + private static class XYZT { + final XYZ xyz; + final long[] t; + + XYZT() { + this(new XYZ(), new long[Field25519.LIMB_CNT]); + } + + XYZT(XYZ xyz, long[] t) { + this.xyz = xyz; + this.t = t; + } + + XYZT(PartialXYZT partialXYZT) { + this(); + fromPartialXYZT(this, partialXYZT); + } + + /** + * ge_p1p1_to_p2.c + */ + private static XYZT fromPartialXYZT(XYZT out, PartialXYZT in) { + Field25519.mult(out.xyz.x, in.xyz.x, in.t); + Field25519.mult(out.xyz.y, in.xyz.y, in.xyz.z); + Field25519.mult(out.xyz.z, in.xyz.z, in.t); + Field25519.mult(out.t, in.xyz.x, in.xyz.y); + return out; + } + + /** + * Decodes {@code s} into an extented projective point. + * See Section 5.1.3 Decoding in https://tools.ietf.org/html/rfc8032#section-5.1.3 + */ + private static XYZT fromBytesNegateVarTime(byte[] s) throws GeneralSecurityException { + long[] x = new long[Field25519.LIMB_CNT]; + long[] y = Field25519.expand(s); + long[] z = new long[Field25519.LIMB_CNT]; + z[0] = 1; + long[] t = new long[Field25519.LIMB_CNT]; + long[] u = new long[Field25519.LIMB_CNT]; + long[] v = new long[Field25519.LIMB_CNT]; + long[] vxx = new long[Field25519.LIMB_CNT]; + long[] check = new long[Field25519.LIMB_CNT]; + Field25519.square(u, y); + Field25519.mult(v, u, D); + Field25519.sub(u, u, z); // u = y^2 - 1 + Field25519.sum(v, v, z); // v = dy^2 + 1 + + long[] v3 = new long[Field25519.LIMB_CNT]; + Field25519.square(v3, v); + Field25519.mult(v3, v3, v); // v3 = v^3 + Field25519.square(x, v3); + Field25519.mult(x, x, v); + Field25519.mult(x, x, u); // x = uv^7 + + pow2252m3(x, x); // x = (uv^7)^((q-5)/8) + Field25519.mult(x, x, v3); + Field25519.mult(x, x, u); // x = uv^3(uv^7)^((q-5)/8) + + Field25519.square(vxx, x); + Field25519.mult(vxx, vxx, v); + Field25519.sub(check, vxx, u); // vx^2-u + if (isNonZeroVarTime(check)) { + Field25519.sum(check, vxx, u); // vx^2+u + if (isNonZeroVarTime(check)) { + throw new GeneralSecurityException("Cannot convert given bytes to extended projective " + + "coordinates. No square root exists for modulo 2^255-19"); + } + Field25519.mult(x, x, SQRTM1); + } + + if (!isNonZeroVarTime(x) && (s[31] & 0xff) >> 7 != 0) { + throw new GeneralSecurityException("Cannot convert given bytes to extended projective " + + "coordinates. Computed x is zero and encoded x's least significant bit is not zero"); + } + if (getLsb(x) == ((s[31] & 0xff) >> 7)) { + neg(x, x); + } + + Field25519.mult(t, x, y); + return new XYZT(new XYZ(x, y, z), t); + } + } + + /** + * Partial projective point representation ((X:Z),(Y:T)) satisfying x=X/Z, y=Y/T + *

+ * Note that this is referred as complete form in the original ref10 impl (ge_p1p1). + * Also note that t = T below following Java coding style. + *

+ * Although this has the same types as XYZT, it is redefined to have its own type so that it is + * readable and 1:1 corresponds to ref10 impl. + *

+ * Can be converted to XYZT as follows: + * X1 = X * T = x * Z * T = x * Z1 + * Y1 = Y * Z = y * T * Z = y * Z1 + * Z1 = Z * T = Z * T + * T1 = X * Y = x * Z * y * T = x * y * Z1 = X1Y1 / Z1 + */ + private static class PartialXYZT { + final XYZ xyz; + final long[] t; + + PartialXYZT() { + this(new XYZ(), new long[Field25519.LIMB_CNT]); + } + + PartialXYZT(XYZ xyz, long[] t) { + this.xyz = xyz; + this.t = t; + } + + PartialXYZT(PartialXYZT other) { + xyz = new XYZ(other.xyz); + t = Arrays.copyOf(other.t, Field25519.LIMB_CNT); + } + } + + /** + * Corresponds to the caching mentioned in the last paragraph of Section 3.1 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + * with Z = 1. + */ + private static class CachedXYT { + final long[] yPlusX; + final long[] yMinusX; + final long[] t2d; + + /** + * Creates a cached XYZT with Z = 1 + * + * @param yPlusX y + x + * @param yMinusX y - x + * @param t2d 2d * xy + */ + CachedXYT(long[] yPlusX, long[] yMinusX, long[] t2d) { + this.yPlusX = yPlusX; + this.yMinusX = yMinusX; + this.t2d = t2d; + } + + CachedXYT(CachedXYT other) { + yPlusX = Arrays.copyOf(other.yPlusX, Field25519.LIMB_CNT); + yMinusX = Arrays.copyOf(other.yMinusX, Field25519.LIMB_CNT); + t2d = Arrays.copyOf(other.t2d, Field25519.LIMB_CNT); + } + + // z is one implicitly, so this just copies {@code in} to {@code output}. + void multByZ(long[] output, long[] in) { + System.arraycopy(in, 0, output, 0, Field25519.LIMB_CNT); + } + + /** + * If icopy is 1, copies {@code other} into this point. Time invariant wrt to icopy value. + */ + void copyConditional(CachedXYT other, int icopy) { + copyConditional(yPlusX, other.yPlusX, icopy); + copyConditional(yMinusX, other.yMinusX, icopy); + copyConditional(t2d, other.t2d, icopy); + } + + /** + * Conditionally copies a reduced-form limb arrays {@code b} into {@code a} if {@code icopy} is + * 1, but leave {@code a} unchanged if 'iswap' is 0. Runs in data-invariant time to avoid + * side-channel attacks. + * + *

NOTE that this function requires that {@code icopy} be 1 or 0; other values give wrong + * results. Also, the two limb arrays must be in reduced-coefficient, reduced-degree form: the + * values in a[10..19] or b[10..19] aren't swapped, and all all values in a[0..9],b[0..9] must + * have magnitude less than Integer.MAX_VALUE. + */ + static void copyConditional(long[] a, long[] b, int icopy) { + int copy = -icopy; + for (int i = 0; i < Field25519.LIMB_CNT; i++) { + int x = copy & (((int) a[i]) ^ ((int) b[i])); + a[i] = ((int) a[i]) ^ x; + } + } + } + + private static class CachedXYZT extends CachedXYT { + private final long[] z; + + CachedXYZT() { + this(new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT], + new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT]); + } + + /** + * ge_p3_to_cached.c + */ + CachedXYZT(XYZT xyzt) { + this(); + Field25519.sum(yPlusX, xyzt.xyz.y, xyzt.xyz.x); + Field25519.sub(yMinusX, xyzt.xyz.y, xyzt.xyz.x); + System.arraycopy(xyzt.xyz.z, 0, z, 0, Field25519.LIMB_CNT); + Field25519.mult(t2d, xyzt.t, D2); + } + + /** + * Creates a cached XYZT + * + * @param yPlusX Y + X + * @param yMinusX Y - X + * @param z Z + * @param t2d 2d * (XY/Z) + */ + CachedXYZT(long[] yPlusX, long[] yMinusX, long[] z, long[] t2d) { + super(yPlusX, yMinusX, t2d); + this.z = z; + } + + @Override + public void multByZ(long[] output, long[] in) { + Field25519.mult(output, in, z); + } + } + + /** + * Addition defined in Section 3.1 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + *

+ * Please note that this is a partial of the operation listed there leaving out the final + * conversion from PartialXYZT to XYZT. + * + * @param extended extended projective point input + * @param cached cached projective point input + */ + private static void add(PartialXYZT partialXYZT, XYZT extended, CachedXYT cached) { + long[] t = new long[Field25519.LIMB_CNT]; + + // Y1 + X1 + Field25519.sum(partialXYZT.xyz.x, extended.xyz.y, extended.xyz.x); + + // Y1 - X1 + Field25519.sub(partialXYZT.xyz.y, extended.xyz.y, extended.xyz.x); + + // A = (Y1 - X1) * (Y2 - X2) + Field25519.mult(partialXYZT.xyz.y, partialXYZT.xyz.y, cached.yMinusX); + + // B = (Y1 + X1) * (Y2 + X2) + Field25519.mult(partialXYZT.xyz.z, partialXYZT.xyz.x, cached.yPlusX); + + // C = T1 * 2d * T2 = 2d * T1 * T2 (2d is written as k in the paper) + Field25519.mult(partialXYZT.t, extended.t, cached.t2d); + + // Z1 * Z2 + cached.multByZ(partialXYZT.xyz.x, extended.xyz.z); + + // D = 2 * Z1 * Z2 + Field25519.sum(t, partialXYZT.xyz.x, partialXYZT.xyz.x); + + // X3 = B - A + Field25519.sub(partialXYZT.xyz.x, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Y3 = B + A + Field25519.sum(partialXYZT.xyz.y, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Z3 = D + C + Field25519.sum(partialXYZT.xyz.z, t, partialXYZT.t); + + // T3 = D - C + Field25519.sub(partialXYZT.t, t, partialXYZT.t); + } + + /** + * Based on the addition defined in Section 3.1 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + *

+ * Please note that this is a partial of the operation listed there leaving out the final + * conversion from PartialXYZT to XYZT. + * + * @param extended extended projective point input + * @param cached cached projective point input + */ + private static void sub(PartialXYZT partialXYZT, XYZT extended, CachedXYT cached) { + long[] t = new long[Field25519.LIMB_CNT]; + + // Y1 + X1 + Field25519.sum(partialXYZT.xyz.x, extended.xyz.y, extended.xyz.x); + + // Y1 - X1 + Field25519.sub(partialXYZT.xyz.y, extended.xyz.y, extended.xyz.x); + + // A = (Y1 - X1) * (Y2 + X2) + Field25519.mult(partialXYZT.xyz.y, partialXYZT.xyz.y, cached.yPlusX); + + // B = (Y1 + X1) * (Y2 - X2) + Field25519.mult(partialXYZT.xyz.z, partialXYZT.xyz.x, cached.yMinusX); + + // C = T1 * 2d * T2 = 2d * T1 * T2 (2d is written as k in the paper) + Field25519.mult(partialXYZT.t, extended.t, cached.t2d); + + // Z1 * Z2 + cached.multByZ(partialXYZT.xyz.x, extended.xyz.z); + + // D = 2 * Z1 * Z2 + Field25519.sum(t, partialXYZT.xyz.x, partialXYZT.xyz.x); + + // X3 = B - A + Field25519.sub(partialXYZT.xyz.x, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Y3 = B + A + Field25519.sum(partialXYZT.xyz.y, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Z3 = D - C + Field25519.sub(partialXYZT.xyz.z, t, partialXYZT.t); + + // T3 = D + C + Field25519.sum(partialXYZT.t, t, partialXYZT.t); + } + + /** + * Doubles {@code p} and puts the result into this PartialXYZT. + *

+ * This is based on the addition defined in formula 7 in Section 3.3 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + *

+ * Please note that this is a partial of the operation listed there leaving out the final + * conversion from PartialXYZT to XYZT and also this fixes a typo in calculation of Y3 and T3 in + * the paper, H should be replaced with A+B. + */ + private static void doubleXYZ(PartialXYZT partialXYZT, XYZ p) { + long[] t0 = new long[Field25519.LIMB_CNT]; + + // XX = X1^2 + Field25519.square(partialXYZT.xyz.x, p.x); + + // YY = Y1^2 + Field25519.square(partialXYZT.xyz.z, p.y); + + // B' = Z1^2 + Field25519.square(partialXYZT.t, p.z); + + // B = 2 * B' + Field25519.sum(partialXYZT.t, partialXYZT.t, partialXYZT.t); + + // A = X1 + Y1 + Field25519.sum(partialXYZT.xyz.y, p.x, p.y); + + // AA = A^2 + Field25519.square(t0, partialXYZT.xyz.y); + + // Y3 = YY + XX + Field25519.sum(partialXYZT.xyz.y, partialXYZT.xyz.z, partialXYZT.xyz.x); + + // Z3 = YY - XX + Field25519.sub(partialXYZT.xyz.z, partialXYZT.xyz.z, partialXYZT.xyz.x); + + // X3 = AA - Y3 + Field25519.sub(partialXYZT.xyz.x, t0, partialXYZT.xyz.y); + + // T3 = B - Z3 + Field25519.sub(partialXYZT.t, partialXYZT.t, partialXYZT.xyz.z); + } + + /** + * Doubles {@code p} and puts the result into this PartialXYZT. + */ + private static void doubleXYZT(PartialXYZT partialXYZT, XYZT p) { + doubleXYZ(partialXYZT, p.xyz); + } + + /** + * Compares two byte values in constant time. + */ + private static int eq(int a, int b) { + int r = ~(a ^ b) & 0xff; + r &= r << 4; + r &= r << 2; + r &= r << 1; + return (r >> 7) & 1; + } + + /** + * This is a constant time operation where point b*B*256^pos is stored in {@code t}. + * When b is 0, t remains the same (i.e., neutral point). + *

+ * Although B_TABLE[32][8] (B_TABLE[i][j] = (j+1)*B*256^i) has j values in [0, 7], the select + * method negates the corresponding point if b is negative (which is straight forward in elliptic + * curves by just negating y coordinate). Therefore we can get multiples of B with the half of + * memory requirements. + * + * @param t neutral element (i.e., point 0), also serves as output. + * @param pos in B[pos][j] = (j+1)*B*256^pos + * @param b value in [-8, 8] range. + */ + private static void select(CachedXYT t, int pos, byte b) { + int bnegative = (b & 0xff) >> 7; + int babs = b - (((-bnegative) & b) << 1); + + t.copyConditional(B_TABLE[pos][0], eq(babs, 1)); + t.copyConditional(B_TABLE[pos][1], eq(babs, 2)); + t.copyConditional(B_TABLE[pos][2], eq(babs, 3)); + t.copyConditional(B_TABLE[pos][3], eq(babs, 4)); + t.copyConditional(B_TABLE[pos][4], eq(babs, 5)); + t.copyConditional(B_TABLE[pos][5], eq(babs, 6)); + t.copyConditional(B_TABLE[pos][6], eq(babs, 7)); + t.copyConditional(B_TABLE[pos][7], eq(babs, 8)); + + long[] yPlusX = Arrays.copyOf(t.yMinusX, Field25519.LIMB_CNT); + long[] yMinusX = Arrays.copyOf(t.yPlusX, Field25519.LIMB_CNT); + long[] t2d = Arrays.copyOf(t.t2d, Field25519.LIMB_CNT); + neg(t2d, t2d); + CachedXYT minust = new CachedXYT(yPlusX, yMinusX, t2d); + t.copyConditional(minust, bnegative); + } + + /** + * Computes {@code a}*B + * where a = a[0]+256*a[1]+...+256^31 a[31] and + * B is the Ed25519 base point (x,4/5) with x positive. + *

+ * Preconditions: + * a[31] <= 127 + * + * @throws IllegalStateException iff there is arithmetic error. + */ + @SuppressWarnings("NarrowingCompoundAssignment") + private static XYZ scalarMultWithBase(byte[] a) { + byte[] e = new byte[2 * Field25519.FIELD_LEN]; + for (int i = 0; i < Field25519.FIELD_LEN; i++) { + e[2 * i + 0] = (byte) (((a[i] & 0xff) >> 0) & 0xf); + e[2 * i + 1] = (byte) (((a[i] & 0xff) >> 4) & 0xf); + } + // each e[i] is between 0 and 15 + // e[63] is between 0 and 7 + + // Rewrite e in a way that each e[i] is in [-8, 8]. + // This can be done since a[63] is in [0, 7], the carry-over onto the most significant byte + // a[63] can be at most 1. + int carry = 0; + for (int i = 0; i < e.length - 1; i++) { + e[i] += carry; + carry = e[i] + 8; + carry >>= 4; + e[i] -= carry << 4; + } + e[e.length - 1] += carry; + + PartialXYZT ret = new PartialXYZT(NEUTRAL); + XYZT xyzt = new XYZT(); + // Although B_TABLE's i can be at most 31 (stores only 32 4bit multiples of B) and we have 64 + // 4bit values in e array, the below for loop adds cached values by iterating e by two in odd + // indices. After the result, we can double the result point 4 times to shift the multiplication + // scalar by 4 bits. + for (int i = 1; i < e.length; i += 2) { + CachedXYT t = new CachedXYT(CACHED_NEUTRAL); + select(t, i / 2, e[i]); + add(ret, XYZT.fromPartialXYZT(xyzt, ret), t); + } + + // Doubles the result 4 times to shift the multiplication scalar 4 bits to get the actual result + // for the odd indices in e. + XYZ xyz = new XYZ(); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + + // Add multiples of B for even indices of e. + for (int i = 0; i < e.length; i += 2) { + CachedXYT t = new CachedXYT(CACHED_NEUTRAL); + select(t, i / 2, e[i]); + add(ret, XYZT.fromPartialXYZT(xyzt, ret), t); + } + + // This check is to protect against flaws, i.e. if there is a computation error through a + // faulty CPU or if the implementation contains a bug. + XYZ result = new XYZ(ret); + if (!result.isOnCurve()) { + throw new IllegalStateException("arithmetic error in scalar multiplication"); + } + return result; + } + + @SuppressWarnings("NarrowingCompoundAssignment") + private static byte[] slide(byte[] a) { + byte[] r = new byte[256]; + // Writes each bit in a[0..31] into r[0..255]: + // a = a[0]+256*a[1]+...+256^31*a[31] is equal to + // r = r[0]+2*r[1]+...+2^255*r[255] + for (int i = 0; i < 256; i++) { + r[i] = (byte) (1 & ((a[i >> 3] & 0xff) >> (i & 7))); + } + + // Transforms r[i] as odd values in [-15, 15] + for (int i = 0; i < 256; i++) { + if (r[i] != 0) { + for (int b = 1; b <= 6 && i + b < 256; b++) { + if (r[i + b] != 0) { + if (r[i] + (r[i + b] << b) <= 15) { + r[i] += r[i + b] << b; + r[i + b] = 0; + } else if (r[i] - (r[i + b] << b) >= -15) { + r[i] -= r[i + b] << b; + for (int k = i + b; k < 256; k++) { + if (r[k] == 0) { + r[k] = 1; + break; + } + r[k] = 0; + } + } else { + break; + } + } + } + } + } + return r; + } + + /** + * Computes {@code a}*{@code pointA}+{@code b}*B + * where a = a[0]+256*a[1]+...+256^31*a[31]. + * and b = b[0]+256*b[1]+...+256^31*b[31]. + * B is the Ed25519 base point (x,4/5) with x positive. + *

+ * Note that execution time varies based on the input since this will only be used in verification + * of signatures. + */ + private static XYZ doubleScalarMultVarTime(byte[] a, XYZT pointA, byte[] b) { + // pointA, 3*pointA, 5*pointA, 7*pointA, 9*pointA, 11*pointA, 13*pointA, 15*pointA + CachedXYZT[] pointAArray = new CachedXYZT[8]; + pointAArray[0] = new CachedXYZT(pointA); + PartialXYZT t = new PartialXYZT(); + doubleXYZT(t, pointA); + XYZT doubleA = new XYZT(t); + for (int i = 1; i < pointAArray.length; i++) { + add(t, doubleA, pointAArray[i - 1]); + pointAArray[i] = new CachedXYZT(new XYZT(t)); + } + + byte[] aSlide = slide(a); + byte[] bSlide = slide(b); + t = new PartialXYZT(NEUTRAL); + XYZT u = new XYZT(); + int i = 255; + for (; i >= 0; i--) { + if (aSlide[i] != 0 || bSlide[i] != 0) { + break; + } + } + for (; i >= 0; i--) { + doubleXYZ(t, new XYZ(t)); + if (aSlide[i] > 0) { + add(t, XYZT.fromPartialXYZT(u, t), pointAArray[aSlide[i] / 2]); + } else if (aSlide[i] < 0) { + sub(t, XYZT.fromPartialXYZT(u, t), pointAArray[-aSlide[i] / 2]); + } + if (bSlide[i] > 0) { + add(t, XYZT.fromPartialXYZT(u, t), B2[bSlide[i] / 2]); + } else if (bSlide[i] < 0) { + sub(t, XYZT.fromPartialXYZT(u, t), B2[-bSlide[i] / 2]); + } + } + + return new XYZ(t); + } + + /** + * Returns true if {@code in} is nonzero. + *

+ * Note that execution time might depend on the input {@code in}. + */ + private static boolean isNonZeroVarTime(long[] in) { + long[] inCopy = new long[in.length + 1]; + System.arraycopy(in, 0, inCopy, 0, in.length); + Field25519.reduceCoefficients(inCopy); + byte[] bytes = Field25519.contract(inCopy); + for (byte b : bytes) { + if (b != 0) { + return true; + } + } + return false; + } + + /** + * Returns the least significant bit of {@code in}. + */ + private static int getLsb(long[] in) { + return Field25519.contract(in)[0] & 1; + } + + /** + * Negates all values in {@code in} and store it in {@code out}. + */ + private static void neg(long[] out, long[] in) { + for (int i = 0; i < in.length; i++) { + out[i] = -in[i]; + } + } + + /** + * Computes {@code in}^(2^252-3) mod 2^255-19 and puts the result in {@code out}. + */ + private static void pow2252m3(long[] out, long[] in) { + long[] t0 = new long[Field25519.LIMB_CNT]; + long[] t1 = new long[Field25519.LIMB_CNT]; + long[] t2 = new long[Field25519.LIMB_CNT]; + + // z2 = z1^2^1 + Field25519.square(t0, in); + + // z8 = z2^2^2 + Field25519.square(t1, t0); + for (int i = 1; i < 2; i++) { + Field25519.square(t1, t1); + } + + // z9 = z1*z8 + Field25519.mult(t1, in, t1); + + // z11 = z2*z9 + Field25519.mult(t0, t0, t1); + + // z22 = z11^2^1 + Field25519.square(t0, t0); + + // z_5_0 = z9*z22 + Field25519.mult(t0, t1, t0); + + // z_10_5 = z_5_0^2^5 + Field25519.square(t1, t0); + for (int i = 1; i < 5; i++) { + Field25519.square(t1, t1); + } + + // z_10_0 = z_10_5*z_5_0 + Field25519.mult(t0, t1, t0); + + // z_20_10 = z_10_0^2^10 + Field25519.square(t1, t0); + for (int i = 1; i < 10; i++) { + Field25519.square(t1, t1); + } + + // z_20_0 = z_20_10*z_10_0 + Field25519.mult(t1, t1, t0); + + // z_40_20 = z_20_0^2^20 + Field25519.square(t2, t1); + for (int i = 1; i < 20; i++) { + Field25519.square(t2, t2); + } + + // z_40_0 = z_40_20*z_20_0 + Field25519.mult(t1, t2, t1); + + // z_50_10 = z_40_0^2^10 + Field25519.square(t1, t1); + for (int i = 1; i < 10; i++) { + Field25519.square(t1, t1); + } + + // z_50_0 = z_50_10*z_10_0 + Field25519.mult(t0, t1, t0); + + // z_100_50 = z_50_0^2^50 + Field25519.square(t1, t0); + for (int i = 1; i < 50; i++) { + Field25519.square(t1, t1); + } + + // z_100_0 = z_100_50*z_50_0 + Field25519.mult(t1, t1, t0); + + // z_200_100 = z_100_0^2^100 + Field25519.square(t2, t1); + for (int i = 1; i < 100; i++) { + Field25519.square(t2, t2); + } + + // z_200_0 = z_200_100*z_100_0 + Field25519.mult(t1, t2, t1); + + // z_250_50 = z_200_0^2^50 + Field25519.square(t1, t1); + for (int i = 1; i < 50; i++) { + Field25519.square(t1, t1); + } + + // z_250_0 = z_250_50*z_50_0 + Field25519.mult(t0, t1, t0); + + // z_252_2 = z_250_0^2^2 + Field25519.square(t0, t0); + for (int i = 1; i < 2; i++) { + Field25519.square(t0, t0); + } + + // z_252_3 = z_252_2*z1 + Field25519.mult(out, t0, in); + } + + /** + * Returns 3 bytes of {@code in} starting from {@code idx} in Little-Endian format. + */ + private static long load3(byte[] in, int idx) { + long result; + result = (long) in[idx] & 0xff; + result |= (long) (in[idx + 1] & 0xff) << 8; + result |= (long) (in[idx + 2] & 0xff) << 16; + return result; + } + + /** + * Returns 4 bytes of {@code in} starting from {@code idx} in Little-Endian format. + */ + private static long load4(byte[] in, int idx) { + long result = load3(in, idx); + result |= (long) (in[idx + 3] & 0xff) << 24; + return result; + } + + /** + * Input: + * s[0]+256*s[1]+...+256^63*s[63] = s + *

+ * Output: + * s[0]+256*s[1]+...+256^31*s[31] = s mod l + * where l = 2^252 + 27742317777372353535851937790883648493. + * Overwrites s in place. + */ + private static void reduce(byte[] s) { + // Observation: + // 2^252 mod l is equivalent to -27742317777372353535851937790883648493 mod l + // Let m = -27742317777372353535851937790883648493 + // Thus a*2^252+b mod l is equivalent to a*m+b mod l + // + // First s is divided into chunks of 21 bits as follows: + // s0+2^21*s1+2^42*s3+...+2^462*s23 = s[0]+256*s[1]+...+256^63*s[63] + long s0 = 2097151 & load3(s, 0); + long s1 = 2097151 & (load4(s, 2) >> 5); + long s2 = 2097151 & (load3(s, 5) >> 2); + long s3 = 2097151 & (load4(s, 7) >> 7); + long s4 = 2097151 & (load4(s, 10) >> 4); + long s5 = 2097151 & (load3(s, 13) >> 1); + long s6 = 2097151 & (load4(s, 15) >> 6); + long s7 = 2097151 & (load3(s, 18) >> 3); + long s8 = 2097151 & load3(s, 21); + long s9 = 2097151 & (load4(s, 23) >> 5); + long s10 = 2097151 & (load3(s, 26) >> 2); + long s11 = 2097151 & (load4(s, 28) >> 7); + long s12 = 2097151 & (load4(s, 31) >> 4); + long s13 = 2097151 & (load3(s, 34) >> 1); + long s14 = 2097151 & (load4(s, 36) >> 6); + long s15 = 2097151 & (load3(s, 39) >> 3); + long s16 = 2097151 & load3(s, 42); + long s17 = 2097151 & (load4(s, 44) >> 5); + long s18 = 2097151 & (load3(s, 47) >> 2); + long s19 = 2097151 & (load4(s, 49) >> 7); + long s20 = 2097151 & (load4(s, 52) >> 4); + long s21 = 2097151 & (load3(s, 55) >> 1); + long s22 = 2097151 & (load4(s, 57) >> 6); + long s23 = (load4(s, 60) >> 3); + long carry0; + long carry1; + long carry2; + long carry3; + long carry4; + long carry5; + long carry6; + long carry7; + long carry8; + long carry9; + long carry10; + long carry11; + long carry12; + long carry13; + long carry14; + long carry15; + long carry16; + + // s23*2^462 = s23*2^210*2^252 is equivalent to s23*2^210*m in mod l + // As m is a 125 bit number, the result needs to scattered to 6 limbs (125/21 ceil is 6) + // starting from s11 (s11*2^210) + // m = [666643, 470296, 654183, -997805, 136657, -683901] in 21-bit limbs + s11 += s23 * 666643; + s12 += s23 * 470296; + s13 += s23 * 654183; + s14 -= s23 * 997805; + s15 += s23 * 136657; + s16 -= s23 * 683901; + // s23 = 0; + + s10 += s22 * 666643; + s11 += s22 * 470296; + s12 += s22 * 654183; + s13 -= s22 * 997805; + s14 += s22 * 136657; + s15 -= s22 * 683901; + // s22 = 0; + + s9 += s21 * 666643; + s10 += s21 * 470296; + s11 += s21 * 654183; + s12 -= s21 * 997805; + s13 += s21 * 136657; + s14 -= s21 * 683901; + // s21 = 0; + + s8 += s20 * 666643; + s9 += s20 * 470296; + s10 += s20 * 654183; + s11 -= s20 * 997805; + s12 += s20 * 136657; + s13 -= s20 * 683901; + // s20 = 0; + + s7 += s19 * 666643; + s8 += s19 * 470296; + s9 += s19 * 654183; + s10 -= s19 * 997805; + s11 += s19 * 136657; + s12 -= s19 * 683901; + // s19 = 0; + + s6 += s18 * 666643; + s7 += s18 * 470296; + s8 += s18 * 654183; + s9 -= s18 * 997805; + s10 += s18 * 136657; + s11 -= s18 * 683901; + // s18 = 0; + + // Reduce the bit length of limbs from s6 to s15 to 21-bits. + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry12 = (s12 + (1 << 20)) >> 21; + s13 += carry12; + s12 -= carry12 << 21; + carry14 = (s14 + (1 << 20)) >> 21; + s15 += carry14; + s14 -= carry14 << 21; + carry16 = (s16 + (1 << 20)) >> 21; + s17 += carry16; + s16 -= carry16 << 21; + + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + carry13 = (s13 + (1 << 20)) >> 21; + s14 += carry13; + s13 -= carry13 << 21; + carry15 = (s15 + (1 << 20)) >> 21; + s16 += carry15; + s15 -= carry15 << 21; + + // Resume reduction where we left off. + s5 += s17 * 666643; + s6 += s17 * 470296; + s7 += s17 * 654183; + s8 -= s17 * 997805; + s9 += s17 * 136657; + s10 -= s17 * 683901; + // s17 = 0; + + s4 += s16 * 666643; + s5 += s16 * 470296; + s6 += s16 * 654183; + s7 -= s16 * 997805; + s8 += s16 * 136657; + s9 -= s16 * 683901; + // s16 = 0; + + s3 += s15 * 666643; + s4 += s15 * 470296; + s5 += s15 * 654183; + s6 -= s15 * 997805; + s7 += s15 * 136657; + s8 -= s15 * 683901; + // s15 = 0; + + s2 += s14 * 666643; + s3 += s14 * 470296; + s4 += s14 * 654183; + s5 -= s14 * 997805; + s6 += s14 * 136657; + s7 -= s14 * 683901; + // s14 = 0; + + s1 += s13 * 666643; + s2 += s13 * 470296; + s3 += s13 * 654183; + s4 -= s13 * 997805; + s5 += s13 * 136657; + s6 -= s13 * 683901; + // s13 = 0; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + // Reduce the range of limbs from s0 to s11 to 21-bits. + carry0 = (s0 + (1 << 20)) >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry2 = (s2 + (1 << 20)) >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry4 = (s4 + (1 << 20)) >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + carry1 = (s1 + (1 << 20)) >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry3 = (s3 + (1 << 20)) >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry5 = (s5 + (1 << 20)) >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + // Carry chain reduction to propagate excess bits from s0 to s5 to the most significant limbs. + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry11 = s11 >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + // Do one last reduction as s12 might be 1. + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + // s12 = 0; + + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + // Serialize the result into the s. + s[0] = (byte) s0; + s[1] = (byte) (s0 >> 8); + s[2] = (byte) ((s0 >> 16) | (s1 << 5)); + s[3] = (byte) (s1 >> 3); + s[4] = (byte) (s1 >> 11); + s[5] = (byte) ((s1 >> 19) | (s2 << 2)); + s[6] = (byte) (s2 >> 6); + s[7] = (byte) ((s2 >> 14) | (s3 << 7)); + s[8] = (byte) (s3 >> 1); + s[9] = (byte) (s3 >> 9); + s[10] = (byte) ((s3 >> 17) | (s4 << 4)); + s[11] = (byte) (s4 >> 4); + s[12] = (byte) (s4 >> 12); + s[13] = (byte) ((s4 >> 20) | (s5 << 1)); + s[14] = (byte) (s5 >> 7); + s[15] = (byte) ((s5 >> 15) | (s6 << 6)); + s[16] = (byte) (s6 >> 2); + s[17] = (byte) (s6 >> 10); + s[18] = (byte) ((s6 >> 18) | (s7 << 3)); + s[19] = (byte) (s7 >> 5); + s[20] = (byte) (s7 >> 13); + s[21] = (byte) s8; + s[22] = (byte) (s8 >> 8); + s[23] = (byte) ((s8 >> 16) | (s9 << 5)); + s[24] = (byte) (s9 >> 3); + s[25] = (byte) (s9 >> 11); + s[26] = (byte) ((s9 >> 19) | (s10 << 2)); + s[27] = (byte) (s10 >> 6); + s[28] = (byte) ((s10 >> 14) | (s11 << 7)); + s[29] = (byte) (s11 >> 1); + s[30] = (byte) (s11 >> 9); + s[31] = (byte) (s11 >> 17); + } + + /** + * Input: + * a[0]+256*a[1]+...+256^31*a[31] = a + * b[0]+256*b[1]+...+256^31*b[31] = b + * c[0]+256*c[1]+...+256^31*c[31] = c + *

+ * Output: + * s[0]+256*s[1]+...+256^31*s[31] = (ab+c) mod l + * where l = 2^252 + 27742317777372353535851937790883648493. + */ + private static void mulAdd(byte[] s, byte[] a, byte[] b, byte[] c) { + // This is very similar to Ed25519.reduce, the difference in here is that it computes ab+c + // See Ed25519.reduce for related comments. + long a0 = 2097151 & load3(a, 0); + long a1 = 2097151 & (load4(a, 2) >> 5); + long a2 = 2097151 & (load3(a, 5) >> 2); + long a3 = 2097151 & (load4(a, 7) >> 7); + long a4 = 2097151 & (load4(a, 10) >> 4); + long a5 = 2097151 & (load3(a, 13) >> 1); + long a6 = 2097151 & (load4(a, 15) >> 6); + long a7 = 2097151 & (load3(a, 18) >> 3); + long a8 = 2097151 & load3(a, 21); + long a9 = 2097151 & (load4(a, 23) >> 5); + long a10 = 2097151 & (load3(a, 26) >> 2); + long a11 = (load4(a, 28) >> 7); + long b0 = 2097151 & load3(b, 0); + long b1 = 2097151 & (load4(b, 2) >> 5); + long b2 = 2097151 & (load3(b, 5) >> 2); + long b3 = 2097151 & (load4(b, 7) >> 7); + long b4 = 2097151 & (load4(b, 10) >> 4); + long b5 = 2097151 & (load3(b, 13) >> 1); + long b6 = 2097151 & (load4(b, 15) >> 6); + long b7 = 2097151 & (load3(b, 18) >> 3); + long b8 = 2097151 & load3(b, 21); + long b9 = 2097151 & (load4(b, 23) >> 5); + long b10 = 2097151 & (load3(b, 26) >> 2); + long b11 = (load4(b, 28) >> 7); + long c0 = 2097151 & load3(c, 0); + long c1 = 2097151 & (load4(c, 2) >> 5); + long c2 = 2097151 & (load3(c, 5) >> 2); + long c3 = 2097151 & (load4(c, 7) >> 7); + long c4 = 2097151 & (load4(c, 10) >> 4); + long c5 = 2097151 & (load3(c, 13) >> 1); + long c6 = 2097151 & (load4(c, 15) >> 6); + long c7 = 2097151 & (load3(c, 18) >> 3); + long c8 = 2097151 & load3(c, 21); + long c9 = 2097151 & (load4(c, 23) >> 5); + long c10 = 2097151 & (load3(c, 26) >> 2); + long c11 = (load4(c, 28) >> 7); + long s0; + long s1; + long s2; + long s3; + long s4; + long s5; + long s6; + long s7; + long s8; + long s9; + long s10; + long s11; + long s12; + long s13; + long s14; + long s15; + long s16; + long s17; + long s18; + long s19; + long s20; + long s21; + long s22; + long s23; + long carry0; + long carry1; + long carry2; + long carry3; + long carry4; + long carry5; + long carry6; + long carry7; + long carry8; + long carry9; + long carry10; + long carry11; + long carry12; + long carry13; + long carry14; + long carry15; + long carry16; + long carry17; + long carry18; + long carry19; + long carry20; + long carry21; + long carry22; + + s0 = c0 + a0 * b0; + s1 = c1 + a0 * b1 + a1 * b0; + s2 = c2 + a0 * b2 + a1 * b1 + a2 * b0; + s3 = c3 + a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0; + s4 = c4 + a0 * b4 + a1 * b3 + a2 * b2 + a3 * b1 + a4 * b0; + s5 = c5 + a0 * b5 + a1 * b4 + a2 * b3 + a3 * b2 + a4 * b1 + a5 * b0; + s6 = c6 + a0 * b6 + a1 * b5 + a2 * b4 + a3 * b3 + a4 * b2 + a5 * b1 + a6 * b0; + s7 = c7 + a0 * b7 + a1 * b6 + a2 * b5 + a3 * b4 + a4 * b3 + a5 * b2 + a6 * b1 + a7 * b0; + s8 = c8 + a0 * b8 + a1 * b7 + a2 * b6 + a3 * b5 + a4 * b4 + a5 * b3 + a6 * b2 + a7 * b1 + + a8 * b0; + s9 = c9 + a0 * b9 + a1 * b8 + a2 * b7 + a3 * b6 + a4 * b5 + a5 * b4 + a6 * b3 + a7 * b2 + + a8 * b1 + a9 * b0; + s10 = c10 + a0 * b10 + a1 * b9 + a2 * b8 + a3 * b7 + a4 * b6 + a5 * b5 + a6 * b4 + a7 * b3 + + a8 * b2 + a9 * b1 + a10 * b0; + s11 = c11 + a0 * b11 + a1 * b10 + a2 * b9 + a3 * b8 + a4 * b7 + a5 * b6 + a6 * b5 + a7 * b4 + + a8 * b3 + a9 * b2 + a10 * b1 + a11 * b0; + s12 = a1 * b11 + a2 * b10 + a3 * b9 + a4 * b8 + a5 * b7 + a6 * b6 + a7 * b5 + a8 * b4 + a9 * b3 + + a10 * b2 + a11 * b1; + s13 = a2 * b11 + a3 * b10 + a4 * b9 + a5 * b8 + a6 * b7 + a7 * b6 + a8 * b5 + a9 * b4 + a10 * b3 + + a11 * b2; + s14 = + a3 * b11 + a4 * b10 + a5 * b9 + a6 * b8 + a7 * b7 + a8 * b6 + a9 * b5 + a10 * b4 + a11 * b3; + s15 = a4 * b11 + a5 * b10 + a6 * b9 + a7 * b8 + a8 * b7 + a9 * b6 + a10 * b5 + a11 * b4; + s16 = a5 * b11 + a6 * b10 + a7 * b9 + a8 * b8 + a9 * b7 + a10 * b6 + a11 * b5; + s17 = a6 * b11 + a7 * b10 + a8 * b9 + a9 * b8 + a10 * b7 + a11 * b6; + s18 = a7 * b11 + a8 * b10 + a9 * b9 + a10 * b8 + a11 * b7; + s19 = a8 * b11 + a9 * b10 + a10 * b9 + a11 * b8; + s20 = a9 * b11 + a10 * b10 + a11 * b9; + s21 = a10 * b11 + a11 * b10; + s22 = a11 * b11; + s23 = 0; + + carry0 = (s0 + (1 << 20)) >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry2 = (s2 + (1 << 20)) >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry4 = (s4 + (1 << 20)) >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry12 = (s12 + (1 << 20)) >> 21; + s13 += carry12; + s12 -= carry12 << 21; + carry14 = (s14 + (1 << 20)) >> 21; + s15 += carry14; + s14 -= carry14 << 21; + carry16 = (s16 + (1 << 20)) >> 21; + s17 += carry16; + s16 -= carry16 << 21; + carry18 = (s18 + (1 << 20)) >> 21; + s19 += carry18; + s18 -= carry18 << 21; + carry20 = (s20 + (1 << 20)) >> 21; + s21 += carry20; + s20 -= carry20 << 21; + carry22 = (s22 + (1 << 20)) >> 21; + s23 += carry22; + s22 -= carry22 << 21; + + carry1 = (s1 + (1 << 20)) >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry3 = (s3 + (1 << 20)) >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry5 = (s5 + (1 << 20)) >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + carry13 = (s13 + (1 << 20)) >> 21; + s14 += carry13; + s13 -= carry13 << 21; + carry15 = (s15 + (1 << 20)) >> 21; + s16 += carry15; + s15 -= carry15 << 21; + carry17 = (s17 + (1 << 20)) >> 21; + s18 += carry17; + s17 -= carry17 << 21; + carry19 = (s19 + (1 << 20)) >> 21; + s20 += carry19; + s19 -= carry19 << 21; + carry21 = (s21 + (1 << 20)) >> 21; + s22 += carry21; + s21 -= carry21 << 21; + + s11 += s23 * 666643; + s12 += s23 * 470296; + s13 += s23 * 654183; + s14 -= s23 * 997805; + s15 += s23 * 136657; + s16 -= s23 * 683901; + // s23 = 0; + + s10 += s22 * 666643; + s11 += s22 * 470296; + s12 += s22 * 654183; + s13 -= s22 * 997805; + s14 += s22 * 136657; + s15 -= s22 * 683901; + // s22 = 0; + + s9 += s21 * 666643; + s10 += s21 * 470296; + s11 += s21 * 654183; + s12 -= s21 * 997805; + s13 += s21 * 136657; + s14 -= s21 * 683901; + // s21 = 0; + + s8 += s20 * 666643; + s9 += s20 * 470296; + s10 += s20 * 654183; + s11 -= s20 * 997805; + s12 += s20 * 136657; + s13 -= s20 * 683901; + // s20 = 0; + + s7 += s19 * 666643; + s8 += s19 * 470296; + s9 += s19 * 654183; + s10 -= s19 * 997805; + s11 += s19 * 136657; + s12 -= s19 * 683901; + // s19 = 0; + + s6 += s18 * 666643; + s7 += s18 * 470296; + s8 += s18 * 654183; + s9 -= s18 * 997805; + s10 += s18 * 136657; + s11 -= s18 * 683901; + // s18 = 0; + + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry12 = (s12 + (1 << 20)) >> 21; + s13 += carry12; + s12 -= carry12 << 21; + carry14 = (s14 + (1 << 20)) >> 21; + s15 += carry14; + s14 -= carry14 << 21; + carry16 = (s16 + (1 << 20)) >> 21; + s17 += carry16; + s16 -= carry16 << 21; + + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + carry13 = (s13 + (1 << 20)) >> 21; + s14 += carry13; + s13 -= carry13 << 21; + carry15 = (s15 + (1 << 20)) >> 21; + s16 += carry15; + s15 -= carry15 << 21; + + s5 += s17 * 666643; + s6 += s17 * 470296; + s7 += s17 * 654183; + s8 -= s17 * 997805; + s9 += s17 * 136657; + s10 -= s17 * 683901; + // s17 = 0; + + s4 += s16 * 666643; + s5 += s16 * 470296; + s6 += s16 * 654183; + s7 -= s16 * 997805; + s8 += s16 * 136657; + s9 -= s16 * 683901; + // s16 = 0; + + s3 += s15 * 666643; + s4 += s15 * 470296; + s5 += s15 * 654183; + s6 -= s15 * 997805; + s7 += s15 * 136657; + s8 -= s15 * 683901; + // s15 = 0; + + s2 += s14 * 666643; + s3 += s14 * 470296; + s4 += s14 * 654183; + s5 -= s14 * 997805; + s6 += s14 * 136657; + s7 -= s14 * 683901; + // s14 = 0; + + s1 += s13 * 666643; + s2 += s13 * 470296; + s3 += s13 * 654183; + s4 -= s13 * 997805; + s5 += s13 * 136657; + s6 -= s13 * 683901; + // s13 = 0; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + carry0 = (s0 + (1 << 20)) >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry2 = (s2 + (1 << 20)) >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry4 = (s4 + (1 << 20)) >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + carry1 = (s1 + (1 << 20)) >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry3 = (s3 + (1 << 20)) >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry5 = (s5 + (1 << 20)) >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry11 = s11 >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + // s12 = 0; + + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + s[0] = (byte) s0; + s[1] = (byte) (s0 >> 8); + s[2] = (byte) ((s0 >> 16) | (s1 << 5)); + s[3] = (byte) (s1 >> 3); + s[4] = (byte) (s1 >> 11); + s[5] = (byte) ((s1 >> 19) | (s2 << 2)); + s[6] = (byte) (s2 >> 6); + s[7] = (byte) ((s2 >> 14) | (s3 << 7)); + s[8] = (byte) (s3 >> 1); + s[9] = (byte) (s3 >> 9); + s[10] = (byte) ((s3 >> 17) | (s4 << 4)); + s[11] = (byte) (s4 >> 4); + s[12] = (byte) (s4 >> 12); + s[13] = (byte) ((s4 >> 20) | (s5 << 1)); + s[14] = (byte) (s5 >> 7); + s[15] = (byte) ((s5 >> 15) | (s6 << 6)); + s[16] = (byte) (s6 >> 2); + s[17] = (byte) (s6 >> 10); + s[18] = (byte) ((s6 >> 18) | (s7 << 3)); + s[19] = (byte) (s7 >> 5); + s[20] = (byte) (s7 >> 13); + s[21] = (byte) s8; + s[22] = (byte) (s8 >> 8); + s[23] = (byte) ((s8 >> 16) | (s9 << 5)); + s[24] = (byte) (s9 >> 3); + s[25] = (byte) (s9 >> 11); + s[26] = (byte) ((s9 >> 19) | (s10 << 2)); + s[27] = (byte) (s10 >> 6); + s[28] = (byte) ((s10 >> 14) | (s11 << 7)); + s[29] = (byte) (s11 >> 1); + s[30] = (byte) (s11 >> 9); + s[31] = (byte) (s11 >> 17); + } + + // The order of the generator as unsigned bytes in little endian order. + // (2^252 + 0x14def9dea2f79cd65812631a5cf5d3ed, cf. RFC 7748) + private static final byte[] GROUP_ORDER = {(byte) 0xed, (byte) 0xd3, (byte) 0xf5, (byte) 0x5c, + (byte) 0x1a, (byte) 0x63, (byte) 0x12, (byte) 0x58, (byte) 0xd6, (byte) 0x9c, (byte) 0xf7, + (byte) 0xa2, (byte) 0xde, (byte) 0xf9, (byte) 0xde, (byte) 0x14, (byte) 0x00, (byte) 0x00, + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x10}; + + // Checks whether s represents an integer smaller than the order of the group. + // This is needed to ensure that EdDSA signatures are non-malleable, as failing to check + // the range of S allows to modify signatures (cf. RFC 8032, Section 5.2.7 and Section 8.4.) + // @param s an integer in little-endian order. + private static boolean isSmallerThanGroupOrder(byte[] s) { + for (int j = Field25519.FIELD_LEN - 1; j >= 0; j--) { + // compare unsigned bytes + int a = s[j] & 0xff; + int b = GROUP_ORDER[j] & 0xff; + if (a != b) { + return a < b; + } + } + return false; + } + + /** + * Returns true if the EdDSA {@code signature} with {@code message}, can be verified with + * {@code publicKey}. + */ + public static boolean verify( + final byte[] message, final byte[] signature, final byte[] publicKey) { + try { + if (signature.length != SIGNATURE_LEN) { + return false; + } + if (publicKey.length != PUBLIC_KEY_LEN) { + return false; + } + byte[] s = Arrays.copyOfRange(signature, Field25519.FIELD_LEN, SIGNATURE_LEN); + if (!isSmallerThanGroupOrder(s)) { + return false; + } + MessageDigest digest = MessageDigest.getInstance("SHA-512"); + digest.update(signature, 0, Field25519.FIELD_LEN); + digest.update(publicKey); + digest.update(message); + byte[] h = digest.digest(); + reduce(h); + + XYZT negPublicKey = XYZT.fromBytesNegateVarTime(publicKey); + XYZ xyz = doubleScalarMultVarTime(h, negPublicKey, s); + byte[] expectedR = xyz.toBytes(); + for (int i = 0; i < Field25519.FIELD_LEN; i++) { + if (expectedR[i] != signature[i]) { + return false; + } + } + return true; + } catch (final GeneralSecurityException ignored) { + return false; + } + } +} diff --git a/client/android/src/com/wireguard/crypto/Key.java b/client/android/src/com/wireguard/crypto/Key.java new file mode 100644 index 00000000..a720624b --- /dev/null +++ b/client/android/src/com/wireguard/crypto/Key.java @@ -0,0 +1,283 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import com.wireguard.crypto.KeyFormatException.Type; + +import java.security.MessageDigest; +import java.security.SecureRandom; +import java.util.Arrays; + +/** + * Represents a WireGuard public or private key. This class uses specialized constant-time base64 + * and hexadecimal codec implementations that resist side-channel attacks. + *

+ * Instances of this class are immutable. + */ +@SuppressWarnings("MagicNumber") + +public final class Key { + private final byte[] key; + + /** + * Constructs an object encapsulating the supplied key. + * + * @param key an array of bytes containing a binary key. Callers of this constructor are + * responsible for ensuring that the array is of the correct length. + */ + private Key(final byte[] key) { + // Defensively copy to ensure immutability. + this.key = Arrays.copyOf(key, key.length); + } + + /** + * Decodes a single 4-character base64 chunk to an integer in constant time. + * + * @param src an array of at least 4 characters in base64 format + * @param srcOffset the offset of the beginning of the chunk in {@code src} + * @return the decoded 3-byte integer, or some arbitrary integer value if the input was not + * valid base64 + */ + private static int decodeBase64(final char[] src, final int srcOffset) { + int val = 0; + for (int i = 0; i < 4; ++i) { + final char c = src[i + srcOffset]; + val |= (-1 + ((((('A' - 1) - c) & (c - ('Z' + 1))) >>> 8) & (c - 64)) + + ((((('a' - 1) - c) & (c - ('z' + 1))) >>> 8) & (c - 70)) + + ((((('0' - 1) - c) & (c - ('9' + 1))) >>> 8) & (c + 5)) + + ((((('+' - 1) - c) & (c - ('+' + 1))) >>> 8) & 63) + + ((((('/' - 1) - c) & (c - ('/' + 1))) >>> 8) & 64)) + << (18 - 6 * i); + } + return val; + } + + /** + * Encodes a single 4-character base64 chunk from 3 consecutive bytes in constant time. + * + * @param src an array of at least 3 bytes + * @param srcOffset the offset of the beginning of the chunk in {@code src} + * @param dest an array of at least 4 characters + * @param destOffset the offset of the beginning of the chunk in {@code dest} + */ + private static void encodeBase64( + final byte[] src, final int srcOffset, final char[] dest, final int destOffset) { + final byte[] input = { + (byte) ((src[srcOffset] >>> 2) & 63), + (byte) ((src[srcOffset] << 4 | ((src[1 + srcOffset] & 0xff) >>> 4)) & 63), + (byte) ((src[1 + srcOffset] << 2 | ((src[2 + srcOffset] & 0xff) >>> 6)) & 63), + (byte) ((src[2 + srcOffset]) & 63), + }; + for (int i = 0; i < 4; ++i) { + dest[i + destOffset] = + (char) (input[i] + 'A' + (((25 - input[i]) >>> 8) & 6) - (((51 - input[i]) >>> 8) & 75) + - (((61 - input[i]) >>> 8) & 15) + (((62 - input[i]) >>> 8) & 3)); + } + } + + /** + * Decodes a WireGuard public or private key from its base64 string representation. This + * function throws a {@link KeyFormatException} if the source string is not well-formed. + * + * @param str the base64 string representation of a WireGuard key + * @return the decoded key encapsulated in an immutable container + */ + public static Key fromBase64(final String str) throws KeyFormatException { + final char[] input = str.toCharArray(); + if (input.length != Format.BASE64.length || input[Format.BASE64.length - 1] != '=') + throw new KeyFormatException(Format.BASE64, Type.LENGTH); + final byte[] key = new byte[Format.BINARY.length]; + int i; + int ret = 0; + for (i = 0; i < key.length / 3; ++i) { + final int val = decodeBase64(input, i * 4); + ret |= val >>> 31; + key[i * 3] = (byte) ((val >>> 16) & 0xff); + key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + key[i * 3 + 2] = (byte) (val & 0xff); + } + final char[] endSegment = { + input[i * 4], + input[i * 4 + 1], + input[i * 4 + 2], + 'A', + }; + final int val = decodeBase64(endSegment, 0); + ret |= (val >>> 31) | (val & 0xff); + key[i * 3] = (byte) ((val >>> 16) & 0xff); + key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + + if (ret != 0) + throw new KeyFormatException(Format.BASE64, Type.CONTENTS); + return new Key(key); + } + + /** + * Wraps a WireGuard public or private key in an immutable container. This function throws a + * {@link KeyFormatException} if the source data is not the correct length. + * + * @param bytes an array of bytes containing a WireGuard key in binary format + * @return the key encapsulated in an immutable container + */ + public static Key fromBytes(final byte[] bytes) throws KeyFormatException { + if (bytes.length != Format.BINARY.length) + throw new KeyFormatException(Format.BINARY, Type.LENGTH); + return new Key(bytes); + } + + /** + * Decodes a WireGuard public or private key from its hexadecimal string representation. This + * function throws a {@link KeyFormatException} if the source string is not well-formed. + * + * @param str the hexadecimal string representation of a WireGuard key + * @return the decoded key encapsulated in an immutable container + */ + public static Key fromHex(final String str) throws KeyFormatException { + final char[] input = str.toCharArray(); + if (input.length != Format.HEX.length) + throw new KeyFormatException(Format.HEX, Type.LENGTH); + final byte[] key = new byte[Format.BINARY.length]; + int ret = 0; + for (int i = 0; i < key.length; ++i) { + int c; + int cNum; + int cNum0; + int cAlpha; + int cAlpha0; + int cVal; + final int cAcc; + + c = input[i * 2]; + cNum = c ^ 48; + cNum0 = ((cNum - 10) >>> 8) & 0xff; + cAlpha = (c & ~32) - 55; + cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; + ret |= ((cNum0 | cAlpha0) - 1) >>> 8; + cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); + cAcc = cVal * 16; + + c = input[i * 2 + 1]; + cNum = c ^ 48; + cNum0 = ((cNum - 10) >>> 8) & 0xff; + cAlpha = (c & ~32) - 55; + cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; + ret |= ((cNum0 | cAlpha0) - 1) >>> 8; + cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); + key[i] = (byte) (cAcc | cVal); + } + if (ret != 0) + throw new KeyFormatException(Format.HEX, Type.CONTENTS); + return new Key(key); + } + + /** + * Generates a private key using the system's {@link SecureRandom} number generator. + * + * @return a well-formed random private key + */ + static Key generatePrivateKey() { + final SecureRandom secureRandom = new SecureRandom(); + final byte[] privateKey = new byte[Format.BINARY.getLength()]; + secureRandom.nextBytes(privateKey); + privateKey[0] &= 248; + privateKey[31] &= 127; + privateKey[31] |= 64; + return new Key(privateKey); + } + + /** + * Generates a public key from an existing private key. + * + * @param privateKey a private key + * @return a well-formed public key that corresponds to the supplied private key + */ + static Key generatePublicKey(final Key privateKey) { + final byte[] publicKey = new byte[Format.BINARY.getLength()]; + Curve25519.eval(publicKey, 0, privateKey.getBytes(), null); + return new Key(publicKey); + } + + @Override + public boolean equals(final Object obj) { + if (obj == this) + return true; + if (obj == null || obj.getClass() != getClass()) + return false; + final Key other = (Key) obj; + return MessageDigest.isEqual(key, other.key); + } + + /** + * Returns the key as an array of bytes. + * + * @return an array of bytes containing the raw binary key + */ + public byte[] getBytes() { + // Defensively copy to ensure immutability. + return Arrays.copyOf(key, key.length); + } + + @Override + public int hashCode() { + int ret = 0; + for (int i = 0; i < key.length / 4; ++i) + ret ^= (key[i * 4 + 0] >> 0) + (key[i * 4 + 1] >> 8) + (key[i * 4 + 2] >> 16) + + (key[i * 4 + 3] >> 24); + return ret; + } + + /** + * Encodes the key to base64. + * + * @return a string containing the encoded key + */ + public String toBase64() { + final char[] output = new char[Format.BASE64.length]; + int i; + for (i = 0; i < key.length / 3; ++i) encodeBase64(key, i * 3, output, i * 4); + final byte[] endSegment = { + key[i * 3], + key[i * 3 + 1], + 0, + }; + encodeBase64(endSegment, 0, output, i * 4); + output[Format.BASE64.length - 1] = '='; + return new String(output); + } + + /** + * Encodes the key to hexadecimal ASCII characters. + * + * @return a string containing the encoded key + */ + public String toHex() { + final char[] output = new char[Format.HEX.length]; + for (int i = 0; i < key.length; ++i) { + output[i * 2] = (char) (87 + (key[i] >> 4 & 0xf) + ((((key[i] >> 4 & 0xf) - 10) >> 8) & ~38)); + output[i * 2 + 1] = (char) (87 + (key[i] & 0xf) + ((((key[i] & 0xf) - 10) >> 8) & ~38)); + } + return new String(output); + } + + /** + * The supported formats for encoding a WireGuard key. + */ + public enum Format { + BASE64(44), + BINARY(32), + HEX(64); + + private final int length; + + Format(final int length) { + this.length = length; + } + + public int getLength() { + return length; + } + } +} diff --git a/client/android/src/com/wireguard/crypto/KeyFormatException.java b/client/android/src/com/wireguard/crypto/KeyFormatException.java new file mode 100644 index 00000000..cfd6f511 --- /dev/null +++ b/client/android/src/com/wireguard/crypto/KeyFormatException.java @@ -0,0 +1,32 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +/** + * An exception thrown when attempting to parse an invalid key (too short, too long, or byte + * data inappropriate for the format). The format being parsed can be accessed with the + * {@link #getFormat} method. + */ + +public final class KeyFormatException extends Exception { + private final Key.Format format; + private final Type type; + + KeyFormatException(final Key.Format format, final Type type) { + this.format = format; + this.type = type; + } + + public Key.Format getFormat() { + return format; + } + + public Type getType() { + return type; + } + + public enum Type { CONTENTS, LENGTH } +} diff --git a/client/android/src/com/wireguard/crypto/KeyPair.java b/client/android/src/com/wireguard/crypto/KeyPair.java new file mode 100644 index 00000000..221915aa --- /dev/null +++ b/client/android/src/com/wireguard/crypto/KeyPair.java @@ -0,0 +1,52 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +/** + * Represents a Curve25519 key pair as used by WireGuard. + *

+ * Instances of this class are immutable. + */ + +public class KeyPair { + private final Key privateKey; + private final Key publicKey; + + /** + * Creates a key pair using a newly-generated private key. + */ + public KeyPair() { + this(Key.generatePrivateKey()); + } + + /** + * Creates a key pair using an existing private key. + * + * @param privateKey a private key, used to derive the public key + */ + public KeyPair(final Key privateKey) { + this.privateKey = privateKey; + publicKey = Key.generatePublicKey(privateKey); + } + + /** + * Returns the private key from the key pair. + * + * @return the private key + */ + public Key getPrivateKey() { + return privateKey; + } + + /** + * Returns the public key from the key pair. + * + * @return the public key + */ + public Key getPublicKey() { + return publicKey; + } +} diff --git a/client/android/src/debug/res/mipmap-anydpi-v26/vpnicon.xml b/client/android/src/debug/res/mipmap-anydpi-v26/vpnicon.xml new file mode 100644 index 00000000..172bc624 --- /dev/null +++ b/client/android/src/debug/res/mipmap-anydpi-v26/vpnicon.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/client/android/src/debug/res/mipmap-anydpi-v26/vpnicon_round.xml b/client/android/src/debug/res/mipmap-anydpi-v26/vpnicon_round.xml new file mode 100644 index 00000000..7e0a9c7b --- /dev/null +++ b/client/android/src/debug/res/mipmap-anydpi-v26/vpnicon_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/client/android/src/debug/res/mipmap-hdpi/vpnicon.png b/client/android/src/debug/res/mipmap-hdpi/vpnicon.png new file mode 100644 index 00000000..fa762712 Binary files /dev/null and b/client/android/src/debug/res/mipmap-hdpi/vpnicon.png differ diff --git a/client/android/src/debug/res/mipmap-hdpi/vpnicon_foreground.png b/client/android/src/debug/res/mipmap-hdpi/vpnicon_foreground.png new file mode 100644 index 00000000..849f6ffd Binary files /dev/null and b/client/android/src/debug/res/mipmap-hdpi/vpnicon_foreground.png differ diff --git a/client/android/src/debug/res/mipmap-hdpi/vpnicon_round.png b/client/android/src/debug/res/mipmap-hdpi/vpnicon_round.png new file mode 100644 index 00000000..0fc3bb7f Binary files /dev/null and b/client/android/src/debug/res/mipmap-hdpi/vpnicon_round.png differ diff --git a/client/android/src/debug/res/mipmap-mdpi/vpnicon.png b/client/android/src/debug/res/mipmap-mdpi/vpnicon.png new file mode 100644 index 00000000..b5d3fe52 Binary files /dev/null and b/client/android/src/debug/res/mipmap-mdpi/vpnicon.png differ diff --git a/client/android/src/debug/res/mipmap-mdpi/vpnicon_foreground.png b/client/android/src/debug/res/mipmap-mdpi/vpnicon_foreground.png new file mode 100644 index 00000000..198212fd Binary files /dev/null and b/client/android/src/debug/res/mipmap-mdpi/vpnicon_foreground.png differ diff --git a/client/android/src/debug/res/mipmap-mdpi/vpnicon_round.png b/client/android/src/debug/res/mipmap-mdpi/vpnicon_round.png new file mode 100644 index 00000000..d6ce6de9 Binary files /dev/null and b/client/android/src/debug/res/mipmap-mdpi/vpnicon_round.png differ diff --git a/client/android/src/debug/res/mipmap-xhdpi/vpnicon.png b/client/android/src/debug/res/mipmap-xhdpi/vpnicon.png new file mode 100644 index 00000000..b6f63054 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xhdpi/vpnicon.png differ diff --git a/client/android/src/debug/res/mipmap-xhdpi/vpnicon_foreground.png b/client/android/src/debug/res/mipmap-xhdpi/vpnicon_foreground.png new file mode 100644 index 00000000..bc00fe12 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xhdpi/vpnicon_foreground.png differ diff --git a/client/android/src/debug/res/mipmap-xhdpi/vpnicon_round.png b/client/android/src/debug/res/mipmap-xhdpi/vpnicon_round.png new file mode 100644 index 00000000..ba4aa7b7 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xhdpi/vpnicon_round.png differ diff --git a/client/android/src/debug/res/mipmap-xxhdpi/vpnicon.png b/client/android/src/debug/res/mipmap-xxhdpi/vpnicon.png new file mode 100644 index 00000000..e6d3c7ed Binary files /dev/null and b/client/android/src/debug/res/mipmap-xxhdpi/vpnicon.png differ diff --git a/client/android/src/debug/res/mipmap-xxhdpi/vpnicon_foreground.png b/client/android/src/debug/res/mipmap-xxhdpi/vpnicon_foreground.png new file mode 100644 index 00000000..4f281d05 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xxhdpi/vpnicon_foreground.png differ diff --git a/client/android/src/debug/res/mipmap-xxhdpi/vpnicon_round.png b/client/android/src/debug/res/mipmap-xxhdpi/vpnicon_round.png new file mode 100644 index 00000000..deb1e253 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xxhdpi/vpnicon_round.png differ diff --git a/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon.png b/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon.png new file mode 100644 index 00000000..be451734 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon.png differ diff --git a/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon_foreground.png b/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon_foreground.png new file mode 100644 index 00000000..8b8bad14 Binary files /dev/null and b/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon_foreground.png differ diff --git a/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon_round.png b/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon_round.png new file mode 100644 index 00000000..ec01eaee Binary files /dev/null and b/client/android/src/debug/res/mipmap-xxxhdpi/vpnicon_round.png differ diff --git a/client/android/src/debug/res/values/vpnicon_background.xml b/client/android/src/debug/res/values/vpnicon_background.xml new file mode 100644 index 00000000..5f343ea1 --- /dev/null +++ b/client/android/src/debug/res/values/vpnicon_background.xml @@ -0,0 +1,4 @@ + + + #000000 + \ No newline at end of file diff --git a/client/android/src/org/amnezia/vpn/NotificationUtil.kt b/client/android/src/org/amnezia/vpn/NotificationUtil.kt new file mode 100644 index 00000000..15e706ed --- /dev/null +++ b/client/android/src/org/amnezia/vpn/NotificationUtil.kt @@ -0,0 +1,115 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn + +import android.app.NotificationChannel +import android.app.NotificationManager +import android.app.PendingIntent +import android.content.Context +import android.content.Intent +import android.os.Build +import android.os.Parcel +import androidx.core.app.NotificationCompat +import org.json.JSONObject + +object NotificationUtil { + var sCurrentContext: Context? = null + private var sNotificationBuilder: NotificationCompat.Builder? = null + + const val NOTIFICATION_CHANNEL_ID = "com.amnezia.vpnNotification" + const val CONNECTED_NOTIFICATION_ID = 1337 + const val tag = "NotificationUtil" + + /** + * Updates the current shown notification from a + * Parcel - Gets called from AndroidController.cpp + */ + fun update(data: Parcel) { + // [data] is here a json containing the noification content + val buffer = data.createByteArray() + val json = buffer?.let { String(it) } + val content = JSONObject(json) + + update(content.getString("title"), content.getString("message")) + } + + /** + * Updates the current shown notification + */ + fun update(heading: String, message: String) { + if (sCurrentContext == null) return + val notificationManager: NotificationManager = + sCurrentContext?.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager + + sNotificationBuilder?.let { + it.setContentTitle(heading) + .setContentText(message) + notificationManager.notify(CONNECTED_NOTIFICATION_ID, it.build()) + } + } + + /** + * Saves the default translated "connected" notification, in case the vpn gets started + * without the app. + */ + fun saveFallBackMessage(data: Parcel, context: Context) { + // [data] is here a json containing the notification content + val buffer = data.createByteArray() + val json = buffer?.let { String(it) } + val content = JSONObject(json) + + val prefs = Prefs.get(context) + prefs.edit() + .putString("fallbackNotificationHeader", content.getString("title")) + .putString("fallbackNotificationMessage", content.getString("message")) + .apply() + Log.v(tag, "Saved new fallback message -> ${content.getString("title")}") + } + + /* + * Creates a new Notification using the current set of Strings + * Shows the notification in the given {context} + */ + fun show(service: VPNService) { + sNotificationBuilder = NotificationCompat.Builder(service, NOTIFICATION_CHANNEL_ID) + sCurrentContext = service + val notificationManager: NotificationManager = + sCurrentContext?.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager + // From Oreo on we need to have a "notification channel" to post to. + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + val name = "vpn" + val descriptionText = " " + val importance = NotificationManager.IMPORTANCE_LOW + val channel = NotificationChannel(NOTIFICATION_CHANNEL_ID, name, importance).apply { + description = descriptionText + } + // Register the channel with the system + notificationManager.createNotificationChannel(channel) + } + // In case we do not have gotten a message to show from the Frontend + // try to populate the notification with a translated Fallback message + val prefs = Prefs.get(service) + val message = + "" + prefs.getString("fallbackNotificationMessage", "Running in the Background") + val header = "" + prefs.getString("fallbackNotificationHeader", "Mozilla VPN") + + // Create the Intent that Should be Fired if the User Clicks the notification + val mainActivityName = "org.amnezia.vpn.qt.VPNActivity" + val activity = Class.forName(mainActivityName) + val intent = Intent(service, activity) + val pendingIntent = PendingIntent.getActivity(service, 0, intent, 0) + // Build our notification + sNotificationBuilder?.let { + it.setSmallIcon(org.amnezia.vpn.R.drawable.ic_amnezia_round) + .setContentTitle(header) + .setContentText(message) + .setOnlyAlertOnce(true) + .setPriority(NotificationCompat.PRIORITY_DEFAULT) + .setContentIntent(pendingIntent) + + service.startForeground(CONNECTED_NOTIFICATION_ID, it.build()) + } + } +} diff --git a/client/android/src/org/amnezia/vpn/Prefs.kt b/client/android/src/org/amnezia/vpn/Prefs.kt new file mode 100644 index 00000000..343bf112 --- /dev/null +++ b/client/android/src/org/amnezia/vpn/Prefs.kt @@ -0,0 +1,35 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn + +import android.content.Context +import android.content.SharedPreferences +import android.util.Log +import androidx.security.crypto.EncryptedSharedPreferences +import androidx.security.crypto.MasterKey + +object Prefs { + // Opens and returns an instance of EncryptedSharedPreferences + fun get(context: Context): SharedPreferences { + try { + val mainKey = MasterKey.Builder(context.applicationContext) + .setKeyScheme(MasterKey.KeyScheme.AES256_GCM) + .build() + + val sharedPrefsFile = "com.amnezia.vpn.secure.prefs" + val sharedPreferences: SharedPreferences = EncryptedSharedPreferences.create( + context.applicationContext, + sharedPrefsFile, + mainKey, + EncryptedSharedPreferences.PrefKeyEncryptionScheme.AES256_SIV, + EncryptedSharedPreferences.PrefValueEncryptionScheme.AES256_GCM + ) + return sharedPreferences + } catch (e: Exception) { + Log.e("Android-Prefs", "Getting Encryption Storage failed, plaintext fallback") + return context.getSharedPreferences("com.amnezia.vpn.prefrences", Context.MODE_PRIVATE) + } + } +} diff --git a/client/android/src/org/amnezia/vpn/VPNLogger.kt b/client/android/src/org/amnezia/vpn/VPNLogger.kt new file mode 100644 index 00000000..05c9f79d --- /dev/null +++ b/client/android/src/org/amnezia/vpn/VPNLogger.kt @@ -0,0 +1,72 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn + +import android.content.Context +import java.io.File +import java.time.LocalDateTime +import android.util.Log as nativeLog + +/* + * Drop in replacement for android.util.Log + * Also stores a copy of all logs in tmp/mozilla_deamon_logs.txt +*/ +class Log { + val LOG_MAX_FILE_SIZE = 204800 + private var file: File + private constructor(context: Context) { + val tempDIR = context.cacheDir + file = File(tempDIR, "mozilla_deamon_logs.txt") + if (file.length() > LOG_MAX_FILE_SIZE) { + file.writeText("") + } + } + + companion object { + var instance: Log? = null + fun init(ctx: Context) { + if (instance == null) { + instance = Log(ctx) + } + } + fun i(tag: String, message: String) { + instance?.write("[info] - ($tag) - $message") + if (!BuildConfig.DEBUG) { return; } + nativeLog.i(tag, message) + } + fun v(tag: String, message: String) { + instance?.write("($tag) - $message") + if (!BuildConfig.DEBUG) { return; } + nativeLog.v(tag, message) + } + fun e(tag: String, message: String) { + instance?.write("[error] - ($tag) - $message") + if (!BuildConfig.DEBUG) { return; } + nativeLog.e(tag, message) + } + // Only Prints && Loggs when in debug, noop in release. + fun sensitive(tag: String, message: String?) { + if (!BuildConfig.DEBUG) { return; } + if (message == null) { return; } + e(tag, message) + } + + fun getContent(): String? { + return try { + instance?.file?.readText() + } catch (e: Exception) { + "=== Failed to read Daemon Logs === \n ${e.localizedMessage} " + } + } + + fun clearFile() { + instance?.file?.writeText("") + } + } + private fun write(message: String) { + LocalDateTime.now() + file.appendText("[${LocalDateTime.now()}] $message \n") + } +} diff --git a/client/android/src/org/amnezia/vpn/VPNService.kt b/client/android/src/org/amnezia/vpn/VPNService.kt new file mode 100644 index 00000000..1e18d898 --- /dev/null +++ b/client/android/src/org/amnezia/vpn/VPNService.kt @@ -0,0 +1,309 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn + +import android.content.Context +import android.content.Intent +import android.os.Build +import android.os.IBinder +import android.system.OsConstants +import com.wireguard.android.util.SharedLibraryLoader +import com.wireguard.config.* +import com.wireguard.crypto.Key +import org.json.JSONObject + +class VPNService : android.net.VpnService() { + private val tag = "VPNService" + private var mBinder: VPNServiceBinder = VPNServiceBinder(this) + private var mConfig: JSONObject? = null + private var mConnectionTime: Long = 0 + private var mAlreadyInitialised = false + + private var currentTunnelHandle = -1 + + fun init() { + if (mAlreadyInitialised) { + return + } + Log.init(this) + SharedLibraryLoader.loadSharedLibrary(this, "wg-go") + Log.i(tag, "loaded lib") + Log.e(tag, "Wireguard Version ${wgVersion()}") + mAlreadyInitialised = true + } + + override fun onUnbind(intent: Intent?): Boolean { + if (!isUp) { + // If the Qt Client got closed while we were not connected + // we do not need to stay as a foreground service. + stopForeground(true) + } + return super.onUnbind(intent) + } + + /** + * EntryPoint for the Service, gets Called when AndroidController.cpp + * calles bindService. Returns the [VPNServiceBinder] so QT can send Requests to it. + */ + override fun onBind(intent: Intent?): IBinder? { + Log.v(tag, "Got Bind request") + init() + return mBinder + } + + /** + * Might be the entryPoint if the Service gets Started via an + * Service Intent: Might be from Always-On-Vpn from Settings + * or from Booting the device and having "connect on boot" enabled. + */ + override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { + init() + intent?.let { + if (intent.getBooleanExtra("startOnly", false)) { + Log.i(tag, "Start only!") + return super.onStartCommand(intent, flags, startId) + } + } + // This start is from always-on + if (this.mConfig == null) { + // We don't have tunnel to turn on - Try to create one with last config the service got + val prefs = Prefs.get(this) + val lastConfString = prefs.getString("lastConf", "") + if (lastConfString.isNullOrEmpty()) { + // We have nothing to connect to -> Exit + Log.e( + tag, + "VPN service was triggered without defining a Server or having a tunnel" + ) + return super.onStartCommand(intent, flags, startId) + } + this.mConfig = JSONObject(lastConfString) + } + turnOn(this.mConfig!!) + return super.onStartCommand(intent, flags, startId) + } + + // Invoked when the application is revoked. + // At this moment, the VPN interface is already deactivated by the system. + override fun onRevoke() { + this.turnOff() + super.onRevoke() + } + + var connectionTime: Long = 0 + get() { + return mConnectionTime + } + + var isUp: Boolean + get() { + return currentTunnelHandle >= 0 + } + private set(value) { + if (value) { + mBinder.dispatchEvent(VPNServiceBinder.EVENTS.connected, "") + mConnectionTime = System.currentTimeMillis() + return + } + mBinder.dispatchEvent(VPNServiceBinder.EVENTS.disconnected, "") + mConnectionTime = 0 + } + val status: JSONObject + get() { + val deviceIpv4: String = "" + return JSONObject().apply { + putOpt("rx_bytes", getConfigValue("rx_bytes")) + putOpt("tx_bytes", getConfigValue("tx_bytes")) + putOpt("endpoint", mConfig?.getJSONObject("server")?.getString("ipv4Gateway")) + putOpt("deviceIpv4", mConfig?.getJSONObject("device")?.getString("ipv4Address")) + } + } + /* + * Checks if the VPN Permission is given. + * If the permission is given, returns true + * Requests permission and returns false if not. + */ + fun checkPermissions(): Boolean { + // See https://developer.android.com/guide/topics/connectivity/vpn#connect_a_service + // Call Prepare, if we get an Intent back, we dont have the VPN Permission + // from the user. So we need to pass this to our main Activity and exit here. + val intent = prepare(this) + if (intent == null) { + Log.e(tag, "VPN Permission Already Present") + return true + } + Log.e(tag, "Requesting VPN Permission") + return false + } + + fun turnOn(json: JSONObject) { + Log.sensitive(tag, json.toString()) + val wireguard_conf = buildWireugardConfig(json) + + if (!checkPermissions()) { + Log.e(tag, "turn on was called without no permissions present!") + isUp = false + return + } + Log.i(tag, "Permission okay") + if (currentTunnelHandle != -1) { + Log.e(tag, "Tunnel already up") + // Turn the tunnel down because this might be a switch + wgTurnOff(currentTunnelHandle) + } + val wgConfig: String = wireguard_conf!!.toWgUserspaceString() + val builder = Builder() + setupBuilder(wireguard_conf, builder) + builder.setSession("mvpn0") + builder.establish().use { tun -> + if (tun == null)return + Log.i(tag, "Go backend " + wgVersion()) + currentTunnelHandle = wgTurnOn("mvpn0", tun.detachFd(), wgConfig) + } + if (currentTunnelHandle < 0) { + Log.e(tag, "Activation Error Code -> $currentTunnelHandle") + isUp = false + return + } + protect(wgGetSocketV4(currentTunnelHandle)) + protect(wgGetSocketV6(currentTunnelHandle)) + mConfig = json + isUp = true + + // Store the config in case the service gets + // asked boot vpn from the OS + val prefs = Prefs.get(this) + prefs.edit() + .putString("lastConf", json.toString()) + .apply() + + NotificationUtil.show(this) // Go foreground + } + + fun turnOff() { + Log.v(tag, "Try to disable tunnel") + wgTurnOff(currentTunnelHandle) + currentTunnelHandle = -1 + stopForeground(false) + isUp = false + } + + /** + * Configures an Android VPN Service Tunnel + * with a given Wireguard Config + */ + private fun setupBuilder(config: Config, builder: Builder) { + // Setup Split tunnel + for (excludedApplication in config.`interface`.excludedApplications) + builder.addDisallowedApplication(excludedApplication) + + // Device IP + for (addr in config.`interface`.addresses) builder.addAddress(addr.address, addr.mask) + // DNS + for (addr in config.`interface`.dnsServers) builder.addDnsServer(addr.hostAddress) + // Add All routes the VPN may route tos + for (peer in config.peers) { + for (addr in peer.allowedIps) { + builder.addRoute(addr.address, addr.mask) + } + } + builder.allowFamily(OsConstants.AF_INET) + builder.allowFamily(OsConstants.AF_INET6) + builder.setMtu(config.`interface`.mtu.orElse(1280)) + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) builder.setMetered(false) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) setUnderlyingNetworks(null) + + builder.setBlocking(true) + } + + /** + * Gets config value for {key} from the Current + * running Wireguard tunnel + */ + private fun getConfigValue(key: String): String? { + if (!isUp) { + return null + } + val config = wgGetConfig(currentTunnelHandle) ?: return null + val lines = config.split("\n") + for (line in lines) { + val parts = line.split("=") + val k = parts.first() + val value = parts.last() + if (key == k) { + return value + } + } + return null + } + + /** + * Create a Wireguard [Config] from a [json] string - + * The [json] will be created in AndroidController.cpp + */ + private fun buildWireugardConfig(obj: JSONObject): Config { + val confBuilder = Config.Builder() + val jServer = obj.getJSONObject("server") + val peerBuilder = Peer.Builder() + val ep = + InetEndpoint.parse(jServer.getString("ipv4AddrIn") + ":" + jServer.getString("port")) + peerBuilder.setEndpoint(ep) + peerBuilder.setPublicKey(Key.fromBase64(jServer.getString("publicKey"))) + + val jAllowedIPList = obj.getJSONArray("allowedIPs") + if (jAllowedIPList.length() == 0) { + val internet = InetNetwork.parse("0.0.0.0/0") // aka The whole internet. + peerBuilder.addAllowedIp(internet) + } else { + (0 until jAllowedIPList.length()).toList().forEach { + val network = InetNetwork.parse(jAllowedIPList.getString(it)) + peerBuilder.addAllowedIp(network) + } + } + + confBuilder.addPeer(peerBuilder.build()) + + val privateKey = obj.getJSONObject("keys").getString("privateKey") + val jDevice = obj.getJSONObject("device") + + val ifaceBuilder = Interface.Builder() + ifaceBuilder.parsePrivateKey(privateKey) + ifaceBuilder.addAddress(InetNetwork.parse(jDevice.getString("ipv4Address"))) + ifaceBuilder.addAddress(InetNetwork.parse(jDevice.getString("ipv6Address"))) + ifaceBuilder.addDnsServer(InetNetwork.parse(obj.getString("dns")).address) + val jExcludedApplication = obj.getJSONArray("excludedApps") + (0 until jExcludedApplication.length()).toList().forEach { + val appName = jExcludedApplication.get(it).toString() + ifaceBuilder.excludeApplication(appName) + } + confBuilder.setInterface(ifaceBuilder.build()) + return confBuilder.build() + } + + companion object { + @JvmStatic + fun startService(c: Context) { + c.applicationContext.startService( + Intent(c.applicationContext, VPNService::class.java).apply { + putExtra("startOnly", true) + } + ) + } + + @JvmStatic + private external fun wgGetConfig(handle: Int): String? + @JvmStatic + private external fun wgGetSocketV4(handle: Int): Int + @JvmStatic + private external fun wgGetSocketV6(handle: Int): Int + @JvmStatic + private external fun wgTurnOff(handle: Int) + @JvmStatic + private external fun wgTurnOn(ifName: String, tunFd: Int, settings: String): Int + @JvmStatic + private external fun wgVersion(): String? + } +} diff --git a/client/android/src/org/amnezia/vpn/VPNServiceBinder.kt b/client/android/src/org/amnezia/vpn/VPNServiceBinder.kt new file mode 100644 index 00000000..d64e1939 --- /dev/null +++ b/client/android/src/org/amnezia/vpn/VPNServiceBinder.kt @@ -0,0 +1,170 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn +import android.os.Binder +import android.os.DeadObjectException +import android.os.IBinder +import android.os.Parcel +import com.wireguard.config.* +import org.json.JSONObject +import java.lang.Exception + +class VPNServiceBinder(service: VPNService) : Binder() { + + private val mService = service + private val tag = "VPNServiceBinder" + private var mListener: IBinder? = null + private var mResumeConfig: JSONObject? = null + + /** + * The codes this Binder does accept in [onTransact] + */ + object ACTIONS { + const val activate = 1 + const val deactivate = 2 + const val registerEventListener = 3 + const val requestStatistic = 4 + const val requestGetLog = 5 + const val requestCleanupLog = 6 + const val resumeActivate = 7 + const val setNotificationText = 8 + const val setFallBackNotification = 9 + } + + /** + * Gets called when the VPNServiceBinder gets a request from a Client. + * The [code] determines what action is requested. - see [ACTIONS] + * [data] may contain a utf-8 encoded json string with optional args or is null. + * [reply] is a pointer to a buffer in the clients memory, to reply results. + * we use this to send result data. + * + * returns true if the [code] was accepted + */ + override fun onTransact(code: Int, data: Parcel, reply: Parcel?, flags: Int): Boolean { + Log.i(tag, "GOT TRANSACTION $code") + + when (code) { + ACTIONS.activate -> { + try { + Log.i(tag, "Activiation Requested, parsing Config") + // [data] is here a json containing the wireguard conf + val buffer = data.createByteArray() + val json = buffer?.let { String(it) } + val config = JSONObject(json) + Log.v(tag, "Stored new Tunnel config in Service") + + if (!mService.checkPermissions()) { + mResumeConfig = config + // The Permission prompt was already + // send, in case it's accepted we will + // receive ACTIONS.resumeActivate + return true + } + this.mService.turnOn(config) + } catch (e: Exception) { + Log.e(tag, "An Error occurred while enabling the VPN: ${e.localizedMessage}") + dispatchEvent(EVENTS.activationError, e.localizedMessage) + } + return true + } + + ACTIONS.resumeActivate -> { + // [data] is empty + // Activate the current tunnel + try { + mResumeConfig?.let { this.mService.turnOn(it) } + } catch (e: Exception) { + Log.e(tag, "An Error occurred while enabling the VPN: ${e.localizedMessage}") + } + return true + } + + ACTIONS.deactivate -> { + // [data] here is empty + this.mService.turnOff() + return true + } + + ACTIONS.registerEventListener -> { + // [data] contains the Binder that we need to dispatch the Events + val binder = data.readStrongBinder() + mListener = binder + val obj = JSONObject() + obj.put("connected", mService.isUp) + obj.put("time", mService.connectionTime) + dispatchEvent(EVENTS.init, obj.toString()) + return true + } + + ACTIONS.requestStatistic -> { + dispatchEvent(EVENTS.statisticUpdate, mService.status.toString()) + return true + } + + ACTIONS.requestGetLog -> { + // Grabs all the Logs and dispatch new Log Event + dispatchEvent(EVENTS.backendLogs, Log.getContent()) + return true + } + ACTIONS.requestCleanupLog -> { + Log.clearFile() + return true + } + ACTIONS.setNotificationText -> { + NotificationUtil.update(data) + return true + } + ACTIONS.setFallBackNotification -> { + NotificationUtil.saveFallBackMessage(data, mService) + return true + } + IBinder.LAST_CALL_TRANSACTION -> { + Log.e(tag, "The OS Requested to shut down the VPN") + this.mService.turnOff() + return true + } + + else -> { + Log.e(tag, "Received invalid bind request \t Code -> $code") + // If we're hitting this there is probably something wrong in the client. + return false + } + } + return false + } + + /** + * Dispatches an Event to all registered Binders + * [code] the Event that happened - see [EVENTS] + * To register an Eventhandler use [onTransact] with + * [ACTIONS.registerEventListener] + */ + fun dispatchEvent(code: Int, payload: String?) { + try { + mListener?.let { + if (it.isBinderAlive) { + val data = Parcel.obtain() + data.writeByteArray(payload?.toByteArray(charset("UTF-8"))) + it.transact(code, data, Parcel.obtain(), 0) + } + } + } catch (e: DeadObjectException) { + // If the QT Process is killed (not just inactive) + // we cant access isBinderAlive, so nothing to do here. + } + } + + /** + * The codes we Are Using in case of [dispatchEvent] + */ + object EVENTS { + const val init = 0 + const val connected = 1 + const val disconnected = 2 + const val statisticUpdate = 3 + const val backendLogs = 4 + const val activationError = 5 + } +} diff --git a/client/android/src/org/amnezia/vpn/qt/PackageManagerHelper.java b/client/android/src/org/amnezia/vpn/qt/PackageManagerHelper.java new file mode 100644 index 00000000..ae0991f9 --- /dev/null +++ b/client/android/src/org/amnezia/vpn/qt/PackageManagerHelper.java @@ -0,0 +1,189 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn.qt; + +import android.Manifest; +import android.content.Context; +import android.content.Intent; +import android.content.pm.ApplicationInfo; +import android.content.pm.PackageInfo; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.graphics.Color; +import android.graphics.drawable.ColorDrawable; +import android.graphics.drawable.Drawable; +import android.Manifest.permission; +import android.net.Uri; +import android.os.Build; +import android.util.Log; +import android.webkit.WebView; + +import org.json.JSONException; +import org.json.JSONObject; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +// Gets used by /platforms/android/androidAppListProvider.cpp +public class PackageManagerHelper { + final static String TAG = "PackageManagerHelper"; + final static int MIN_CHROME_VERSION = 65; + + final static List CHROME_BROWSERS = Arrays.asList( + new String[] {"com.google.android.webview", "com.android.webview", "com.google.chrome"}); + + private static String getAllAppNames(Context ctx) { + JSONObject output = new JSONObject(); + PackageManager pm = ctx.getPackageManager(); + List browsers = getBrowserIDs(pm); + List packs = pm.getInstalledPackages(PackageManager.GET_PERMISSIONS); + for (int i = 0; i < packs.size(); i++) { + PackageInfo p = packs.get(i); + // Do not add ourselves and System Apps to the list, unless it might be a browser + if ((!isSystemPackage(p,pm) || browsers.contains(p.packageName)) + && !isSelf(p)) { + String appid = p.packageName; + String appName = p.applicationInfo.loadLabel(pm).toString(); + try { + output.put(appid, appName); + } catch (JSONException e) { + e.printStackTrace(); + } + } + } + return output.toString(); + } + + private static Drawable getAppIcon(Context ctx, String id) { + try { + return ctx.getPackageManager().getApplicationIcon(id); + } catch (PackageManager.NameNotFoundException e) { + e.printStackTrace(); + } + return new ColorDrawable(Color.TRANSPARENT); + } + + private static boolean isSystemPackage(PackageInfo pkgInfo, PackageManager pm) { + if( (pkgInfo.applicationInfo.flags & ApplicationInfo.FLAG_SYSTEM) == 0){ + // no system app + return false; + } + // For Systems Packages there are Cases where we want to add it anyway: + // Has the use Internet permission (otherwise makes no sense) + // Had at least 1 update (this means it's probably on any AppStore) + // Has a a launch activity (has a ui and is not just a system service) + + if(!usesInternet(pkgInfo)){ + return true; + } + if(!hadUpdate(pkgInfo)){ + return true; + } + if(pm.getLaunchIntentForPackage(pkgInfo.packageName) == null){ + // If there is no way to launch this from a homescreen, def a sys package + return true; + } + return false; + } + private static boolean isSelf(PackageInfo pkgInfo) { + return pkgInfo.packageName.equals("org.amnezia.vpn") + || pkgInfo.packageName.equals("org.amnezia.vpn.debug"); + } + private static boolean usesInternet(PackageInfo pkgInfo){ + if(pkgInfo.requestedPermissions == null){ + return false; + } + for(int i=0; i < pkgInfo.requestedPermissions.length; i++) { + String permission = pkgInfo.requestedPermissions[i]; + if(Manifest.permission.INTERNET.equals(permission)){ + return true; + } + } + return false; + } + private static boolean hadUpdate(PackageInfo pkgInfo){ + return pkgInfo.lastUpdateTime > pkgInfo.firstInstallTime; + } + + // Returns List of all Packages that can classify themselves as browsers + private static List getBrowserIDs(PackageManager pm) { + Intent intent = new Intent(Intent.ACTION_VIEW, Uri.parse("https://www.mozilla.org/")); + intent.addCategory(Intent.CATEGORY_BROWSABLE); + // We've tried using PackageManager.MATCH_DEFAULT_ONLY flag and found that browsers that + // are not set as the default browser won't be matched even if they had CATEGORY_DEFAULT set + // in the intent filter + + List resolveInfos = pm.queryIntentActivities(intent, PackageManager.MATCH_ALL); + List browsers = new ArrayList(); + for (int i = 0; i < resolveInfos.size(); i++) { + ResolveInfo info = resolveInfos.get(i); + String browserID = info.activityInfo.packageName; + browsers.add(browserID); + } + return browsers; + } + + // Gets called in AndroidAuthenticationListener; + public static boolean isWebViewSupported(Context ctx) { + Log.v(TAG, "Checking if installed Webview is compatible with FxA"); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { + // The default Webview is able do to FXA + return true; + } + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + PackageInfo pi = WebView.getCurrentWebViewPackage(); + if (CHROME_BROWSERS.contains(pi.packageName)) { + return isSupportedChromeBrowser(pi); + } + return isNotAncientBrowser(pi); + } + + // Before O the webview is hardcoded, but we dont know which package it is. + // Check if com.google.android.webview is installed + PackageManager pm = ctx.getPackageManager(); + try { + PackageInfo pi = pm.getPackageInfo("com.google.android.webview", 0); + return isSupportedChromeBrowser(pi); + } catch (PackageManager.NameNotFoundException e) { + } + // Otherwise check com.android.webview + try { + PackageInfo pi = pm.getPackageInfo("com.android.webview", 0); + return isSupportedChromeBrowser(pi); + } catch (PackageManager.NameNotFoundException e) { + } + Log.e(TAG, "Android System WebView is not found"); + // Giving up :( + return false; + } + + private static boolean isSupportedChromeBrowser(PackageInfo pi) { + Log.d(TAG, "Checking Chrome Based Browser: " + pi.packageName); + Log.d(TAG, "version name: " + pi.versionName); + Log.d(TAG, "version code: " + pi.versionCode); + try { + String versionCode = pi.versionName.split(Pattern.quote(" "))[0]; + String majorVersion = versionCode.split(Pattern.quote("."))[0]; + int version = Integer.parseInt(majorVersion); + return version >= MIN_CHROME_VERSION; + } catch (Exception e) { + Log.e(TAG, "Failed to check Chrome Version Code " + pi.versionName); + return false; + } + } + + private static boolean isNotAncientBrowser(PackageInfo pi) { + // Not a google chrome - So the version name is worthless + // Lets just make sure the WebView + // used is not ancient ==> Was updated in at least the last 365 days + Log.d(TAG, "Checking Chrome Based Browser: " + pi.packageName); + Log.d(TAG, "version name: " + pi.versionName); + Log.d(TAG, "version code: " + pi.versionCode); + double oneYearInMillis = 31536000000L; + return pi.lastUpdateTime > (System.currentTimeMillis() - oneYearInMillis); + } +} diff --git a/client/android/src/org/amnezia/vpn/qt/VPNActivity.java b/client/android/src/org/amnezia/vpn/qt/VPNActivity.java new file mode 100644 index 00000000..33ae513e --- /dev/null +++ b/client/android/src/org/amnezia/vpn/qt/VPNActivity.java @@ -0,0 +1,33 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn.qt; + +import android.view.KeyEvent; + +public class VPNActivity extends org.qtproject.qt5.android.bindings.QtActivity { + + @Override + public boolean onKeyDown(int keyCode, KeyEvent event) { + if (keyCode == KeyEvent.KEYCODE_BACK && event.getRepeatCount() == 0) { + onBackPressed(); + return true; + } + return super.onKeyDown(keyCode, event); + } + + @Override + public void onBackPressed() { + try { + if (!handleBackButton()) { + // Move the activity into paused state if back button was pressed + moveTaskToBack(true); + } + } catch (Exception e) { + } + } + + // Returns true if MVPN has handled the back button + native boolean handleBackButton(); +} diff --git a/client/android/src/org/amnezia/vpn/qt/VPNApplication.java b/client/android/src/org/amnezia/vpn/qt/VPNApplication.java new file mode 100644 index 00000000..3131349a --- /dev/null +++ b/client/android/src/org/amnezia/vpn/qt/VPNApplication.java @@ -0,0 +1,18 @@ +package org.amnezia.vpn.qt; + +import android.app.Activity; +import android.os.Bundle; + +import org.amnezia.vpn.BuildConfig; + +public class VPNApplication extends org.qtproject.qt5.android.bindings.QtApplication { + + private static VPNApplication instance; + + @Override + public void onCreate() { + super.onCreate(); + VPNApplication.instance = this; + } + +} diff --git a/client/android/src/org/amnezia/vpn/qt/VPNPermissionHelper.kt b/client/android/src/org/amnezia/vpn/qt/VPNPermissionHelper.kt new file mode 100644 index 00000000..7ae82f66 --- /dev/null +++ b/client/android/src/org/amnezia/vpn/qt/VPNPermissionHelper.kt @@ -0,0 +1,37 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package org.amnezia.vpn.qt + +import android.content.Context +import android.content.Intent + +class VPNPermissionHelper : android.net.VpnService() { + /** + * This small service does nothing else then checking if the vpn permission + * is present and prompting if not. + */ + override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { + val intent = prepare(this.applicationContext) + if (intent != null) { + startActivityForResult(intent) + } + return START_NOT_STICKY + } + + companion object { + @JvmStatic + fun startService(c: Context) { + val appC = c.applicationContext + appC.startService(Intent(appC, VPNPermissionHelper::class.java)) + } + } + + /** + * Fetches the Global QTAndroidActivity and calls startActivityForResult with the given intent + * Is used to request the VPN-Permission, if not given. + * Actually Implemented in src/platforms/android/AndroidController.cpp + */ + external fun startActivityForResult(i: Intent) +} diff --git a/client/client.pro b/client/client.pro index efc9708b..cec0abff 100644 --- a/client/client.pro +++ b/client/client.pro @@ -34,6 +34,7 @@ HEADERS += \ debug.h \ defines.h \ managementserver.h \ + protocols/android_vpnprotocol.h \ protocols/openvpnovercloakprotocol.h \ protocols/protocols_defs.h \ protocols/shadowsocksvpnprotocol.h \ @@ -87,6 +88,7 @@ SOURCES += \ debug.cpp \ main.cpp \ managementserver.cpp \ + protocols/android_vpnprotocol.cpp \ protocols/openvpnovercloakprotocol.cpp \ protocols/protocols_defs.cpp \ protocols/shadowsocksvpnprotocol.cpp \ @@ -174,10 +176,24 @@ macx { LIBS += -framework Cocoa -framework ApplicationServices -framework CoreServices -framework Foundation -framework AppKit -framework Security } +android { + QT += androidextras + + INCLUDEPATH += platforms/android + + DISTFILES += \ + android/AndroidManifest.xml \ + android/build.gradle \ + android/gradle/wrapper/gradle-wrapper.jar \ + android/gradle/wrapper/gradle-wrapper.properties \ + android/gradlew \ + android/gradlew.bat \ + android/res/values/libs.xml + + ANDROID_PACKAGE_SOURCE_DIR = $$PWD/android +} + REPC_REPLICA += ../ipc/ipc_interface.rep !ios: REPC_REPLICA += ../ipc/ipc_process_interface.rep -DISTFILES += \ - android/AndroidManifest.xml -ANDROID_PACKAGE_SOURCE_DIR = $$PWD/android diff --git a/client/core/defs.h b/client/core/defs.h index 4ea03dbe..02d6cad6 100644 --- a/client/core/defs.h +++ b/client/core/defs.h @@ -66,6 +66,7 @@ enum ErrorCode CloakExecutableCrashed }; + namespace config { // config keys const char key_openvpn_config_data[] = "openvpn_config_data"; @@ -78,4 +79,6 @@ const char key_wireguard_config_data[] = "wireguard_config_data"; } // namespace amnezia +Q_DECLARE_METATYPE(amnezia::ErrorCode) + #endif // DEFS_H diff --git a/client/protocols/android_vpnprotocol.cpp b/client/protocols/android_vpnprotocol.cpp new file mode 100644 index 00000000..1f44379d --- /dev/null +++ b/client/protocols/android_vpnprotocol.cpp @@ -0,0 +1,344 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "android_vpnprotocol.h" +#include "core/errorstrings.h" + +// Binder Codes for VPNServiceBinder +// See also - VPNServiceBinder.kt +// Actions that are Requestable +const int ACTION_ACTIVATE = 1; +const int ACTION_DEACTIVATE = 2; +const int ACTION_REGISTERLISTENER = 3; +const int ACTION_REQUEST_STATISTIC = 4; +const int ACTION_REQUEST_GET_LOG = 5; +const int ACTION_REQUEST_CLEANUP_LOG = 6; +const int ACTION_RESUME_ACTIVATE = 7; +const int ACTION_SET_NOTIFICATION_TEXT = 8; +const int ACTION_SET_NOTIFICATION_FALLBACK = 9; + +// Event Types that will be Dispatched after registration +const int EVENT_INIT = 0; +const int EVENT_CONNECTED = 1; +const int EVENT_DISCONNECTED = 2; +const int EVENT_STATISTIC_UPDATE = 3; +const int EVENT_BACKEND_LOGS = 4; +const int EVENT_ACTIVATION_ERROR = 5; + +namespace { +AndroidVpnProtocol* s_instance = nullptr; + +constexpr auto PERMISSIONHELPER_CLASS = + "org/amnezia/vpn/VPNPermissionHelper"; + +} // namespace + +AndroidVpnProtocol::AndroidVpnProtocol(Protocol protocol, const QJsonObject &configuration, QObject* parent) + : VpnProtocol(configuration, parent), + m_protocol(protocol), + m_binder(this) +{ + +} + +AndroidVpnProtocol* AndroidVpnProtocol::instance() { + return s_instance; +} + +void AndroidVpnProtocol::initialize() +{ + qDebug() << "Initializing"; + + // Hook in the native implementation for startActivityForResult into the JNI + JNINativeMethod methods[]{{"startActivityForResult", + "(Landroid/content/Intent;)V", + reinterpret_cast(startActivityForResult)}}; + QAndroidJniObject javaClass(PERMISSIONHELPER_CLASS); + QAndroidJniEnvironment env; + jclass objectClass = env->GetObjectClass(javaClass.object()); + env->RegisterNatives(objectClass, methods, + sizeof(methods) / sizeof(methods[0])); + env->DeleteLocalRef(objectClass); + + auto appContext = QtAndroid::androidActivity().callObjectMethod( + "getApplicationContext", "()Landroid/content/Context;"); + + QAndroidJniObject::callStaticMethod( + "org/amnezia/vpn/VPNService", "startService", + "(Landroid/content/Context;)V", appContext.object()); + + // Start the VPN Service (if not yet) and Bind to it + QtAndroid::bindService( + QAndroidIntent(appContext.object(), "org.amnezia.vpn.VPNService"), + *this, QtAndroid::BindFlag::AutoCreate); +} + +ErrorCode AndroidVpnProtocol::start() +{ + qDebug() << "Prompting for VPN permission"; + auto appContext = QtAndroid::androidActivity().callObjectMethod( + "getApplicationContext", "()Landroid/content/Context;"); + QAndroidJniObject::callStaticMethod( + PERMISSIONHELPER_CLASS, "startService", "(Landroid/content/Context;)V", + appContext.object()); + + +// QJsonObject jServer; +// jServer["ipv4AddrIn"] = server.ipv4AddrIn(); +// jServer["ipv4Gateway"] = server.ipv4Gateway(); +// jServer["ipv6AddrIn"] = server.ipv6AddrIn(); +// jServer["ipv6Gateway"] = server.ipv6Gateway(); +// jServer["publicKey"] = server.publicKey(); +// jServer["port"] = (int)server.choosePort(); + +// QJsonArray allowedIPs; +// foreach (auto item, allowedIPAddressRanges) { +// QJsonValue val; +// val = item.toString(); +// allowedIPs.append(val); +// } + +// QJsonArray excludedApps; +// foreach (auto appID, vpnDisabledApps) { +// excludedApps.append(QJsonValue(appID)); +// } + +// QJsonObject args; +// args["device"] = jDevice; +// args["keys"] = jKeys; +// args["server"] = jServer; +// args["reason"] = (int)reason; +// args["allowedIPs"] = allowedIPs; +// args["excludedApps"] = excludedApps; +// args["dns"] = dns.toString(); + + QAndroidParcel sendData; + sendData.writeData(QJsonDocument(m_rawConfig).toJson()); + m_serviceBinder.transact(ACTION_ACTIVATE, sendData, nullptr); +} + +// Activates the tunnel that is currently set +// in the VPN Service +void AndroidVpnProtocol::resume_start() { + QAndroidParcel nullData; + m_serviceBinder.transact(ACTION_RESUME_ACTIVATE, nullData, nullptr); +} + +void AndroidVpnProtocol::stop() { + qDebug() << "deactivation"; + +// if (reason != ReasonNone) { +// // Just show that we're disconnected +// // we're doing the actual disconnect once +// // the vpn-service has the new server ready in Action->Activate +// emit disconnected(); +// logger.warning() << "deactivation skipped for Switching"; +// return; +// } + + QAndroidParcel nullData; + m_serviceBinder.transact(ACTION_DEACTIVATE, nullData, nullptr); +} + +/* + * Sets the current notification text that is shown + */ +void AndroidVpnProtocol::setNotificationText(const QString& title, + const QString& message, + int timerSec) { + QJsonObject args; + args["title"] = title; + args["message"] = message; + args["sec"] = timerSec; + QJsonDocument doc(args); + QAndroidParcel data; + data.writeData(doc.toJson()); + m_serviceBinder.transact(ACTION_SET_NOTIFICATION_TEXT, data, nullptr); +} + +/* + * Sets fallback Notification text that should be shown in case the VPN + * switches into the Connected state without the app open + * e.g via always-on vpn + */ +void AndroidVpnProtocol::setFallbackConnectedNotification() { + QJsonObject args; + args["title"] = qtTrId("vpn.main.productName"); + //% "Ready for you to connect" + //: Refers to the app - which is currently running the background and waiting + args["message"] = qtTrId("vpn.android.notification.isIDLE"); + QJsonDocument doc(args); + QAndroidParcel data; + data.writeData(doc.toJson()); + m_serviceBinder.transact(ACTION_SET_NOTIFICATION_FALLBACK, data, nullptr); +} + +void AndroidVpnProtocol::checkStatus() { + qDebug() << "check status"; + + QAndroidParcel nullParcel; + m_serviceBinder.transact(ACTION_REQUEST_STATISTIC, nullParcel, nullptr); +} + +void AndroidVpnProtocol::getBackendLogs(std::function&& a_callback) { + qDebug() << "get logs"; + + m_logCallback = std::move(a_callback); + QAndroidParcel nullData, replyData; + m_serviceBinder.transact(ACTION_REQUEST_GET_LOG, nullData, &replyData); +} + +void AndroidVpnProtocol::cleanupBackendLogs() { + qDebug() << "cleanup logs"; + + QAndroidParcel nullParcel; + m_serviceBinder.transact(ACTION_REQUEST_CLEANUP_LOG, nullParcel, nullptr); +} + +void AndroidVpnProtocol::onServiceConnected( + const QString& name, const QAndroidBinder& serviceBinder) { + qDebug() << "Server connected"; + + Q_UNUSED(name); + + m_serviceBinder = serviceBinder; + + // Send the Service our Binder to recive incoming Events + QAndroidParcel binderParcel; + binderParcel.writeBinder(m_binder); + m_serviceBinder.transact(ACTION_REGISTERLISTENER, binderParcel, nullptr); +} + +void AndroidVpnProtocol::onServiceDisconnected(const QString& name) { + qDebug() << "Server disconnected"; + m_serviceConnected = false; + Q_UNUSED(name); + // TODO: Maybe restart? Or crash? +} + + +/** + * @brief AndroidController::VPNBinder::onTransact + * @param code the Event-Type we get From the VPNService See + * @param data - Might contain UTF-8 JSON in case the Event has a payload + * @param reply - always null + * @param flags - unused + * @return Returns true is the code was a valid Event Code + */ +bool AndroidVpnProtocol::VPNBinder::onTransact(int code, + const QAndroidParcel& data, + const QAndroidParcel& reply, + QAndroidBinder::CallType flags) { + Q_UNUSED(data); + Q_UNUSED(reply); + Q_UNUSED(flags); + + QJsonDocument doc; + QString buffer; + switch (code) { + case EVENT_INIT: + qDebug() << "Transact: init"; + doc = QJsonDocument::fromJson(data.readData()); + emit m_controller->initialized( + true, doc.object()["connected"].toBool(), + QDateTime::fromMSecsSinceEpoch( + doc.object()["time"].toVariant().toLongLong())); + // Pass a localised version of the Fallback string for the Notification + m_controller->setFallbackConnectedNotification(); + + break; + case EVENT_CONNECTED: + qDebug() << "Transact: connected"; + m_controller->setConnectionState(Connected); + break; + case EVENT_DISCONNECTED: + qDebug() << "Transact: disconnected"; + m_controller->setConnectionState(Disconnected); + break; + case EVENT_STATISTIC_UPDATE: + qDebug() << "Transact:: update"; + + // Data is here a JSON String + doc = QJsonDocument::fromJson(data.readData()); + // TODO update counters +// emit m_controller->statusUpdated(doc.object()["endpoint"].toString(), +// doc.object()["deviceIpv4"].toString(), +// doc.object()["totalTX"].toInt(), +// doc.object()["totalRX"].toInt()); + break; + case EVENT_BACKEND_LOGS: + qDebug() << "Transact: backend logs"; + + buffer = readUTF8Parcel(data); + if (m_controller->m_logCallback) { + m_controller->m_logCallback(buffer); + } + break; + case EVENT_ACTIVATION_ERROR: + m_controller->setConnectionState(Error); + default: + qWarning() << "Transact: Invalid!"; + break; + } + + return true; +} + +QString AndroidVpnProtocol::VPNBinder::readUTF8Parcel(QAndroidParcel data) { + // 106 is the Code for UTF-8 + return QTextCodec::codecForMib(106)->toUnicode(data.readData()); +} + +const int ACTIVITY_RESULT_OK = 0xffffffff; +/** + * @brief Starts the Given intent in Context of the QTActivity + * @param env + * @param intent + */ +void AndroidVpnProtocol::startActivityForResult(JNIEnv *env, jobject, jobject intent) +{ + qDebug() << "start activity"; + Q_UNUSED(env); + QtAndroid::startActivity(intent, 1337, + [](int receiverRequestCode, int resultCode, + const QAndroidJniObject& data) { + // Currently this function just used in + // VPNService.kt::checkPersmissions. So the result + // we're getting is if the User gave us the + // Vpn.bind permission. In case of NO we should + // abort. + Q_UNUSED(receiverRequestCode); + Q_UNUSED(data); + + AndroidVpnProtocol* controller = + AndroidVpnProtocol::instance(); + if (!controller) { + return; + } + + if (resultCode == ACTIVITY_RESULT_OK) { + qDebug() << "VPN PROMPT RESULT - Accepted"; + controller->resume_start(); + return; + } + // If the request got rejected abort the current + // connection. + qWarning() << "VPN PROMPT RESULT - Rejected"; + controller->setConnectionState(Disconnected); + + }); + return; +} diff --git a/client/protocols/android_vpnprotocol.h b/client/protocols/android_vpnprotocol.h new file mode 100644 index 00000000..70e1fc1a --- /dev/null +++ b/client/protocols/android_vpnprotocol.h @@ -0,0 +1,82 @@ +#ifndef ANDROID_VPNPROTOCOL_H +#define ANDROID_VPNPROTOCOL_H + +#include +#include + +#include "vpnprotocol.h" +#include "protocols/protocols_defs.h" + +using namespace amnezia; + + + +class AndroidVpnProtocol : public VpnProtocol, + public QAndroidServiceConnection +{ + Q_OBJECT + +public: + explicit AndroidVpnProtocol(Protocol protocol, const QJsonObject& configuration, QObject* parent = nullptr); + static AndroidVpnProtocol* instance(); + + virtual ~AndroidVpnProtocol() override = default; + + void initialize(); + + virtual ErrorCode start() override; + virtual void stop() override; + + void resume_start(); + + void checkStatus(); + + void setNotificationText(const QString& title, const QString& message, + int timerSec); + void setFallbackConnectedNotification(); + + void getBackendLogs(std::function&& callback); + + void cleanupBackendLogs(); + + // from QAndroidServiceConnection + void onServiceConnected(const QString& name, + const QAndroidBinder& serviceBinder) override; + void onServiceDisconnected(const QString& name) override; + +signals: + + +protected slots: + +protected: + + +private: + Protocol m_protocol; + + bool m_serviceConnected = false; + std::function m_logCallback; + + QAndroidBinder m_serviceBinder; + class VPNBinder : public QAndroidBinder { + public: + VPNBinder(AndroidVpnProtocol* controller) : m_controller(controller) {} + + bool onTransact(int code, const QAndroidParcel& data, + const QAndroidParcel& reply, + QAndroidBinder::CallType flags) override; + + QString readUTF8Parcel(QAndroidParcel data); + + private: + AndroidVpnProtocol* m_controller = nullptr; + }; + + VPNBinder m_binder; + + static void startActivityForResult(JNIEnv* env, jobject /*thiz*/, + jobject intent); +}; + +#endif // ANDROID_VPNPROTOCOL_H diff --git a/client/protocols/vpnprotocol.cpp b/client/protocols/vpnprotocol.cpp index f4b50a7b..2a6cac4b 100644 --- a/client/protocols/vpnprotocol.cpp +++ b/client/protocols/vpnprotocol.cpp @@ -3,7 +3,6 @@ #include "vpnprotocol.h" #include "core/errorstrings.h" -#include "containers/containers_defs.h" VpnProtocol::VpnProtocol(const QJsonObject &configuration, QObject* parent) : QObject(parent), diff --git a/client/protocols/vpnprotocol.h b/client/protocols/vpnprotocol.h index 0a12dd35..59c1278d 100644 --- a/client/protocols/vpnprotocol.h +++ b/client/protocols/vpnprotocol.h @@ -43,6 +43,13 @@ signals: void timeoutTimerEvent(); void protocolError(amnezia::ErrorCode e); + // This signal is emitted when the controller is initialized. Note that the + // VPN tunnel can be already active. In this case, "connected" should be set + // to true and the "connectionDate" should be set to the activation date if + // known. + // If "status" is set to false, the backend service is considered unavailable. + void initialized(bool status, bool connected, + const QDateTime& connectionDate); protected slots: virtual void onTimeout(); diff --git a/client/ui/pages_logic/protocols/OtherProtocolsLogic.cpp b/client/ui/pages_logic/protocols/OtherProtocolsLogic.cpp index 6b6ea53c..d9cf07e9 100644 --- a/client/ui/pages_logic/protocols/OtherProtocolsLogic.cpp +++ b/client/ui/pages_logic/protocols/OtherProtocolsLogic.cpp @@ -9,7 +9,9 @@ #include "../../uilogic.h" #include "utils.h" +#ifdef Q_OS_WINDOWS #include +#endif using namespace amnezia; using namespace PageEnumNS; @@ -22,12 +24,14 @@ OtherProtocolsLogic::OtherProtocolsLogic(UiLogic *logic, QObject *parent): OtherProtocolsLogic::~OtherProtocolsLogic() { - for (QProcess *p: m_sftpMpuntProcesses) { +#ifdef Q_OS_WINDOWS + for (QProcess *p: m_sftpMountProcesses) { if (p) Utils::signalCtrl(p->processId(), CTRL_C_EVENT); if (p) p->kill(); if (p) p->waitForFinished(); if (p) delete p; } +#endif } void OtherProtocolsLogic::updateProtocolPage(const QJsonObject &config, DockerContainer container, bool haveAuthData) diff --git a/client/ui/pages_logic/protocols/OtherProtocolsLogic.h b/client/ui/pages_logic/protocols/OtherProtocolsLogic.h index 0612b936..dc099186 100644 --- a/client/ui/pages_logic/protocols/OtherProtocolsLogic.h +++ b/client/ui/pages_logic/protocols/OtherProtocolsLogic.h @@ -34,7 +34,9 @@ private: Settings m_settings; UiLogic *m_uiLogic; - QList m_sftpMpuntProcesses; +#ifdef Q_OS_WINDOWS + QList m_sftpMountProcesses; +#endif }; #endif // OTHER_PROTOCOLS_LOGIC_H diff --git a/service/server/main.cpp b/service/server/main.cpp index bff4fc04..d0d7d82a 100644 --- a/service/server/main.cpp +++ b/service/server/main.cpp @@ -6,15 +6,16 @@ #include "systemservice.h" #include "utils.h" + int runApplication(int argc, char** argv) { QCoreApplication app(argc,argv); LocalServer localServer; -// if (!localServer.isRunning()) { -// return -1; -// } + return app.exec(); } + + int main(int argc, char **argv) { Utils::initializePath(Utils::systemLogPath()); @@ -24,19 +25,15 @@ int main(int argc, char **argv) if (argc == 2) { qInfo() << "Started as console application"; return runApplication(argc, argv); - } else { + } + else { qInfo() << "Started as system service"; #ifdef Q_OS_WIN SystemService systemService(argc, argv); return systemService.exec(); - #else - //daemon(0,0); - return runApplication(argc, argv); + return runApplication(argc, argv); #endif } - - // Never reached - return 0; } diff --git a/service/server/server.pro b/service/server/server.pro index e08fc80f..2c7097ec 100644 --- a/service/server/server.pro +++ b/service/server/server.pro @@ -64,13 +64,6 @@ SOURCES += \ include(../src/qtservice.pri) -#CONFIG(release, debug|release) { -# DESTDIR = $$PWD/../../../AmneziaVPN-build/server/release -# MOC_DIR = $$DESTDIR -# OBJECTS_DIR = $$DESTDIR -# RCC_DIR = $$DESTDIR -#} - INCLUDEPATH += "$$PWD/../../client" REPC_SOURCE += ../../ipc/ipc_interface.rep diff --git a/service/server/systemservice.cpp b/service/server/systemservice.cpp index 3aae8404..8ec36449 100644 --- a/service/server/systemservice.cpp +++ b/service/server/systemservice.cpp @@ -12,10 +12,6 @@ void SystemService::start() { QCoreApplication* app = application(); m_localServer = new LocalServer(); - -// if (!m_localServer->isRunning()) { -// app->quit(); -// } } void SystemService::stop()