diff --git a/app/src/androidTest/java/com/kunzisoft/keepass/tests/crypto/CipherTest.kt b/app/src/androidTest/java/com/kunzisoft/keepass/tests/crypto/CipherTest.kt index 680f48608..b5e3fca2a 100644 --- a/app/src/androidTest/java/com/kunzisoft/keepass/tests/crypto/CipherTest.kt +++ b/app/src/androidTest/java/com/kunzisoft/keepass/tests/crypto/CipherTest.kt @@ -19,28 +19,20 @@ */ package com.kunzisoft.keepass.tests.crypto +import com.kunzisoft.keepass.crypto.CipherFactory +import com.kunzisoft.keepass.crypto.engine.AesEngine +import com.kunzisoft.keepass.stream.BetterCipherInputStream +import com.kunzisoft.keepass.stream.readBytesLength +import junit.framework.TestCase import org.junit.Assert.assertArrayEquals - import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.IOException import java.security.InvalidAlgorithmParameterException import java.security.InvalidKeyException import java.security.NoSuchAlgorithmException -import java.util.Random - -import javax.crypto.BadPaddingException -import javax.crypto.Cipher -import javax.crypto.CipherOutputStream -import javax.crypto.IllegalBlockSizeException -import javax.crypto.NoSuchPaddingException - -import junit.framework.TestCase - -import com.kunzisoft.keepass.crypto.CipherFactory -import com.kunzisoft.keepass.crypto.engine.AesEngine -import com.kunzisoft.keepass.stream.BetterCipherInputStream -import com.kunzisoft.keepass.stream.LittleEndianDataInputStream +import java.util.* +import javax.crypto.* class CipherTest : TestCase() { private val rand = Random() @@ -92,9 +84,8 @@ class CipherTest : TestCase() { val bis = ByteArrayInputStream(secrettext) val cis = BetterCipherInputStream(bis, decrypt) - val lis = LittleEndianDataInputStream(cis) - val decrypttext = lis.readBytes(MESSAGE_LENGTH) + val decrypttext = cis.readBytesLength(MESSAGE_LENGTH) assertArrayEquals("Encryption and decryption failed", plaintext, decrypttext) } diff --git a/app/src/main/java/com/kunzisoft/keepass/database/file/DatabaseHeaderKDBX.kt b/app/src/main/java/com/kunzisoft/keepass/database/file/DatabaseHeaderKDBX.kt index 3f0434fa2..3881217e5 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/file/DatabaseHeaderKDBX.kt +++ b/app/src/main/java/com/kunzisoft/keepass/database/file/DatabaseHeaderKDBX.kt @@ -150,23 +150,22 @@ class DatabaseHeaderKDBX(private val databaseV4: DatabaseKDBX) : DatabaseHeader( val headerBOS = ByteArrayOutputStream() val copyInputStream = CopyInputStream(inputStream, headerBOS) val digestInputStream = DigestInputStream(copyInputStream, messageDigest) - val littleEndianDataInputStream = LittleEndianDataInputStream(digestInputStream) - val sig1 = littleEndianDataInputStream.readUInt() - val sig2 = littleEndianDataInputStream.readUInt() + val sig1 = digestInputStream.readBytes4ToUInt() + val sig2 = digestInputStream.readBytes4ToUInt() if (!matchesHeader(sig1, sig2)) { throw VersionDatabaseException() } - version = littleEndianDataInputStream.readUInt() // Erase previous value + version = digestInputStream.readBytes4ToUInt() // Erase previous value if (!validVersion(version)) { throw VersionDatabaseException() } var done = false while (!done) { - done = readHeaderField(littleEndianDataInputStream) + done = readHeaderField(digestInputStream) } val hash = messageDigest.digest() @@ -174,13 +173,13 @@ class DatabaseHeaderKDBX(private val databaseV4: DatabaseKDBX) : DatabaseHeader( } @Throws(IOException::class) - private fun readHeaderField(dis: LittleEndianDataInputStream): Boolean { + private fun readHeaderField(dis: InputStream): Boolean { val fieldID = dis.read().toByte() val fieldSize: Int = if (version.toKotlinLong() < FILE_VERSION_32_4.toKotlinLong()) { - dis.readUShort() + dis.readBytes2ToUShort() } else { - dis.readUInt().toKotlinInt() + dis.readBytes4ToUInt().toKotlinInt() } var fieldData: ByteArray? = null diff --git a/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt b/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt index cf7947ea6..dc3ef3fc0 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt +++ b/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt @@ -155,11 +155,10 @@ class DatabaseInputKDBX(cacheDirectory: File, val isPlain: InputStream if (mDatabase.kdbxVersion.toKotlinLong() < DatabaseHeaderKDBX.FILE_VERSION_32_4.toKotlinLong()) { - val decrypted = CipherInputStream(databaseInputStream, cipher) - val dataDecrypted = LittleEndianDataInputStream(decrypted) + val dataDecrypted = CipherInputStream(databaseInputStream, cipher) val storedStartBytes: ByteArray? try { - storedStartBytes = dataDecrypted.readBytes(32) + storedStartBytes = dataDecrypted.readBytesLength(32) if (storedStartBytes.size != 32) { throw InvalidCredentialsDatabaseException() } @@ -173,15 +172,14 @@ class DatabaseInputKDBX(cacheDirectory: File, isPlain = HashedBlockInputStream(dataDecrypted) } else { // KDBX 4 - val isData = LittleEndianDataInputStream(databaseInputStream) - val storedHash = isData.readBytes(32) + val storedHash = databaseInputStream.readBytesLength(32) if (!Arrays.equals(storedHash, hashOfHeader)) { throw InvalidCredentialsDatabaseException() } val hmacKey = mDatabase.hmacKey ?: throw LoadDatabaseException() val headerHmac = DatabaseHeaderKDBX.computeHeaderHmac(pbHeader, hmacKey) - val storedHmac = isData.readBytes(32) + val storedHmac = databaseInputStream.readBytesLength(32) if (storedHmac.size != 32) { throw InvalidCredentialsDatabaseException() } @@ -190,7 +188,7 @@ class DatabaseInputKDBX(cacheDirectory: File, throw InvalidCredentialsDatabaseException() } - val hmIs = HmacBlockInputStream(isData, true, hmacKey) + val hmIs = HmacBlockInputStream(databaseInputStream, true, hmacKey) isPlain = CipherInputStream(hmIs, cipher) } @@ -231,23 +229,21 @@ class DatabaseInputKDBX(cacheDirectory: File, } @Throws(IOException::class) - private fun readInnerHeader(inputStream: InputStream, + private fun readInnerHeader(dataInputStream: InputStream, header: DatabaseHeaderKDBX) { - val dataInputStream = LittleEndianDataInputStream(inputStream) - var readStream = true while (readStream) { val fieldId = dataInputStream.read().toByte() - val size = dataInputStream.readUInt().toKotlinInt() + val size = dataInputStream.readBytes4ToUInt().toKotlinInt() if (size < 0) throw IOException("Corrupted file") var data = ByteArray(0) try { if (size > 0) { if (fieldId != DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.Binary) { - data = dataInputStream.readBytes(size) + data = dataInputStream.readBytesLength(size) } } } catch (e: Exception) { diff --git a/app/src/main/java/com/kunzisoft/keepass/stream/HashedBlockInputStream.kt b/app/src/main/java/com/kunzisoft/keepass/stream/HashedBlockInputStream.kt index 7c72b8093..2b68efb48 100644 --- a/app/src/main/java/com/kunzisoft/keepass/stream/HashedBlockInputStream.kt +++ b/app/src/main/java/com/kunzisoft/keepass/stream/HashedBlockInputStream.kt @@ -27,9 +27,8 @@ import java.security.NoSuchAlgorithmException import java.util.* -class HashedBlockInputStream(inputStream: InputStream) : InputStream() { +class HashedBlockInputStream(private val baseStream: InputStream) : InputStream() { - private val baseStream: LittleEndianDataInputStream = LittleEndianDataInputStream(inputStream) private var bufferPos = 0 private var buffer: ByteArray = ByteArray(0) private var bufferIndex: Long = 0 @@ -80,13 +79,13 @@ class HashedBlockInputStream(inputStream: InputStream) : InputStream() { bufferPos = 0 - val index = baseStream.readUInt() + val index = baseStream.readBytes4ToUInt() if (index.toKotlinLong() != bufferIndex) { throw IOException("Invalid data format") } bufferIndex++ - val storedHash = baseStream.readBytes(32) + val storedHash = baseStream.readBytesLength(32) if (storedHash.size != HASH_SIZE) { throw IOException("Invalid data format") } @@ -104,7 +103,7 @@ class HashedBlockInputStream(inputStream: InputStream) : InputStream() { return false } - buffer = baseStream.readBytes(bufferSize) + buffer = baseStream.readBytesLength(bufferSize) if (buffer.size != bufferSize) { throw IOException("Invalid data format") } diff --git a/app/src/main/java/com/kunzisoft/keepass/stream/HmacBlockInputStream.kt b/app/src/main/java/com/kunzisoft/keepass/stream/HmacBlockInputStream.kt index d38abf45e..373507c1d 100644 --- a/app/src/main/java/com/kunzisoft/keepass/stream/HmacBlockInputStream.kt +++ b/app/src/main/java/com/kunzisoft/keepass/stream/HmacBlockInputStream.kt @@ -28,9 +28,8 @@ import java.util.* import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec -class HmacBlockInputStream(baseStream: InputStream, private val verify: Boolean, private val key: ByteArray) : InputStream() { +class HmacBlockInputStream(private val baseStream: InputStream, private val verify: Boolean, private val key: ByteArray) : InputStream() { - private val baseStream: LittleEndianDataInputStream = LittleEndianDataInputStream(baseStream) private var buffer: ByteArray = ByteArray(0) private var bufferPos = 0 private var blockIndex: Long = 0 @@ -88,20 +87,20 @@ class HmacBlockInputStream(baseStream: InputStream, private val verify: Boolean, private fun readSafeBlock(): Boolean { if (endOfStream) return false - val storedHmac = baseStream.readBytes(32) + val storedHmac = baseStream.readBytesLength(32) if (storedHmac.size != 32) { throw IOException("File corrupted") } val pbBlockIndex = longTo8Bytes(blockIndex) - val pbBlockSize = baseStream.readBytes(4) + val pbBlockSize = baseStream.readBytesLength(4) if (pbBlockSize.size != 4) { throw IOException("File corrupted") } val blockSize = bytes4ToUInt(pbBlockSize) bufferPos = 0 - buffer = baseStream.readBytes(blockSize.toKotlinInt()) + buffer = baseStream.readBytesLength(blockSize.toKotlinInt()) if (verify) { val cmpHmac: ByteArray diff --git a/app/src/main/java/com/kunzisoft/keepass/stream/LittleEndianDataInputStream.kt b/app/src/main/java/com/kunzisoft/keepass/stream/LittleEndianDataInputStream.kt deleted file mode 100644 index 46e030941..000000000 --- a/app/src/main/java/com/kunzisoft/keepass/stream/LittleEndianDataInputStream.kt +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright 2019 Jeremy Jamet / Kunzisoft. - * - * This file is part of KeePassDX. - * - * KeePassDX is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * KeePassDX is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with KeePassDX. If not, see . - * - */ -package com.kunzisoft.keepass.stream - -import com.kunzisoft.keepass.utils.UnsignedInt -import java.io.IOException -import java.io.InputStream - -/** - * Little endian version of the DataInputStream - */ -class LittleEndianDataInputStream(private val baseStream: InputStream) : InputStream() { - - /** - * Read a 32-bit value and return it as a long, so that it can - * be interpreted as an unsigned integer. - */ - @Throws(IOException::class) - fun readUInt(): UnsignedInt { - return baseStream.readBytes4ToUInt() - } - - @Throws(IOException::class) - fun readUShort(): Int { - val buf = ByteArray(2) - if (baseStream.read(buf, 0, 2) != 2) - throw IOException("Unable to read UShort value") - return bytes2ToUShort(buf) - } - - @Throws(IOException::class) - override fun available(): Int { - return baseStream.available() - } - - @Throws(IOException::class) - override fun close() { - baseStream.close() - } - - override fun mark(readlimit: Int) { - baseStream.mark(readlimit) - } - - override fun markSupported(): Boolean { - return baseStream.markSupported() - } - - @Throws(IOException::class) - override fun read(): Int { - return baseStream.read() - } - - @Throws(IOException::class) - override fun read(b: ByteArray, offset: Int, length: Int): Int { - return baseStream.read(b, offset, length) - } - - @Throws(IOException::class) - override fun read(b: ByteArray): Int { - return baseStream.read(b) - } - - @Synchronized - @Throws(IOException::class) - override fun reset() { - baseStream.reset() - } - - @Throws(IOException::class) - override fun skip(n: Long): Long { - return baseStream.skip(n) - } - - @Throws(IOException::class) - fun readBytes(length: Int): ByteArray { - // TODO Exception max length < buffer size - val buf = ByteArray(length) - - var count = 0 - while (count < length) { - val read = read(buf, count, length - count) - - // Reached end - if (read == -1) { - // Stop early - val early = ByteArray(count) - System.arraycopy(buf, 0, early, 0, count) - return early - } - - count += read - } - - return buf - } -} diff --git a/app/src/main/java/com/kunzisoft/keepass/utils/VariantDictionary.kt b/app/src/main/java/com/kunzisoft/keepass/utils/VariantDictionary.kt index 5c51d636c..5386f4465 100644 --- a/app/src/main/java/com/kunzisoft/keepass/utils/VariantDictionary.kt +++ b/app/src/main/java/com/kunzisoft/keepass/utils/VariantDictionary.kt @@ -24,6 +24,7 @@ import com.kunzisoft.keepass.stream.* import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.IOException +import java.io.InputStream import java.nio.charset.Charset import java.util.* @@ -115,7 +116,7 @@ open class VariantDictionary { @Throws(IOException::class) fun deserialize(data: ByteArray): VariantDictionary { - val inputStream = LittleEndianDataInputStream(ByteArrayInputStream(data)) + val inputStream = ByteArrayInputStream(data) return deserialize(inputStream) } @@ -128,9 +129,9 @@ open class VariantDictionary { } @Throws(IOException::class) - fun deserialize(inputStream: LittleEndianDataInputStream): VariantDictionary { + fun deserialize(inputStream: InputStream): VariantDictionary { val dictionary = VariantDictionary() - val version = inputStream.readUShort() + val version = inputStream.readBytes2ToUShort() if (version and VdmCritical > VdVersion and VdmCritical) { throw IOException("Invalid format") } @@ -143,14 +144,14 @@ open class VariantDictionary { if (bType == VdType.None) { break } - val nameLen = inputStream.readUInt().toKotlinInt() - val nameBuf = inputStream.readBytes(nameLen) + val nameLen = inputStream.readBytes4ToUInt().toKotlinInt() + val nameBuf = inputStream.readBytesLength(nameLen) if (nameLen != nameBuf.size) { throw IOException("Invalid format") } val name = String(nameBuf, UTF8Charset) - val valueLen = inputStream.readUInt().toKotlinInt() - val valueBuf = inputStream.readBytes(valueLen) + val valueLen = inputStream.readBytes4ToUInt().toKotlinInt() + val valueBuf = inputStream.readBytesLength(valueLen) if (valueLen != valueBuf.size) { throw IOException("Invalid format") }