test_dataprep.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import base64
  2. import os
  3. import random
  4. import re
  5. import unittest
  6. from io import BytesIO
  7. from unittest.mock import patch
  8. import numpy as np
  9. import pytest
  10. import requests
  11. import torch
  12. from PIL import Image
  13. from torch.utils.data import DataLoader
  14. from tqdm import tqdm
  15. from transformers import AutoProcessor
  16. from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
  17. from olmocr.train.dataloader import build_finetuning_dataset
  18. from olmocr.train.dataprep import (
  19. batch_prepare_data_for_molmo_training,
  20. build_finetuning_prompt,
  21. prepare_data_for_molmo_training,
  22. prepare_data_for_qwen2_training,
  23. )
  24. from olmocr.train.utils import make_dataset
  25. @pytest.mark.nonci
  26. class TestDataprep(unittest.TestCase):
  27. def testFullDataloader(self):
  28. processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
  29. config = TrainConfig(
  30. train_data=DataConfig(
  31. seed=42,
  32. sources=[
  33. SourceConfig(
  34. name="eval_test",
  35. target_longest_image_dim=1024,
  36. target_anchor_text_len=6000,
  37. response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
  38. )
  39. ],
  40. ),
  41. valid_data=DataConfig(
  42. seed=42,
  43. sources=[
  44. SourceConfig(
  45. name="eval_test",
  46. target_longest_image_dim=1024,
  47. target_anchor_text_len=6000,
  48. response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
  49. )
  50. ],
  51. ),
  52. )
  53. train_dataset, valid_dataset = make_dataset(config, processor)
  54. im_end_token_ids = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
  55. # train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False)
  56. for entry in train_dataset:
  57. print({x: (y.shape, y.dtype) for (x, y) in entry.items()})
  58. self.assertEqual(entry["input_ids"].dtype, np.int64)
  59. self.assertEqual(entry["attention_mask"].dtype, np.int64)
  60. self.assertEqual(entry["labels"].dtype, np.int64)
  61. self.assertEqual(entry["pixel_values"].dtype, np.float32)
  62. self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
  63. # Extract input_ids and labels
  64. input_ids = entry["input_ids"]
  65. labels = entry["labels"]
  66. # 1. Verify that the last token is the end token
  67. # Ensure input_ids is long enough
  68. self.assertTrue(len(input_ids) >= len(im_end_token_ids), "Input IDs are shorter than the end token sequence.")
  69. # Compare the last tokens of input_ids with im_end_token_ids
  70. self.assertEqual(
  71. input_ids[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last tokens of input_ids do not match the end token sequence."
  72. )
  73. # 2. Ensure labels are masked correctly and match input_ids after the mask
  74. # Find where labels start being non-masked (-100 is the mask value)
  75. label_indices = np.where(labels != -100)[0]
  76. # There should be at least one label that is not masked
  77. self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
  78. first_label_index = label_indices[0]
  79. # Ensure the masked portion is at least 10 tokens long
  80. self.assertTrue(first_label_index >= 10, "Masked portion of labels is less than 10 tokens.")
  81. # Check that all values before first_label_index are -100
  82. self.assertTrue(np.all(labels[:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
  83. # Check that the unmasked labels match the corresponding input_ids
  84. self.assertTrue(
  85. np.array_equal(labels[first_label_index:], input_ids[first_label_index:]), "Unmasked labels do not match the corresponding input_ids."
  86. )
  87. # Optionally, verify that the last unmasked tokens in labels match the end token IDs
  88. unmasked_labels = labels[labels != -100]
  89. self.assertEqual(
  90. unmasked_labels[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last unmasked tokens in labels do not match the end token sequence."
  91. )
  92. def testListTargetAnchorLength(self):
  93. processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
  94. config = TrainConfig(
  95. train_data=DataConfig(
  96. seed=42,
  97. sources=[
  98. SourceConfig(
  99. name="eval_test",
  100. target_longest_image_dim=1024,
  101. target_anchor_text_len=[0, 6000], # Only 0 and 6000
  102. response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
  103. )
  104. ],
  105. ),
  106. valid_data=DataConfig(
  107. seed=42,
  108. sources=[
  109. SourceConfig(
  110. name="eval_test",
  111. target_longest_image_dim=1024,
  112. target_anchor_text_len=[0, 6000], # Only 0 and 6000
  113. response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
  114. )
  115. ],
  116. ),
  117. )
  118. # Set a fixed seed for reproducibility
  119. random.seed(42)
  120. train_dataset, valid_dataset = make_dataset(config, processor)
  121. zero_count = 0
  122. full_count = 0
  123. num_iterations = 100
  124. for i in range(num_iterations):
  125. entry = train_dataset[0] # Get the first entry repeatedly
  126. # Basic type checks
  127. self.assertEqual(entry["input_ids"].dtype, np.int64)
  128. self.assertEqual(entry["attention_mask"].dtype, np.int64)
  129. self.assertEqual(entry["labels"].dtype, np.int64)
  130. self.assertEqual(entry["pixel_values"].dtype, np.float32)
  131. self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
  132. # Get the input text before the response
  133. # Find where labels start being non-masked (-100 is the mask value)
  134. label_indices = np.where(entry["labels"] != -100)[0]
  135. first_label_index = label_indices[0] if len(label_indices) > 0 else len(entry["input_ids"])
  136. # Decode the input portion to check the prompt
  137. input_text = processor.tokenizer.decode(entry["input_ids"][:first_label_index])
  138. pattern = r"RAW_TEXT_START\nPage dimensions: (\d+\.?\d*)x(\d+\.?\d*)\s+RAW_TEXT_END"
  139. match = re.search(pattern, input_text, flags=re.MULTILINE)
  140. if match:
  141. zero_count += 1
  142. else:
  143. full_count += 1
  144. # Verify the distribution: should be roughly 10% zero-length, 90% full-length
  145. zero_ratio = zero_count / num_iterations
  146. full_ratio = full_count / num_iterations
  147. print(zero_count, full_count)
  148. self.assertTrue(0.45 <= zero_ratio <= 0.55, f"Expected zero-length ratio around 0.5, got {zero_ratio:.2f}")
  149. self.assertTrue(0.45 <= full_ratio <= 0.55, f"Expected full-length ratio around 0.5, got {full_ratio:.2f}")
  150. # Verify total adds up to 100%
  151. self.assertEqual(zero_count + full_count, num_iterations, "Total count should equal number of iterations")
  152. @pytest.mark.nonci
  153. class TestMolmoDataPrep(unittest.TestCase):
  154. def testMolmoDefaultSetup(self):
  155. processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
  156. # process the image and text
  157. inputs = processor.process(images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)], text="Describe this image.")
  158. print(inputs.keys())
  159. print(inputs["input_ids"])
  160. print(processor.tokenizer.batch_decode(inputs["input_ids"]))
  161. labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")
  162. print(labels)
  163. print(processor.tokenizer.batch_decode(labels["input_ids"]))
  164. def testMolmoDataPrep(self):
  165. # Initialize the processor for Molmo
  166. processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
  167. # Create a mock example
  168. example = {
  169. "local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
  170. "page_num": 1,
  171. "response": "This is the response text.",
  172. }
  173. # Define target dimensions and anchor text lengths
  174. target_longest_image_dim = [1024]
  175. target_anchor_text_len = [0, 6000]
  176. # Set a fixed seed for reproducibility
  177. random.seed(42)
  178. # Mock the functions that require actual PDF files
  179. with (
  180. patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
  181. patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
  182. ):
  183. # Set return values for the mocked functions
  184. mock_get_anchor_text.return_value = "This is the anchor text."
  185. # Create a red square image and encode it in base64
  186. img = Image.new("RGB", (100, 100), color="red")
  187. buffered = BytesIO()
  188. img.save(buffered, format="PNG")
  189. img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
  190. mock_render_pdf_to_base64png.return_value = img_str
  191. # Process the example using the prepare_data_for_molmo_training function
  192. processed_example = prepare_data_for_molmo_training(
  193. example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
  194. )
  195. # Basic type checks
  196. self.assertIsInstance(processed_example["input_ids"], torch.Tensor, "input_ids should be a torch.Tensor")
  197. self.assertIsInstance(processed_example["attention_mask"], torch.Tensor, "attention_mask should be a torch.Tensor")
  198. self.assertIsInstance(processed_example["labels"], torch.Tensor, "labels should be a torch.Tensor")
  199. self.assertIsInstance(processed_example["images"], torch.Tensor, "images should be a torch.Tensor")
  200. self.assertIsInstance(processed_example["image_input_idx"], torch.Tensor, "image_input_idx should be a torch.Tensor")
  201. self.assertIsInstance(processed_example["image_masks"], torch.Tensor, "image_masks should be a torch.Tensor")
  202. # Check tensor dimensions
  203. self.assertEqual(len(processed_example["input_ids"].shape), 1, "input_ids should be a 1D tensor")
  204. self.assertEqual(
  205. processed_example["input_ids"].shape, processed_example["attention_mask"].shape, "input_ids and attention_mask should have the same shape"
  206. )
  207. self.assertEqual(processed_example["input_ids"].shape, processed_example["labels"].shape, "input_ids and labels should have the same shape")
  208. # Verify label masking
  209. # Find where labels start being non-masked (-100 is the mask value)
  210. label_indices = torch.where(processed_example["labels"] != -100)[0]
  211. # There should be at least one label that is not masked
  212. self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
  213. first_label_index = label_indices[0]
  214. # Ensure the masked portion is reasonable (at least a few tokens long)
  215. self.assertTrue(first_label_index >= 5, "Masked portion of labels is too short")
  216. # Check that all values before first_label_index are -100
  217. self.assertTrue(torch.all(processed_example["labels"][:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
  218. # Verify attention mask
  219. self.assertTrue(torch.all(processed_example["attention_mask"] == 1), "All attention mask values should be 1")
  220. # Verify image input indices
  221. self.assertTrue(
  222. torch.all(processed_example["image_input_idx"] < len(processed_example["input_ids"])),
  223. "Image input indices should be within the range of input_ids length",
  224. )
  225. # Decode and verify content structure
  226. decoded_input = processor.tokenizer.decode(processed_example["input_ids"])
  227. self.assertIn("This is the anchor text", decoded_input, "Anchor text should be present in the decoded input")
  228. # Verify that unmasked labels decode to the response text
  229. unmasked_labels = processed_example["labels"][processed_example["labels"] != -100]
  230. decoded_labels = processor.tokenizer.decode(unmasked_labels)
  231. self.assertIn("This is the response text", decoded_labels, "Response text should be present in the decoded labels")
  232. def testBatchMolmoDataPrep(self):
  233. """Test the batch preparation function for Molmo"""
  234. processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
  235. # Create a mock batch
  236. batch = {
  237. "local_pdf_path": [os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")],
  238. "page_num": [1],
  239. "response": ["This is the response text."],
  240. }
  241. target_longest_image_dim = [1024]
  242. target_anchor_text_len = [0, 6000]
  243. # Mock the necessary functions
  244. with (
  245. patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
  246. patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
  247. ):
  248. mock_get_anchor_text.return_value = "This is the anchor text."
  249. img = Image.new("RGB", (100, 100), color="red")
  250. buffered = BytesIO()
  251. img.save(buffered, format="PNG")
  252. img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
  253. mock_render_pdf_to_base64png.return_value = img_str
  254. # Process the batch
  255. processed_batch = batch_prepare_data_for_molmo_training(
  256. batch, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
  257. )
  258. # Verify batch structure
  259. self.assertEqual(len(processed_batch["input_ids"]), 1, "Batch size should be 1")
  260. self.assertEqual(len(processed_batch["attention_mask"]), 1, "Batch size should be 1")
  261. self.assertEqual(len(processed_batch["labels"]), 1, "Batch size should be 1")
  262. self.assertEqual(len(processed_batch["images"]), 1, "Batch size should be 1")
  263. self.assertEqual(len(processed_batch["image_input_idx"]), 1, "Batch size should be 1")
  264. self.assertEqual(len(processed_batch["image_masks"]), 1, "Batch size should be 1")
  265. # Verify the first item in the batch
  266. first_item = {k: v[0] for k, v in processed_batch.items()}
  267. self.assertIsInstance(first_item["input_ids"], torch.Tensor, "Batch item should contain torch.Tensor")
  268. self.assertTrue(torch.all(first_item["attention_mask"] == 1), "All attention mask values should be 1")