From 9130a3851f19557a73dddca6f2a87d5f0fc248df Mon Sep 17 00:00:00 2001 From: J-Jamet Date: Mon, 3 Sep 2018 20:06:50 +0200 Subject: [PATCH] Fix Out of memory for large attachment #115 --- .../keepass/database/BinaryPool.java | 26 +++++++----- .../kunzisoft/keepass/database/Database.java | 3 +- .../database/load/ImporterFactory.java | 9 ++-- .../keepass/database/load/ImporterV4.java | 41 +++++++++++-------- .../keepass/database/save/PwDbV4Output.java | 2 +- .../database/security/ProtectedBinary.java | 16 +++++++- .../keepass/stream/LEDataInputStream.java | 28 ++++++++++++- 7 files changed, 91 insertions(+), 34 deletions(-) diff --git a/app/src/main/java/com/kunzisoft/keepass/database/BinaryPool.java b/app/src/main/java/com/kunzisoft/keepass/database/BinaryPool.java index 84fed9719..5acfd27c9 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/BinaryPool.java +++ b/app/src/main/java/com/kunzisoft/keepass/database/BinaryPool.java @@ -63,30 +63,36 @@ public class BinaryPool { @Override public boolean operate(PwEntryV4 entry) { for (PwEntryV4 histEntry : entry.getHistory()) { - poolAdd(histEntry.getBinaries()); + add(histEntry.getBinaries()); } - poolAdd(entry.getBinaries()); + add(entry.getBinaries()); return true; } } - private void poolAdd(Map dict) { + private void add(Map dict) { for (ProtectedBinary pb : dict.values()) { - poolAdd(pb); + add(pb); } } - public void poolAdd(ProtectedBinary pb) { - assert(pb != null); + public void add(ProtectedBinary pb) { + assert(pb != null); + if (findKey(pb) != -1) return; - if (poolFind(pb) != -1) return; - - pool.put(pool.size(), pb); + pool.put(findUnusedKey(), pb); } + + public int findUnusedKey() { + int unusedKey = pool.size(); + while(get(unusedKey) != null) + unusedKey++; + return unusedKey; + } - public int poolFind(ProtectedBinary pb) { + public int findKey(ProtectedBinary pb) { for (Entry pair : pool.entrySet()) { if (pair.getValue().equals(pb)) return pair.getKey(); } diff --git a/app/src/main/java/com/kunzisoft/keepass/database/Database.java b/app/src/main/java/com/kunzisoft/keepass/database/Database.java index 6dd484e49..fc822cc6e 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/Database.java +++ b/app/src/main/java/com/kunzisoft/keepass/database/Database.java @@ -151,7 +151,8 @@ public class Database { // We'll end up reading 8 bytes to identify the header. Might as well use two extra. bis.mark(10); - Importer databaseImporter = ImporterFactory.createImporter(bis, debug); + // Get the file directory to save the attachments + Importer databaseImporter = ImporterFactory.createImporter(bis, ctx.getFilesDir(), debug); bis.reset(); // Return to the start diff --git a/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterFactory.java b/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterFactory.java index 385159fd7..02cefcb7e 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterFactory.java +++ b/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterFactory.java @@ -24,15 +24,16 @@ import com.kunzisoft.keepass.database.PwDbHeaderV4; import com.kunzisoft.keepass.database.exception.InvalidDBSignatureException; import com.kunzisoft.keepass.stream.LEDataInputStream; +import java.io.File; import java.io.IOException; import java.io.InputStream; public class ImporterFactory { - public static Importer createImporter(InputStream is) throws InvalidDBSignatureException, IOException { - return createImporter(is, false); + public static Importer createImporter(InputStream is, File streamDir) throws InvalidDBSignatureException, IOException { + return createImporter(is, streamDir,false); } - public static Importer createImporter(InputStream is, boolean debug) throws InvalidDBSignatureException, IOException { + public static Importer createImporter(InputStream is, File streamDir, boolean debug) throws InvalidDBSignatureException, IOException { int sig1 = LEDataInputStream.readInt(is); int sig2 = LEDataInputStream.readInt(is); @@ -43,7 +44,7 @@ public class ImporterFactory { return new ImporterV3(); } else if ( PwDbHeaderV4.matchesHeader(sig1, sig2) ) { - return new ImporterV4(); + return new ImporterV4(streamDir); } throw new InvalidDBSignatureException(); diff --git a/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterV4.java b/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterV4.java index 7719019a4..8d88771c9 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterV4.java +++ b/app/src/main/java/com/kunzisoft/keepass/database/load/ImporterV4.java @@ -23,6 +23,7 @@ import com.kunzisoft.keepass.R; import com.kunzisoft.keepass.crypto.CipherFactory; import com.kunzisoft.keepass.crypto.PwStreamCipherFactory; import com.kunzisoft.keepass.crypto.engine.CipherEngine; +import com.kunzisoft.keepass.database.BinaryPool; import com.kunzisoft.keepass.database.ITimeLogger; import com.kunzisoft.keepass.database.PwCompressionAlgorithm; import com.kunzisoft.keepass.database.PwDatabase; @@ -54,6 +55,8 @@ import org.xmlpull.v1.XmlPullParser; import org.xmlpull.v1.XmlPullParserException; import org.xmlpull.v1.XmlPullParserFactory; +import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; @@ -83,9 +86,11 @@ public class ImporterV4 extends Importer { private byte[] pbHeader = null; private long version; Calendar utcCal; + private File streamDir; - public ImporterV4() { - utcCal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + public ImporterV4(File streamDir) { + this.utcCal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + this.streamDir = streamDir; } @Override @@ -223,7 +228,8 @@ public class ImporterV4 extends Importer { byte[] data = new byte[0]; if (size > 0) { - data = lis.readBytes(size); + if (fieldId != PwDbHeaderV4.PwDbInnerHeaderV4Fields.Binary) + data = lis.readBytes(size); } boolean result = true; @@ -238,20 +244,23 @@ public class ImporterV4 extends Importer { header.innerRandomStreamKey = data; break; case PwDbHeaderV4.PwDbInnerHeaderV4Fields.Binary: - if (data.length < 1) throw new IOException("Invalid binary format"); - byte flag = data[0]; - boolean prot = (flag & PwDbHeaderV4.KdbxBinaryFlags.Protected) != - PwDbHeaderV4.KdbxBinaryFlags.None; - - byte[] bin = new byte[data.length - 1]; - System.arraycopy(data, 1, bin, 0, data.length-1); - ProtectedBinary pb = new ProtectedBinary(prot, bin); - db.getBinPool().poolAdd(pb); - - if (prot) { - Arrays.fill(data, (byte)0); - } + byte flag = lis.readBytes(1)[0]; + boolean protectedFlag = (flag & PwDbHeaderV4.KdbxBinaryFlags.Protected) != + PwDbHeaderV4.KdbxBinaryFlags.None; + // Read in a file + BinaryPool binaryPool = db.getBinPool(); + int binaryKey = binaryPool.findUnusedKey(); + File file = new File(streamDir, String.valueOf(binaryKey)); + FileOutputStream outputStream = new FileOutputStream(file); + try { + lis.readBytes(size - 1, outputStream::write); + } finally { + outputStream.close(); + } + ProtectedBinary protectedBinary = new ProtectedBinary(protectedFlag, file); + binaryPool.add(protectedBinary); break; + default: assert(false); break; diff --git a/app/src/main/java/com/kunzisoft/keepass/database/save/PwDbV4Output.java b/app/src/main/java/com/kunzisoft/keepass/database/save/PwDbV4Output.java index 53b4f5701..59f8ac0a3 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/save/PwDbV4Output.java +++ b/app/src/main/java/com/kunzisoft/keepass/database/save/PwDbV4Output.java @@ -413,7 +413,7 @@ public class PwDbV4Output extends PwDbOutput { xml.startTag(null, PwDatabaseV4XML.ElemValue); String strRef = null; if (allowRef) { - int ref = mPM.getBinPool().poolFind(value); + int ref = mPM.getBinPool().findKey(value); strRef = Integer.toString(ref); } diff --git a/app/src/main/java/com/kunzisoft/keepass/database/security/ProtectedBinary.java b/app/src/main/java/com/kunzisoft/keepass/database/security/ProtectedBinary.java index 463b7b604..3b6e6aed6 100644 --- a/app/src/main/java/com/kunzisoft/keepass/database/security/ProtectedBinary.java +++ b/app/src/main/java/com/kunzisoft/keepass/database/security/ProtectedBinary.java @@ -22,6 +22,7 @@ package com.kunzisoft.keepass.database.security; import android.os.Parcel; import android.os.Parcelable; +import java.io.File; import java.util.Arrays; public class ProtectedBinary implements Parcelable { @@ -30,12 +31,14 @@ public class ProtectedBinary implements Parcelable { private boolean protect; private byte[] data; + private File dataFile; public boolean isProtected() { return protect; } public int length() { + // TODO File length if (data == null) { return 0; } @@ -49,11 +52,19 @@ public class ProtectedBinary implements Parcelable { public ProtectedBinary(boolean enableProtection, byte[] data) { this.protect = enableProtection; this.data = data; + this.dataFile = null; } + public ProtectedBinary(boolean enableProtection, File dataFile) { + this.protect = enableProtection; + this.data = new byte[0]; + this.dataFile = dataFile; + } + public ProtectedBinary(Parcel in) { protect = in.readByte() != 0; in.readByteArray(data); + dataFile = new File(in.readString()); } // TODO: replace the byte[] with something like ByteBuffer to make the return @@ -63,7 +74,9 @@ public class ProtectedBinary implements Parcelable { } public boolean equals(ProtectedBinary rhs) { - return (protect == rhs.protect) && Arrays.equals(data, rhs.data); + return (protect == rhs.protect) + && Arrays.equals(data, rhs.data) + && dataFile.equals(rhs.dataFile); } @Override @@ -75,6 +88,7 @@ public class ProtectedBinary implements Parcelable { public void writeToParcel(Parcel dest, int flags) { dest.writeByte((byte) (protect ? 1 : 0)); dest.writeByteArray(data); + dest.writeString(dataFile.getAbsolutePath()); } public static final Creator CREATOR = new Creator() { diff --git a/app/src/main/java/com/kunzisoft/keepass/stream/LEDataInputStream.java b/app/src/main/java/com/kunzisoft/keepass/stream/LEDataInputStream.java index a213f7c76..dab42aca6 100644 --- a/app/src/main/java/com/kunzisoft/keepass/stream/LEDataInputStream.java +++ b/app/src/main/java/com/kunzisoft/keepass/stream/LEDataInputStream.java @@ -106,8 +106,9 @@ public class LEDataInputStream extends InputStream { } public byte[] readBytes(int length) throws IOException { + // TODO Exception max length < buffer size byte[] buf = new byte[length]; - + int count = 0; while ( count < length ) { int read = read(buf, count, length - count); @@ -126,6 +127,31 @@ public class LEDataInputStream extends InputStream { return buf; } + public void readBytes(int length, ActionReadBytes actionReadBytes) throws IOException { + byte[] buffer = new byte[1024]; + + int offset = 0; + int read = 0; + while ( offset < length && read != -1) { + + int tempLength = buffer.length; + // If buffer not needed + if (length - offset < tempLength) + tempLength = length - offset; + read = read(buffer, 0, tempLength); + actionReadBytes.doAction(buffer); + offset += read; + } + } + + public interface ActionReadBytes { + /** + * Called after each buffer fill + * @param buffer filled + */ + void doAction(byte[] buffer) throws IOException; + } + public static int readUShort(InputStream is) throws IOException { byte[] buf = new byte[2];