Remove Little Endian input stream

This commit is contained in:
J-Jamet
2021-03-24 11:55:41 +01:00
parent cfcfd47705
commit 844588a0d4
7 changed files with 39 additions and 168 deletions

View File

@@ -19,28 +19,20 @@
*/
package com.kunzisoft.keepass.tests.crypto
import com.kunzisoft.keepass.crypto.CipherFactory
import com.kunzisoft.keepass.crypto.engine.AesEngine
import com.kunzisoft.keepass.stream.BetterCipherInputStream
import com.kunzisoft.keepass.stream.readBytesLength
import junit.framework.TestCase
import org.junit.Assert.assertArrayEquals
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.security.InvalidAlgorithmParameterException
import java.security.InvalidKeyException
import java.security.NoSuchAlgorithmException
import java.util.Random
import javax.crypto.BadPaddingException
import javax.crypto.Cipher
import javax.crypto.CipherOutputStream
import javax.crypto.IllegalBlockSizeException
import javax.crypto.NoSuchPaddingException
import junit.framework.TestCase
import com.kunzisoft.keepass.crypto.CipherFactory
import com.kunzisoft.keepass.crypto.engine.AesEngine
import com.kunzisoft.keepass.stream.BetterCipherInputStream
import com.kunzisoft.keepass.stream.LittleEndianDataInputStream
import java.util.*
import javax.crypto.*
class CipherTest : TestCase() {
private val rand = Random()
@@ -92,9 +84,8 @@ class CipherTest : TestCase() {
val bis = ByteArrayInputStream(secrettext)
val cis = BetterCipherInputStream(bis, decrypt)
val lis = LittleEndianDataInputStream(cis)
val decrypttext = lis.readBytes(MESSAGE_LENGTH)
val decrypttext = cis.readBytesLength(MESSAGE_LENGTH)
assertArrayEquals("Encryption and decryption failed", plaintext, decrypttext)
}

View File

@@ -150,23 +150,22 @@ class DatabaseHeaderKDBX(private val databaseV4: DatabaseKDBX) : DatabaseHeader(
val headerBOS = ByteArrayOutputStream()
val copyInputStream = CopyInputStream(inputStream, headerBOS)
val digestInputStream = DigestInputStream(copyInputStream, messageDigest)
val littleEndianDataInputStream = LittleEndianDataInputStream(digestInputStream)
val sig1 = littleEndianDataInputStream.readUInt()
val sig2 = littleEndianDataInputStream.readUInt()
val sig1 = digestInputStream.readBytes4ToUInt()
val sig2 = digestInputStream.readBytes4ToUInt()
if (!matchesHeader(sig1, sig2)) {
throw VersionDatabaseException()
}
version = littleEndianDataInputStream.readUInt() // Erase previous value
version = digestInputStream.readBytes4ToUInt() // Erase previous value
if (!validVersion(version)) {
throw VersionDatabaseException()
}
var done = false
while (!done) {
done = readHeaderField(littleEndianDataInputStream)
done = readHeaderField(digestInputStream)
}
val hash = messageDigest.digest()
@@ -174,13 +173,13 @@ class DatabaseHeaderKDBX(private val databaseV4: DatabaseKDBX) : DatabaseHeader(
}
@Throws(IOException::class)
private fun readHeaderField(dis: LittleEndianDataInputStream): Boolean {
private fun readHeaderField(dis: InputStream): Boolean {
val fieldID = dis.read().toByte()
val fieldSize: Int = if (version.toKotlinLong() < FILE_VERSION_32_4.toKotlinLong()) {
dis.readUShort()
dis.readBytes2ToUShort()
} else {
dis.readUInt().toKotlinInt()
dis.readBytes4ToUInt().toKotlinInt()
}
var fieldData: ByteArray? = null

View File

@@ -155,11 +155,10 @@ class DatabaseInputKDBX(cacheDirectory: File,
val isPlain: InputStream
if (mDatabase.kdbxVersion.toKotlinLong() < DatabaseHeaderKDBX.FILE_VERSION_32_4.toKotlinLong()) {
val decrypted = CipherInputStream(databaseInputStream, cipher)
val dataDecrypted = LittleEndianDataInputStream(decrypted)
val dataDecrypted = CipherInputStream(databaseInputStream, cipher)
val storedStartBytes: ByteArray?
try {
storedStartBytes = dataDecrypted.readBytes(32)
storedStartBytes = dataDecrypted.readBytesLength(32)
if (storedStartBytes.size != 32) {
throw InvalidCredentialsDatabaseException()
}
@@ -173,15 +172,14 @@ class DatabaseInputKDBX(cacheDirectory: File,
isPlain = HashedBlockInputStream(dataDecrypted)
} else { // KDBX 4
val isData = LittleEndianDataInputStream(databaseInputStream)
val storedHash = isData.readBytes(32)
val storedHash = databaseInputStream.readBytesLength(32)
if (!Arrays.equals(storedHash, hashOfHeader)) {
throw InvalidCredentialsDatabaseException()
}
val hmacKey = mDatabase.hmacKey ?: throw LoadDatabaseException()
val headerHmac = DatabaseHeaderKDBX.computeHeaderHmac(pbHeader, hmacKey)
val storedHmac = isData.readBytes(32)
val storedHmac = databaseInputStream.readBytesLength(32)
if (storedHmac.size != 32) {
throw InvalidCredentialsDatabaseException()
}
@@ -190,7 +188,7 @@ class DatabaseInputKDBX(cacheDirectory: File,
throw InvalidCredentialsDatabaseException()
}
val hmIs = HmacBlockInputStream(isData, true, hmacKey)
val hmIs = HmacBlockInputStream(databaseInputStream, true, hmacKey)
isPlain = CipherInputStream(hmIs, cipher)
}
@@ -231,23 +229,21 @@ class DatabaseInputKDBX(cacheDirectory: File,
}
@Throws(IOException::class)
private fun readInnerHeader(inputStream: InputStream,
private fun readInnerHeader(dataInputStream: InputStream,
header: DatabaseHeaderKDBX) {
val dataInputStream = LittleEndianDataInputStream(inputStream)
var readStream = true
while (readStream) {
val fieldId = dataInputStream.read().toByte()
val size = dataInputStream.readUInt().toKotlinInt()
val size = dataInputStream.readBytes4ToUInt().toKotlinInt()
if (size < 0) throw IOException("Corrupted file")
var data = ByteArray(0)
try {
if (size > 0) {
if (fieldId != DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.Binary) {
data = dataInputStream.readBytes(size)
data = dataInputStream.readBytesLength(size)
}
}
} catch (e: Exception) {

View File

@@ -27,9 +27,8 @@ import java.security.NoSuchAlgorithmException
import java.util.*
class HashedBlockInputStream(inputStream: InputStream) : InputStream() {
class HashedBlockInputStream(private val baseStream: InputStream) : InputStream() {
private val baseStream: LittleEndianDataInputStream = LittleEndianDataInputStream(inputStream)
private var bufferPos = 0
private var buffer: ByteArray = ByteArray(0)
private var bufferIndex: Long = 0
@@ -80,13 +79,13 @@ class HashedBlockInputStream(inputStream: InputStream) : InputStream() {
bufferPos = 0
val index = baseStream.readUInt()
val index = baseStream.readBytes4ToUInt()
if (index.toKotlinLong() != bufferIndex) {
throw IOException("Invalid data format")
}
bufferIndex++
val storedHash = baseStream.readBytes(32)
val storedHash = baseStream.readBytesLength(32)
if (storedHash.size != HASH_SIZE) {
throw IOException("Invalid data format")
}
@@ -104,7 +103,7 @@ class HashedBlockInputStream(inputStream: InputStream) : InputStream() {
return false
}
buffer = baseStream.readBytes(bufferSize)
buffer = baseStream.readBytesLength(bufferSize)
if (buffer.size != bufferSize) {
throw IOException("Invalid data format")
}

View File

@@ -28,9 +28,8 @@ import java.util.*
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
class HmacBlockInputStream(baseStream: InputStream, private val verify: Boolean, private val key: ByteArray) : InputStream() {
class HmacBlockInputStream(private val baseStream: InputStream, private val verify: Boolean, private val key: ByteArray) : InputStream() {
private val baseStream: LittleEndianDataInputStream = LittleEndianDataInputStream(baseStream)
private var buffer: ByteArray = ByteArray(0)
private var bufferPos = 0
private var blockIndex: Long = 0
@@ -88,20 +87,20 @@ class HmacBlockInputStream(baseStream: InputStream, private val verify: Boolean,
private fun readSafeBlock(): Boolean {
if (endOfStream) return false
val storedHmac = baseStream.readBytes(32)
val storedHmac = baseStream.readBytesLength(32)
if (storedHmac.size != 32) {
throw IOException("File corrupted")
}
val pbBlockIndex = longTo8Bytes(blockIndex)
val pbBlockSize = baseStream.readBytes(4)
val pbBlockSize = baseStream.readBytesLength(4)
if (pbBlockSize.size != 4) {
throw IOException("File corrupted")
}
val blockSize = bytes4ToUInt(pbBlockSize)
bufferPos = 0
buffer = baseStream.readBytes(blockSize.toKotlinInt())
buffer = baseStream.readBytesLength(blockSize.toKotlinInt())
if (verify) {
val cmpHmac: ByteArray

View File

@@ -1,114 +0,0 @@
/*
* Copyright 2019 Jeremy Jamet / Kunzisoft.
*
* This file is part of KeePassDX.
*
* KeePassDX is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* KeePassDX is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with KeePassDX. If not, see <http://www.gnu.org/licenses/>.
*
*/
package com.kunzisoft.keepass.stream
import com.kunzisoft.keepass.utils.UnsignedInt
import java.io.IOException
import java.io.InputStream
/**
* Little endian version of the DataInputStream
*/
class LittleEndianDataInputStream(private val baseStream: InputStream) : InputStream() {
/**
* Read a 32-bit value and return it as a long, so that it can
* be interpreted as an unsigned integer.
*/
@Throws(IOException::class)
fun readUInt(): UnsignedInt {
return baseStream.readBytes4ToUInt()
}
@Throws(IOException::class)
fun readUShort(): Int {
val buf = ByteArray(2)
if (baseStream.read(buf, 0, 2) != 2)
throw IOException("Unable to read UShort value")
return bytes2ToUShort(buf)
}
@Throws(IOException::class)
override fun available(): Int {
return baseStream.available()
}
@Throws(IOException::class)
override fun close() {
baseStream.close()
}
override fun mark(readlimit: Int) {
baseStream.mark(readlimit)
}
override fun markSupported(): Boolean {
return baseStream.markSupported()
}
@Throws(IOException::class)
override fun read(): Int {
return baseStream.read()
}
@Throws(IOException::class)
override fun read(b: ByteArray, offset: Int, length: Int): Int {
return baseStream.read(b, offset, length)
}
@Throws(IOException::class)
override fun read(b: ByteArray): Int {
return baseStream.read(b)
}
@Synchronized
@Throws(IOException::class)
override fun reset() {
baseStream.reset()
}
@Throws(IOException::class)
override fun skip(n: Long): Long {
return baseStream.skip(n)
}
@Throws(IOException::class)
fun readBytes(length: Int): ByteArray {
// TODO Exception max length < buffer size
val buf = ByteArray(length)
var count = 0
while (count < length) {
val read = read(buf, count, length - count)
// Reached end
if (read == -1) {
// Stop early
val early = ByteArray(count)
System.arraycopy(buf, 0, early, 0, count)
return early
}
count += read
}
return buf
}
}

View File

@@ -24,6 +24,7 @@ import com.kunzisoft.keepass.stream.*
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.nio.charset.Charset
import java.util.*
@@ -115,7 +116,7 @@ open class VariantDictionary {
@Throws(IOException::class)
fun deserialize(data: ByteArray): VariantDictionary {
val inputStream = LittleEndianDataInputStream(ByteArrayInputStream(data))
val inputStream = ByteArrayInputStream(data)
return deserialize(inputStream)
}
@@ -128,9 +129,9 @@ open class VariantDictionary {
}
@Throws(IOException::class)
fun deserialize(inputStream: LittleEndianDataInputStream): VariantDictionary {
fun deserialize(inputStream: InputStream): VariantDictionary {
val dictionary = VariantDictionary()
val version = inputStream.readUShort()
val version = inputStream.readBytes2ToUShort()
if (version and VdmCritical > VdVersion and VdmCritical) {
throw IOException("Invalid format")
}
@@ -143,14 +144,14 @@ open class VariantDictionary {
if (bType == VdType.None) {
break
}
val nameLen = inputStream.readUInt().toKotlinInt()
val nameBuf = inputStream.readBytes(nameLen)
val nameLen = inputStream.readBytes4ToUInt().toKotlinInt()
val nameBuf = inputStream.readBytesLength(nameLen)
if (nameLen != nameBuf.size) {
throw IOException("Invalid format")
}
val name = String(nameBuf, UTF8Charset)
val valueLen = inputStream.readUInt().toKotlinInt()
val valueBuf = inputStream.readBytes(valueLen)
val valueLen = inputStream.readBytes4ToUInt().toKotlinInt()
val valueBuf = inputStream.readBytesLength(valueLen)
if (valueLen != valueBuf.size) {
throw IOException("Invalid format")
}