diff --git a/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt b/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt index a24cecfc6..814cdff43 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt +++ b/app/src/main/java/com/kunzisoft/keepass/database/file/input/DatabaseInputKDBX.kt @@ -205,7 +205,7 @@ class DatabaseInputKDBX(cacheDirectory: File, } if (mDatabase.kdbxVersion.toKotlinLong() >= DatabaseHeaderKDBX.FILE_VERSION_32_4.toKotlinLong()) { - loadInnerHeader(inputStreamXml, header) + readInnerHeader(inputStreamXml, header) } try { @@ -237,57 +237,56 @@ class DatabaseInputKDBX(cacheDirectory: File, } @Throws(IOException::class) - private fun loadInnerHeader(inputStream: InputStream, header: DatabaseHeaderKDBX) { - val lis = LittleEndianDataInputStream(inputStream) + private fun readInnerHeader(inputStream: InputStream, + header: DatabaseHeaderKDBX) { - while (true) { - if (!readInnerHeader(lis, header)) break - } - } + val dataInputStream = LittleEndianDataInputStream(inputStream) - @Throws(IOException::class) - private fun readInnerHeader(dataInputStream: LittleEndianDataInputStream, - header: DatabaseHeaderKDBX): Boolean { - val fieldId = dataInputStream.read().toByte() + var readStream = true + while (readStream) { + val fieldId = dataInputStream.read().toByte() - val size = dataInputStream.readUInt().toKotlinInt() - if (size < 0) throw IOException("Corrupted file") + val size = dataInputStream.readUInt().toKotlinInt() + if (size < 0) throw IOException("Corrupted file") - var data = ByteArray(0) - if (size > 0) { - if (fieldId != DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.Binary) { - // TODO OOM here - data = dataInputStream.readBytes(size) + var data = ByteArray(0) + try { + if (size > 0) { + if (fieldId != DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.Binary) { + data = dataInputStream.readBytes(size) + } + } + } catch (e: Exception) { + // OOM only if corrupted file + throw IOException("Corrupted file") } - } - var result = true - when (fieldId) { - DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.EndOfHeader -> { - result = false - } - DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.InnerRandomStreamID -> { - header.setRandomStreamID(data) - } - DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.InnerRandomstreamKey -> { - header.innerRandomStreamKey = data - } - DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.Binary -> { - // Read in a file - val protectedFlag = dataInputStream.read().toByte() == DatabaseHeaderKDBX.KdbxBinaryFlags.Protected - val byteLength = size - 1 - // No compression at this level - val protectedBinary = mDatabase.buildNewAttachment( - isRAMSufficient.invoke(byteLength.toLong()), false, protectedFlag) - protectedBinary.getOutputDataStream(mDatabase.binaryCache).use { outputStream -> - dataInputStream.readBytes(byteLength) { buffer -> - outputStream.write(buffer) + readStream = true + when (fieldId) { + DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.EndOfHeader -> { + readStream = false + } + DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.InnerRandomStreamID -> { + header.setRandomStreamID(data) + } + DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.InnerRandomstreamKey -> { + header.innerRandomStreamKey = data + } + DatabaseHeaderKDBX.PwDbInnerHeaderV4Fields.Binary -> { + // Read in a file + val protectedFlag = dataInputStream.read().toByte() == DatabaseHeaderKDBX.KdbxBinaryFlags.Protected + val byteLength = size - 1 + // No compression at this level + val protectedBinary = mDatabase.buildNewAttachment( + isRAMSufficient.invoke(byteLength.toLong()), false, protectedFlag) + protectedBinary.getOutputDataStream(mDatabase.binaryCache).use { outputStream -> + dataInputStream.readBytes(byteLength) { buffer -> + outputStream.write(buffer) + } } } } } - - return result } private enum class KdbContext {