| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- import asyncio
- import datetime
- import hashlib
- import unittest
- from typing import Dict, List
- from unittest.mock import Mock, call, patch
- from botocore.exceptions import ClientError
- # Import the classes we're testing
- from olmocr.work_queue import S3WorkQueue, WorkItem
- class TestS3WorkQueue(unittest.TestCase):
- def setUp(self):
- """Set up test fixtures before each test method."""
- self.s3_client = Mock()
- self.s3_client.exceptions.ClientError = ClientError
- self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
- self.sample_paths = [
- "s3://test-bucket/data/file1.pdf",
- "s3://test-bucket/data/file2.pdf",
- "s3://test-bucket/data/file3.pdf",
- ]
- def tearDown(self):
- """Clean up after each test method."""
- pass
- def test_compute_workgroup_hash(self):
- """Test hash computation is deterministic and correct"""
- paths = [
- "s3://test-bucket/data/file2.pdf",
- "s3://test-bucket/data/file1.pdf",
- ]
- # Hash should be the same regardless of order
- hash1 = S3WorkQueue._compute_workgroup_hash(paths)
- hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
- self.assertEqual(hash1, hash2)
- def test_init(self):
- """Test initialization of S3WorkQueue"""
- client = Mock()
- queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
- self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
- self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
- self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
- def asyncSetUp(self):
- """Set up async test fixtures"""
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
- def asyncTearDown(self):
- """Clean up async test fixtures"""
- self.loop.close()
- def async_test(f):
- """Decorator for async test methods"""
- def wrapper(*args, **kwargs):
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- return loop.run_until_complete(f(*args, **kwargs))
- finally:
- loop.close()
- return wrapper
- @async_test
- async def test_populate_queue_new_items(self):
- """Test populating queue with new items"""
- # Mock empty existing index
- with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
- with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
- await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
- # Verify upload was called with correct data
- self.assertEqual(mock_upload.call_count, 1)
- _, _, lines = mock_upload.call_args[0]
- # Should create 2 work groups (2 files + 1 file)
- self.assertEqual(len(lines), 2)
- # Verify format of uploaded lines
- for line in lines:
- parts = line.split(",")
- self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
- self.assertEqual(len(parts[0]), 40) # SHA1 hash length
- @async_test
- async def test_populate_queue_existing_items(self):
- """Test populating queue with mix of new and existing items"""
- existing_paths = ["s3://test-bucket/data/existing1.pdf"]
- new_paths = ["s3://test-bucket/data/new1.pdf"]
- # Create existing index content
- existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
- existing_line = f"{existing_hash},{existing_paths[0]}"
- with patch("olmocr.work_queue.download_zstd_csv", return_value=[existing_line]):
- with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
- await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
- # Verify upload called with both existing and new items
- _, _, lines = mock_upload.call_args[0]
- self.assertEqual(len(lines), 2)
- self.assertIn(existing_line, lines)
- @async_test
- async def test_initialize_queue(self):
- """Test queue initialization"""
- # Mock work items and completed items
- work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
- work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
- work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
- completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
- with patch("olmocr.work_queue.download_zstd_csv", return_value=[work_line]):
- with patch("olmocr.work_queue.expand_s3_glob", return_value=completed_items):
- await self.work_queue.initialize_queue()
- # Queue should be empty since all work is completed
- self.assertTrue(self.work_queue._queue.empty())
- @async_test
- async def test_is_completed(self):
- """Test completed work check"""
- work_hash = "testhash123"
- # Test completed work
- self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
- self.assertTrue(await self.work_queue.is_completed(work_hash))
- # Test incomplete work
- self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
- self.assertFalse(await self.work_queue.is_completed(work_hash))
- @async_test
- async def test_get_work(self):
- """Test getting work items"""
- # Setup test data
- work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
- await self.work_queue._queue.put(work_item)
- # Test getting available work
- self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
- result = await self.work_queue.get_work()
- self.assertEqual(result, work_item)
- # Verify lock file was created
- self.s3_client.put_object.assert_called_once()
- bucket, key = self.s3_client.put_object.call_args[1]["Bucket"], self.s3_client.put_object.call_args[1]["Key"]
- self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
- @async_test
- async def test_get_work_completed(self):
- """Test getting work that's already completed"""
- work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
- await self.work_queue._queue.put(work_item)
- # Simulate completed work
- self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
- result = await self.work_queue.get_work()
- self.assertIsNone(result) # Should skip completed work
- @async_test
- async def test_get_work_locked(self):
- """Test getting work that's locked by another worker"""
- work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
- await self.work_queue._queue.put(work_item)
- # Simulate active lock
- recent_time = datetime.datetime.now(datetime.timezone.utc)
- self.s3_client.head_object.side_effect = [
- ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # Not completed
- {"LastModified": recent_time}, # Active lock
- ]
- result = await self.work_queue.get_work()
- self.assertIsNone(result) # Should skip locked work
- @async_test
- async def test_get_work_stale_lock(self):
- """Test getting work with a stale lock"""
- work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
- await self.work_queue._queue.put(work_item)
- # Simulate stale lock
- stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
- self.s3_client.head_object.side_effect = [
- ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # Not completed
- {"LastModified": stale_time}, # Stale lock
- ]
- result = await self.work_queue.get_work()
- self.assertEqual(result, work_item) # Should take work with stale lock
- @async_test
- async def test_mark_done(self):
- """Test marking work as done"""
- work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
- await self.work_queue._queue.put(work_item)
- await self.work_queue.mark_done(work_item)
- # Verify lock file was deleted
- self.s3_client.delete_object.assert_called_once()
- bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
- self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
- @async_test
- async def test_paths_with_commas(self):
- """Test handling of paths that contain commas"""
- # Create paths with commas in them
- paths_with_commas = ["s3://test-bucket/data/file1,with,commas.pdf", "s3://test-bucket/data/file2,comma.pdf", "s3://test-bucket/data/file3.pdf"]
- # Mock empty existing index for initial population
- with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
- with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
- # Populate the queue with these paths
- await self.work_queue.populate_queue(paths_with_commas, items_per_group=3)
- # Capture what would be written to the index
- _, _, lines = mock_upload.call_args[0]
- # Now simulate reading back these lines (which have commas in the paths)
- with patch("olmocr.work_queue.download_zstd_csv", return_value=lines):
- with patch("olmocr.work_queue.expand_s3_glob", return_value=[]):
- # Initialize a fresh queue from these lines
- await self.work_queue.initialize_queue()
- # Mock ClientError for head_object (file doesn't exist)
- self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
- # Get a work item
- work_item = await self.work_queue.get_work()
- # Now verify we get a work item
- self.assertIsNotNone(work_item, "Should get a work item")
- # Verify the work item has the correct number of paths
- self.assertEqual(len(work_item.work_paths), len(paths_with_commas), "Work item should have the correct number of paths")
- # Check that all original paths with commas are preserved
- for path in paths_with_commas:
- print(path)
- self.assertIn(path, work_item.work_paths, f"Path with commas should be preserved: {path}")
- def test_queue_size(self):
- """Test queue size property"""
- self.assertEqual(self.work_queue.size, 0)
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
- self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"])))
- self.assertEqual(self.work_queue.size, 1)
- self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"])))
- self.assertEqual(self.work_queue.size, 2)
- self.loop.close()
- if __name__ == "__main__":
- unittest.main()
|