Fix Out of memory for large attachment #115

This commit is contained in:
J-Jamet
2018-09-03 20:06:50 +02:00
parent bb3fb26847
commit 9130a3851f
7 changed files with 91 additions and 34 deletions

View File

@@ -63,30 +63,36 @@ public class BinaryPool {
@Override @Override
public boolean operate(PwEntryV4 entry) { public boolean operate(PwEntryV4 entry) {
for (PwEntryV4 histEntry : entry.getHistory()) { for (PwEntryV4 histEntry : entry.getHistory()) {
poolAdd(histEntry.getBinaries()); add(histEntry.getBinaries());
} }
poolAdd(entry.getBinaries()); add(entry.getBinaries());
return true; return true;
} }
} }
private void poolAdd(Map<String, ProtectedBinary> dict) { private void add(Map<String, ProtectedBinary> dict) {
for (ProtectedBinary pb : dict.values()) { for (ProtectedBinary pb : dict.values()) {
poolAdd(pb); add(pb);
} }
} }
public void poolAdd(ProtectedBinary pb) { public void add(ProtectedBinary pb) {
assert(pb != null); assert(pb != null);
if (findKey(pb) != -1) return;
if (poolFind(pb) != -1) return; pool.put(findUnusedKey(), pb);
pool.put(pool.size(), pb);
} }
public int findUnusedKey() {
int unusedKey = pool.size();
while(get(unusedKey) != null)
unusedKey++;
return unusedKey;
}
public int poolFind(ProtectedBinary pb) { public int findKey(ProtectedBinary pb) {
for (Entry<Integer, ProtectedBinary> pair : pool.entrySet()) { for (Entry<Integer, ProtectedBinary> pair : pool.entrySet()) {
if (pair.getValue().equals(pb)) return pair.getKey(); if (pair.getValue().equals(pb)) return pair.getKey();
} }

View File

@@ -151,7 +151,8 @@ public class Database {
// We'll end up reading 8 bytes to identify the header. Might as well use two extra. // We'll end up reading 8 bytes to identify the header. Might as well use two extra.
bis.mark(10); bis.mark(10);
Importer databaseImporter = ImporterFactory.createImporter(bis, debug); // Get the file directory to save the attachments
Importer databaseImporter = ImporterFactory.createImporter(bis, ctx.getFilesDir(), debug);
bis.reset(); // Return to the start bis.reset(); // Return to the start

View File

@@ -24,15 +24,16 @@ import com.kunzisoft.keepass.database.PwDbHeaderV4;
import com.kunzisoft.keepass.database.exception.InvalidDBSignatureException; import com.kunzisoft.keepass.database.exception.InvalidDBSignatureException;
import com.kunzisoft.keepass.stream.LEDataInputStream; import com.kunzisoft.keepass.stream.LEDataInputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
public class ImporterFactory { public class ImporterFactory {
public static Importer createImporter(InputStream is) throws InvalidDBSignatureException, IOException { public static Importer createImporter(InputStream is, File streamDir) throws InvalidDBSignatureException, IOException {
return createImporter(is, false); return createImporter(is, streamDir,false);
} }
public static Importer createImporter(InputStream is, boolean debug) throws InvalidDBSignatureException, IOException { public static Importer createImporter(InputStream is, File streamDir, boolean debug) throws InvalidDBSignatureException, IOException {
int sig1 = LEDataInputStream.readInt(is); int sig1 = LEDataInputStream.readInt(is);
int sig2 = LEDataInputStream.readInt(is); int sig2 = LEDataInputStream.readInt(is);
@@ -43,7 +44,7 @@ public class ImporterFactory {
return new ImporterV3(); return new ImporterV3();
} else if ( PwDbHeaderV4.matchesHeader(sig1, sig2) ) { } else if ( PwDbHeaderV4.matchesHeader(sig1, sig2) ) {
return new ImporterV4(); return new ImporterV4(streamDir);
} }
throw new InvalidDBSignatureException(); throw new InvalidDBSignatureException();

View File

@@ -23,6 +23,7 @@ import com.kunzisoft.keepass.R;
import com.kunzisoft.keepass.crypto.CipherFactory; import com.kunzisoft.keepass.crypto.CipherFactory;
import com.kunzisoft.keepass.crypto.PwStreamCipherFactory; import com.kunzisoft.keepass.crypto.PwStreamCipherFactory;
import com.kunzisoft.keepass.crypto.engine.CipherEngine; import com.kunzisoft.keepass.crypto.engine.CipherEngine;
import com.kunzisoft.keepass.database.BinaryPool;
import com.kunzisoft.keepass.database.ITimeLogger; import com.kunzisoft.keepass.database.ITimeLogger;
import com.kunzisoft.keepass.database.PwCompressionAlgorithm; import com.kunzisoft.keepass.database.PwCompressionAlgorithm;
import com.kunzisoft.keepass.database.PwDatabase; import com.kunzisoft.keepass.database.PwDatabase;
@@ -54,6 +55,8 @@ import org.xmlpull.v1.XmlPullParser;
import org.xmlpull.v1.XmlPullParserException; import org.xmlpull.v1.XmlPullParserException;
import org.xmlpull.v1.XmlPullParserFactory; import org.xmlpull.v1.XmlPullParserFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
@@ -83,9 +86,11 @@ public class ImporterV4 extends Importer {
private byte[] pbHeader = null; private byte[] pbHeader = null;
private long version; private long version;
Calendar utcCal; Calendar utcCal;
private File streamDir;
public ImporterV4() { public ImporterV4(File streamDir) {
utcCal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); this.utcCal = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
this.streamDir = streamDir;
} }
@Override @Override
@@ -223,7 +228,8 @@ public class ImporterV4 extends Importer {
byte[] data = new byte[0]; byte[] data = new byte[0];
if (size > 0) { if (size > 0) {
data = lis.readBytes(size); if (fieldId != PwDbHeaderV4.PwDbInnerHeaderV4Fields.Binary)
data = lis.readBytes(size);
} }
boolean result = true; boolean result = true;
@@ -238,20 +244,23 @@ public class ImporterV4 extends Importer {
header.innerRandomStreamKey = data; header.innerRandomStreamKey = data;
break; break;
case PwDbHeaderV4.PwDbInnerHeaderV4Fields.Binary: case PwDbHeaderV4.PwDbInnerHeaderV4Fields.Binary:
if (data.length < 1) throw new IOException("Invalid binary format"); byte flag = lis.readBytes(1)[0];
byte flag = data[0]; boolean protectedFlag = (flag & PwDbHeaderV4.KdbxBinaryFlags.Protected) !=
boolean prot = (flag & PwDbHeaderV4.KdbxBinaryFlags.Protected) != PwDbHeaderV4.KdbxBinaryFlags.None;
PwDbHeaderV4.KdbxBinaryFlags.None; // Read in a file
BinaryPool binaryPool = db.getBinPool();
byte[] bin = new byte[data.length - 1]; int binaryKey = binaryPool.findUnusedKey();
System.arraycopy(data, 1, bin, 0, data.length-1); File file = new File(streamDir, String.valueOf(binaryKey));
ProtectedBinary pb = new ProtectedBinary(prot, bin); FileOutputStream outputStream = new FileOutputStream(file);
db.getBinPool().poolAdd(pb); try {
lis.readBytes(size - 1, outputStream::write);
if (prot) { } finally {
Arrays.fill(data, (byte)0); outputStream.close();
} }
ProtectedBinary protectedBinary = new ProtectedBinary(protectedFlag, file);
binaryPool.add(protectedBinary);
break; break;
default: default:
assert(false); assert(false);
break; break;

View File

@@ -413,7 +413,7 @@ public class PwDbV4Output extends PwDbOutput<PwDbHeaderV4> {
xml.startTag(null, PwDatabaseV4XML.ElemValue); xml.startTag(null, PwDatabaseV4XML.ElemValue);
String strRef = null; String strRef = null;
if (allowRef) { if (allowRef) {
int ref = mPM.getBinPool().poolFind(value); int ref = mPM.getBinPool().findKey(value);
strRef = Integer.toString(ref); strRef = Integer.toString(ref);
} }

View File

@@ -22,6 +22,7 @@ package com.kunzisoft.keepass.database.security;
import android.os.Parcel; import android.os.Parcel;
import android.os.Parcelable; import android.os.Parcelable;
import java.io.File;
import java.util.Arrays; import java.util.Arrays;
public class ProtectedBinary implements Parcelable { public class ProtectedBinary implements Parcelable {
@@ -30,12 +31,14 @@ public class ProtectedBinary implements Parcelable {
private boolean protect; private boolean protect;
private byte[] data; private byte[] data;
private File dataFile;
public boolean isProtected() { public boolean isProtected() {
return protect; return protect;
} }
public int length() { public int length() {
// TODO File length
if (data == null) { if (data == null) {
return 0; return 0;
} }
@@ -49,11 +52,19 @@ public class ProtectedBinary implements Parcelable {
public ProtectedBinary(boolean enableProtection, byte[] data) { public ProtectedBinary(boolean enableProtection, byte[] data) {
this.protect = enableProtection; this.protect = enableProtection;
this.data = data; this.data = data;
this.dataFile = null;
} }
public ProtectedBinary(boolean enableProtection, File dataFile) {
this.protect = enableProtection;
this.data = new byte[0];
this.dataFile = dataFile;
}
public ProtectedBinary(Parcel in) { public ProtectedBinary(Parcel in) {
protect = in.readByte() != 0; protect = in.readByte() != 0;
in.readByteArray(data); in.readByteArray(data);
dataFile = new File(in.readString());
} }
// TODO: replace the byte[] with something like ByteBuffer to make the return // TODO: replace the byte[] with something like ByteBuffer to make the return
@@ -63,7 +74,9 @@ public class ProtectedBinary implements Parcelable {
} }
public boolean equals(ProtectedBinary rhs) { public boolean equals(ProtectedBinary rhs) {
return (protect == rhs.protect) && Arrays.equals(data, rhs.data); return (protect == rhs.protect)
&& Arrays.equals(data, rhs.data)
&& dataFile.equals(rhs.dataFile);
} }
@Override @Override
@@ -75,6 +88,7 @@ public class ProtectedBinary implements Parcelable {
public void writeToParcel(Parcel dest, int flags) { public void writeToParcel(Parcel dest, int flags) {
dest.writeByte((byte) (protect ? 1 : 0)); dest.writeByte((byte) (protect ? 1 : 0));
dest.writeByteArray(data); dest.writeByteArray(data);
dest.writeString(dataFile.getAbsolutePath());
} }
public static final Creator<ProtectedBinary> CREATOR = new Creator<ProtectedBinary>() { public static final Creator<ProtectedBinary> CREATOR = new Creator<ProtectedBinary>() {

View File

@@ -106,8 +106,9 @@ public class LEDataInputStream extends InputStream {
} }
public byte[] readBytes(int length) throws IOException { public byte[] readBytes(int length) throws IOException {
// TODO Exception max length < buffer size
byte[] buf = new byte[length]; byte[] buf = new byte[length];
int count = 0; int count = 0;
while ( count < length ) { while ( count < length ) {
int read = read(buf, count, length - count); int read = read(buf, count, length - count);
@@ -126,6 +127,31 @@ public class LEDataInputStream extends InputStream {
return buf; return buf;
} }
public void readBytes(int length, ActionReadBytes actionReadBytes) throws IOException {
byte[] buffer = new byte[1024];
int offset = 0;
int read = 0;
while ( offset < length && read != -1) {
int tempLength = buffer.length;
// If buffer not needed
if (length - offset < tempLength)
tempLength = length - offset;
read = read(buffer, 0, tempLength);
actionReadBytes.doAction(buffer);
offset += read;
}
}
public interface ActionReadBytes {
/**
* Called after each buffer fill
* @param buffer filled
*/
void doAction(byte[] buffer) throws IOException;
}
public static int readUShort(InputStream is) throws IOException { public static int readUShort(InputStream is) throws IOException {
byte[] buf = new byte[2]; byte[] buf = new byte[2];