diff --git a/imageio/imageio-core/src/main/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStream.java b/imageio/imageio-core/src/main/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStream.java index dc0144c9..dca1a5ca 100644 --- a/imageio/imageio-core/src/main/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStream.java +++ b/imageio/imageio-core/src/main/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStream.java @@ -255,6 +255,7 @@ public final class BufferedImageInputStream extends ImageInputStreamImpl impleme } int val = buffer.get() & 0xff; + streamPos++; accum <<= 8; accum |= val; @@ -264,9 +265,7 @@ public final class BufferedImageInputStream extends ImageInputStreamImpl impleme // Move byte position back if in the middle of a byte if (newBitOffset != 0) { buffer.position(buffer.position() - 1); - } - else { - streamPos++; + streamPos--; } this.bitOffset = newBitOffset; @@ -281,26 +280,26 @@ public final class BufferedImageInputStream extends ImageInputStreamImpl impleme } @Override - public void seek(long pPosition) throws IOException { + public void seek(long position) throws IOException { checkClosed(); bitOffset = 0; - if (streamPos == pPosition) { + if (streamPos == position) { return; } // Optimized to not invalidate buffer if new position is within current buffer - long newBufferPos = buffer.position() + pPosition - streamPos; + long newBufferPos = buffer.position() + position - streamPos; if (newBufferPos >= 0 && newBufferPos <= buffer.limit()) { buffer.position((int) newBufferPos); } else { // Will invalidate buffer buffer.limit(0); - stream.seek(pPosition); + stream.seek(position); } - streamPos = pPosition; + streamPos = position; } @Override @@ -332,7 +331,9 @@ public final class BufferedImageInputStream extends ImageInputStreamImpl impleme @Override public void close() throws IOException { if (stream != null) { - //stream.close(); + // TODO: FixMe: Need to close underlying stream here! + // For call sites that relies on not closing, we should instead not close the buffered stream. +// stream.close(); stream = null; buffer = null; } diff --git a/imageio/imageio-core/src/test/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStreamTest.java b/imageio/imageio-core/src/test/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStreamTest.java index da0acc1a..17dab84c 100644 --- a/imageio/imageio-core/src/test/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStreamTest.java +++ b/imageio/imageio-core/src/test/java/com/twelvemonkeys/imageio/stream/BufferedImageInputStreamTest.java @@ -32,16 +32,19 @@ package com.twelvemonkeys.imageio.stream; import com.twelvemonkeys.io.ole2.CompoundDocument; import com.twelvemonkeys.io.ole2.Entry; + import org.junit.Test; import javax.imageio.stream.ImageInputStream; import javax.imageio.stream.MemoryCacheImageInputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Random; import static java.util.Arrays.fill; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; /** * BufferedImageInputStreamTest @@ -72,6 +75,257 @@ public class BufferedImageInputStreamTest { } } + @Test + public void testReadBit() throws IOException { + byte[] bytes = new byte[] {(byte) 0xF0, (byte) 0x0F}; + + // Create wrapper stream + BufferedImageInputStream stream = new BufferedImageInputStream(new ByteArrayImageInputStream(bytes)); + + // Read all bits + assertEquals(1, stream.readBit()); + assertEquals(1, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(2, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(3, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(4, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); + assertEquals(5, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); + assertEquals(6, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); + assertEquals(7, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); // last bit + assertEquals(0, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + // Full reset, read same sequence again + stream.seek(0); + assertEquals(0, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + assertEquals(0, stream.readBit()); + assertEquals(0, stream.readBit()); + assertEquals(0, stream.readBit()); + assertEquals(0, stream.readBit()); + + assertEquals(0, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + // Full reset, read partial + stream.seek(0); + + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + + // Byte reset, read same sequence again + stream.setBitOffset(0); + + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + assertEquals(1, stream.readBit()); + assertEquals(0, stream.readBit()); + + // Byte reset, read partial sequence again + stream.setBitOffset(3); + + assertEquals(1, stream.readBit()); + assertEquals(0, stream.readBit()); + assertEquals(0, stream.getStreamPosition()); + + // Byte reset, read partial sequence again + stream.setBitOffset(6); + + assertEquals(0, stream.readBit()); + assertEquals(0, stream.readBit()); + assertEquals(1, stream.getStreamPosition()); + + // Read all bits, second byte + assertEquals(0, stream.readBit()); + assertEquals(1, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); + assertEquals(2, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); + assertEquals(3, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(0, stream.readBit()); + assertEquals(4, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(5, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(6, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); + assertEquals(7, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(1, stream.readBit()); // last bit + assertEquals(0, stream.getBitOffset()); + assertEquals(2, stream.getStreamPosition()); + } + + @Test + public void testReadBits() throws IOException { + byte[] bytes = new byte[] {(byte) 0xF0, (byte) 0xCC, (byte) 0xAA}; + + // Create wrapper stream + BufferedImageInputStream stream = new BufferedImageInputStream(new ByteArrayImageInputStream(bytes)); + + // Read all bits, first byte + assertEquals(3, stream.readBits(2)); + assertEquals(2, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(3, stream.readBits(2)); + assertEquals(4, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0, stream.readBits(2)); + assertEquals(6, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0, stream.readBits(2)); + assertEquals(0, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + // Read all bits, second byte + assertEquals(3, stream.readBits(2)); + assertEquals(2, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(0, stream.readBits(2)); + assertEquals(4, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(3, stream.readBits(2)); + assertEquals(6, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(0, stream.readBits(2)); + assertEquals(0, stream.getBitOffset()); + assertEquals(2, stream.getStreamPosition()); + + // Read all bits, third byte + assertEquals(2, stream.readBits(2)); + assertEquals(2, stream.getBitOffset()); + assertEquals(2, stream.getStreamPosition()); + + assertEquals(2, stream.readBits(2)); + assertEquals(4, stream.getBitOffset()); + assertEquals(2, stream.getStreamPosition()); + + assertEquals(2, stream.readBits(2)); + assertEquals(6, stream.getBitOffset()); + assertEquals(2, stream.getStreamPosition()); + + assertEquals(2, stream.readBits(2)); + assertEquals(0, stream.getBitOffset()); + assertEquals(3, stream.getStreamPosition()); + + // Full reset, read same sequence again + stream.seek(0); + assertEquals(0, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + // Read all bits, increasing size + assertEquals(7, stream.readBits(3)); // 111 + assertEquals(3, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(8, stream.readBits(4)); // 1000 + assertEquals(7, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(12, stream.readBits(5)); // 01100 + assertEquals(4, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(50, stream.readBits(6)); // 110010 + assertEquals(2, stream.getBitOffset()); + assertEquals(2, stream.getStreamPosition()); + + assertEquals(42, stream.readBits(6)); // 101010 + assertEquals(0, stream.getBitOffset()); + assertEquals(3, stream.getStreamPosition()); + + // Full reset, read same sequence again + stream.seek(0); + assertEquals(0, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + // Read all bits multi-byte + assertEquals(0xF0C, stream.readBits(12)); // 111100001100 + assertEquals(4, stream.getBitOffset()); + assertEquals(1, stream.getStreamPosition()); + + assertEquals(0xCAA, stream.readBits(12)); // 110010101010 + assertEquals(0, stream.getBitOffset()); + assertEquals(3, stream.getStreamPosition()); + + // Full reset, read same sequence again, all bits in one go + stream.seek(0); + assertEquals(0, stream.getBitOffset()); + assertEquals(0, stream.getStreamPosition()); + + assertEquals(0xF0CCAA, stream.readBits(24)); + } + + @Test + public void testReadBitsRandom() throws IOException { + long value = random.nextLong(); + byte[] bytes = new byte[8]; + ByteBuffer.wrap(bytes).putLong(value); + + // Create wrapper stream + BufferedImageInputStream stream = new BufferedImageInputStream(new ByteArrayImageInputStream(bytes)); + + for (int i = 1; i < 64; i++) { + stream.seek(0); + assertEquals(i + " bits differ", value >>> (64L - i), stream.readBits(i)); + } + } + + @Test + public void testClose() throws IOException { + // Create wrapper stream + ImageInputStream mock = mock(ImageInputStream.class); + BufferedImageInputStream stream = new BufferedImageInputStream(mock); + + stream.close(); + verify(mock, never()).close(); + } + // TODO: Write other tests // TODO: Create test that exposes read += -1 (eof) bug