#712 Core: Fix possible OOM situation in new stream implementation

(cherry picked from commit 8f5c1b409f0ea99de8d5a71204ccb426828a80e4)
This commit is contained in:
Harald Kuhr 2022-11-21 16:15:21 +01:00
parent debf7d0207
commit 6840f31fa3
2 changed files with 35 additions and 4 deletions

View File

@ -17,8 +17,12 @@ final class MemoryCache implements Cache {
final static int BLOCK_SIZE = 1 << 13; final static int BLOCK_SIZE = 1 << 13;
private static final byte[] NULL_BLOCK = new byte[0];
private final List<byte[]> cache = new ArrayList<>(); private final List<byte[]> cache = new ArrayList<>();
private final ReadableByteChannel channel; private final ReadableByteChannel channel;
private int maxBlock = Integer.MAX_VALUE;
private long length; private long length;
private long position; private long position;
private long start; private long start;
@ -34,12 +38,14 @@ final class MemoryCache implements Cache {
byte[] fetchBlock() throws IOException { byte[] fetchBlock() throws IOException {
long currPos = position; long currPos = position;
long index = currPos / BLOCK_SIZE; long index = currPos / BLOCK_SIZE;
if (index >= Integer.MAX_VALUE) { if (index >= Integer.MAX_VALUE) {
throw new IOException("Memory cache max size exceeded"); throw new IOException("Memory cache max size exceeded");
} }
if (index > maxBlock) {
return NULL_BLOCK;
}
while (index >= cache.size()) { while (index >= cache.size()) {
byte[] block; byte[] block;
@ -51,7 +57,14 @@ final class MemoryCache implements Cache {
} }
cache.add(block); cache.add(block);
length += readBlock(block); int bytesRead = readBlock(block);
length += bytesRead;
if (bytesRead < BLOCK_SIZE) {
// Last block, EOF found
maxBlock = (int) index;
return block;
}
} }
return cache.get((int) index); return cache.get((int) index);
@ -63,7 +76,7 @@ final class MemoryCache implements Cache {
while (wrapped.hasRemaining()) { while (wrapped.hasRemaining()) {
int count = channel.read(wrapped); int count = channel.read(wrapped);
if (count == -1) { if (count == -1) {
// Last block // Last block, EOF found
break; break;
} }
} }
@ -84,12 +97,12 @@ final class MemoryCache implements Cache {
@Override @Override
public int read(ByteBuffer dest) throws IOException { public int read(ByteBuffer dest) throws IOException {
byte[] buffer = fetchBlock(); byte[] buffer = fetchBlock();
int bufferPos = (int) (position % BLOCK_SIZE);
if (position >= length) { if (position >= length) {
return -1; return -1;
} }
int bufferPos = (int) (position % BLOCK_SIZE);
int len = min(dest.remaining(), (int) min(BLOCK_SIZE - bufferPos, length - position)); int len = min(dest.remaining(), (int) min(BLOCK_SIZE - bufferPos, length - position));
dest.put(buffer, bufferPos, len); dest.put(buffer, bufferPos, len);

View File

@ -402,6 +402,24 @@ public class BufferedChannelImageInputStreamMemoryCacheTest {
assertEquals(-1, stream.read()); assertEquals(-1, stream.read());
} }
} }
@Test
public void testSeekWayPastEOFShouldNotThrowOOME() throws IOException {
byte[] bytes = new byte[9];
InputStream input = randomDataToInputStream(bytes);
try (final ImageInputStream stream = new BufferedChannelImageInputStream(new MemoryCache(input))) {
stream.seek(Integer.MAX_VALUE * 4L * 512L); // ~4 TB
assertEquals(-1, stream.read()); // No OOME should happen...
stream.seek(0);
for (byte value : bytes) {
assertEquals(value, stream.readByte());
}
assertEquals(-1, stream.read());
}
}
@Test @Test
public void testClose() throws IOException { public void testClose() throws IOException {