test_s3_work_queue.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import asyncio
  2. import datetime
  3. import hashlib
  4. import unittest
  5. from typing import Dict, List
  6. from unittest.mock import Mock, call, patch
  7. from botocore.exceptions import ClientError
  8. # Import the classes we're testing
  9. from olmocr.work_queue import S3WorkQueue, WorkItem
  10. class TestS3WorkQueue(unittest.TestCase):
  11. def setUp(self):
  12. """Set up test fixtures before each test method."""
  13. self.s3_client = Mock()
  14. self.s3_client.exceptions.ClientError = ClientError
  15. self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
  16. self.sample_paths = [
  17. "s3://test-bucket/data/file1.pdf",
  18. "s3://test-bucket/data/file2.pdf",
  19. "s3://test-bucket/data/file3.pdf",
  20. ]
  21. def tearDown(self):
  22. """Clean up after each test method."""
  23. pass
  24. def test_compute_workgroup_hash(self):
  25. """Test hash computation is deterministic and correct"""
  26. paths = [
  27. "s3://test-bucket/data/file2.pdf",
  28. "s3://test-bucket/data/file1.pdf",
  29. ]
  30. # Hash should be the same regardless of order
  31. hash1 = S3WorkQueue._compute_workgroup_hash(paths)
  32. hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
  33. self.assertEqual(hash1, hash2)
  34. def test_init(self):
  35. """Test initialization of S3WorkQueue"""
  36. client = Mock()
  37. queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
  38. self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
  39. self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
  40. self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
  41. def asyncSetUp(self):
  42. """Set up async test fixtures"""
  43. self.loop = asyncio.new_event_loop()
  44. asyncio.set_event_loop(self.loop)
  45. def asyncTearDown(self):
  46. """Clean up async test fixtures"""
  47. self.loop.close()
  48. def async_test(f):
  49. """Decorator for async test methods"""
  50. def wrapper(*args, **kwargs):
  51. loop = asyncio.new_event_loop()
  52. asyncio.set_event_loop(loop)
  53. try:
  54. return loop.run_until_complete(f(*args, **kwargs))
  55. finally:
  56. loop.close()
  57. return wrapper
  58. @async_test
  59. async def test_populate_queue_new_items(self):
  60. """Test populating queue with new items"""
  61. # Mock empty existing index
  62. with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
  63. with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
  64. await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
  65. # Verify upload was called with correct data
  66. self.assertEqual(mock_upload.call_count, 1)
  67. _, _, lines = mock_upload.call_args[0]
  68. # Should create 2 work groups (2 files + 1 file)
  69. self.assertEqual(len(lines), 2)
  70. # Verify format of uploaded lines
  71. for line in lines:
  72. parts = line.split(",")
  73. self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
  74. self.assertEqual(len(parts[0]), 40) # SHA1 hash length
  75. @async_test
  76. async def test_populate_queue_existing_items(self):
  77. """Test populating queue with mix of new and existing items"""
  78. existing_paths = ["s3://test-bucket/data/existing1.pdf"]
  79. new_paths = ["s3://test-bucket/data/new1.pdf"]
  80. # Create existing index content
  81. existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
  82. existing_line = f"{existing_hash},{existing_paths[0]}"
  83. with patch("olmocr.work_queue.download_zstd_csv", return_value=[existing_line]):
  84. with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
  85. await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
  86. # Verify upload called with both existing and new items
  87. _, _, lines = mock_upload.call_args[0]
  88. self.assertEqual(len(lines), 2)
  89. self.assertIn(existing_line, lines)
  90. @async_test
  91. async def test_initialize_queue(self):
  92. """Test queue initialization"""
  93. # Mock work items and completed items
  94. work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
  95. work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
  96. work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
  97. completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
  98. with patch("olmocr.work_queue.download_zstd_csv", return_value=[work_line]):
  99. with patch("olmocr.work_queue.expand_s3_glob", return_value=completed_items):
  100. await self.work_queue.initialize_queue()
  101. # Queue should be empty since all work is completed
  102. self.assertTrue(self.work_queue._queue.empty())
  103. @async_test
  104. async def test_is_completed(self):
  105. """Test completed work check"""
  106. work_hash = "testhash123"
  107. # Test completed work
  108. self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
  109. self.assertTrue(await self.work_queue.is_completed(work_hash))
  110. # Test incomplete work
  111. self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
  112. self.assertFalse(await self.work_queue.is_completed(work_hash))
  113. @async_test
  114. async def test_get_work(self):
  115. """Test getting work items"""
  116. # Setup test data
  117. work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
  118. await self.work_queue._queue.put(work_item)
  119. # Test getting available work
  120. self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
  121. result = await self.work_queue.get_work()
  122. self.assertEqual(result, work_item)
  123. # Verify lock file was created
  124. self.s3_client.put_object.assert_called_once()
  125. bucket, key = self.s3_client.put_object.call_args[1]["Bucket"], self.s3_client.put_object.call_args[1]["Key"]
  126. self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
  127. @async_test
  128. async def test_get_work_completed(self):
  129. """Test getting work that's already completed"""
  130. work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
  131. await self.work_queue._queue.put(work_item)
  132. # Simulate completed work
  133. self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
  134. result = await self.work_queue.get_work()
  135. self.assertIsNone(result) # Should skip completed work
  136. @async_test
  137. async def test_get_work_locked(self):
  138. """Test getting work that's locked by another worker"""
  139. work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
  140. await self.work_queue._queue.put(work_item)
  141. # Simulate active lock
  142. recent_time = datetime.datetime.now(datetime.timezone.utc)
  143. self.s3_client.head_object.side_effect = [
  144. ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # Not completed
  145. {"LastModified": recent_time}, # Active lock
  146. ]
  147. result = await self.work_queue.get_work()
  148. self.assertIsNone(result) # Should skip locked work
  149. @async_test
  150. async def test_get_work_stale_lock(self):
  151. """Test getting work with a stale lock"""
  152. work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
  153. await self.work_queue._queue.put(work_item)
  154. # Simulate stale lock
  155. stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
  156. self.s3_client.head_object.side_effect = [
  157. ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # Not completed
  158. {"LastModified": stale_time}, # Stale lock
  159. ]
  160. result = await self.work_queue.get_work()
  161. self.assertEqual(result, work_item) # Should take work with stale lock
  162. @async_test
  163. async def test_mark_done(self):
  164. """Test marking work as done"""
  165. work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
  166. await self.work_queue._queue.put(work_item)
  167. await self.work_queue.mark_done(work_item)
  168. # Verify lock file was deleted
  169. self.s3_client.delete_object.assert_called_once()
  170. bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
  171. self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
  172. @async_test
  173. async def test_paths_with_commas(self):
  174. """Test handling of paths that contain commas"""
  175. # Create paths with commas in them
  176. paths_with_commas = ["s3://test-bucket/data/file1,with,commas.pdf", "s3://test-bucket/data/file2,comma.pdf", "s3://test-bucket/data/file3.pdf"]
  177. # Mock empty existing index for initial population
  178. with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
  179. with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
  180. # Populate the queue with these paths
  181. await self.work_queue.populate_queue(paths_with_commas, items_per_group=3)
  182. # Capture what would be written to the index
  183. _, _, lines = mock_upload.call_args[0]
  184. # Now simulate reading back these lines (which have commas in the paths)
  185. with patch("olmocr.work_queue.download_zstd_csv", return_value=lines):
  186. with patch("olmocr.work_queue.expand_s3_glob", return_value=[]):
  187. # Initialize a fresh queue from these lines
  188. await self.work_queue.initialize_queue()
  189. # Mock ClientError for head_object (file doesn't exist)
  190. self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
  191. # Get a work item
  192. work_item = await self.work_queue.get_work()
  193. # Now verify we get a work item
  194. self.assertIsNotNone(work_item, "Should get a work item")
  195. # Verify the work item has the correct number of paths
  196. self.assertEqual(len(work_item.work_paths), len(paths_with_commas), "Work item should have the correct number of paths")
  197. # Check that all original paths with commas are preserved
  198. for path in paths_with_commas:
  199. print(path)
  200. self.assertIn(path, work_item.work_paths, f"Path with commas should be preserved: {path}")
  201. def test_queue_size(self):
  202. """Test queue size property"""
  203. self.assertEqual(self.work_queue.size, 0)
  204. self.loop = asyncio.new_event_loop()
  205. asyncio.set_event_loop(self.loop)
  206. self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"])))
  207. self.assertEqual(self.work_queue.size, 1)
  208. self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"])))
  209. self.assertEqual(self.work_queue.size, 2)
  210. self.loop.close()
  211. if __name__ == "__main__":
  212. unittest.main()