Skip to content

Commit e002eb2

Browse files
author
David Eigen
committed
add test
1 parent 0389776 commit e002eb2

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

tests/runners/test_stream_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import io
2+
import unittest
3+
4+
from clarifai.runners.utils.stream_utils import (SeekableStreamingChunksReader,
5+
StreamingChunksReader)
6+
7+
8+
class TestStreamingChunksReader(unittest.TestCase):
9+
10+
def setUp(self):
11+
self.chunks = [b'hello', b'world', b'12345']
12+
#self.reader = BufferStream(iter(self.chunks), buffer_size=10)
13+
self.reader = StreamingChunksReader(iter(self.chunks))
14+
15+
def test_read(self):
16+
buffer = bytearray(5)
17+
self.assertEqual(self.reader.readinto(buffer), 5)
18+
self.assertEqual(buffer, b'hello')
19+
20+
def test_read_file(self):
21+
self.assertEqual(self.reader.read(5), b'hello')
22+
23+
def test_read_partial_chunk(self):
24+
"""Test reading fewer bytes than a chunk contains, across multiple reads."""
25+
buffer = bytearray(3)
26+
self.assertEqual(self.reader.readinto(buffer), 3)
27+
self.assertEqual(buffer, b'hel')
28+
self.assertEqual(self.reader.readinto(buffer), 2)
29+
self.assertEqual(buffer[:2], b'lo')
30+
self.assertEqual(self.reader.readinto(buffer), 3)
31+
self.assertEqual(buffer, b'wor')
32+
33+
def test_large_chunk(self):
34+
"""Test handling a chunk larger than the buffer size."""
35+
large_chunk = b'a' * 20
36+
reader = StreamingChunksReader(iter([large_chunk]))
37+
buffer = bytearray(10)
38+
self.assertEqual(reader.readinto(buffer), 10)
39+
self.assertEqual(buffer, b'a' * 10)
40+
self.assertEqual(reader.readinto(buffer), 10)
41+
self.assertEqual(buffer, b'a' * 10)
42+
43+
44+
class TestSeekableStreamingChunksReader(TestStreamingChunksReader):
45+
46+
def setUp(self):
47+
self.chunks = [b'hello', b'world', b'12345']
48+
self.reader = SeekableStreamingChunksReader(iter(self.chunks), buffer_size=10)
49+
50+
def test_interleaved_read_and_seek(self):
51+
"""Test alternating read and seek operations."""
52+
buffer = bytearray(5)
53+
self.reader.readinto(buffer)
54+
self.assertEqual(buffer, b'hello')
55+
buffer[:] = b'xxxxx'
56+
self.reader.seek(0)
57+
self.assertEqual(self.reader.readinto(buffer), 5)
58+
self.assertEqual(buffer, b'hello')
59+
self.reader.seek(7)
60+
n = self.reader.readinto(buffer)
61+
assert 1 <= n <= len(buffer)
62+
self.assertEqual(buffer[:n], b''.join(self.chunks)[7:7 + n])
63+
64+
def test_seek_and_tell(self):
65+
"""Test seeking to a position and confirming it with tell()."""
66+
self.reader.seek(5)
67+
self.assertEqual(self.reader.tell(), 5)
68+
self.reader.seek(-2, io.SEEK_CUR)
69+
self.assertEqual(self.reader.tell(), 3)
70+
71+
def test_seek_out_of_bounds(self):
72+
"""Test seeking to a negative position, which should raise an IOError."""
73+
with self.assertRaises(IOError):
74+
self.reader.seek(-1)
75+
76+
77+
if __name__ == '__main__':
78+
unittest.main()

0 commit comments

Comments
 (0)