| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- import base64
- import os
- import random
- import re
- import unittest
- from io import BytesIO
- from unittest.mock import patch
- import numpy as np
- import pytest
- import requests
- import torch
- from PIL import Image
- from torch.utils.data import DataLoader
- from tqdm import tqdm
- from transformers import AutoProcessor
- from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
- from olmocr.train.dataloader import build_finetuning_dataset
- from olmocr.train.dataprep import (
- batch_prepare_data_for_molmo_training,
- build_finetuning_prompt,
- prepare_data_for_molmo_training,
- prepare_data_for_qwen2_training,
- )
- from olmocr.train.utils import make_dataset
- @pytest.mark.nonci
- class TestDataprep(unittest.TestCase):
- def testFullDataloader(self):
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
- config = TrainConfig(
- train_data=DataConfig(
- seed=42,
- sources=[
- SourceConfig(
- name="eval_test",
- target_longest_image_dim=1024,
- target_anchor_text_len=6000,
- response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
- )
- ],
- ),
- valid_data=DataConfig(
- seed=42,
- sources=[
- SourceConfig(
- name="eval_test",
- target_longest_image_dim=1024,
- target_anchor_text_len=6000,
- response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
- )
- ],
- ),
- )
- train_dataset, valid_dataset = make_dataset(config, processor)
- im_end_token_ids = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
- # train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False)
- for entry in train_dataset:
- print({x: (y.shape, y.dtype) for (x, y) in entry.items()})
- self.assertEqual(entry["input_ids"].dtype, np.int64)
- self.assertEqual(entry["attention_mask"].dtype, np.int64)
- self.assertEqual(entry["labels"].dtype, np.int64)
- self.assertEqual(entry["pixel_values"].dtype, np.float32)
- self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
- # Extract input_ids and labels
- input_ids = entry["input_ids"]
- labels = entry["labels"]
- # 1. Verify that the last token is the end token
- # Ensure input_ids is long enough
- self.assertTrue(len(input_ids) >= len(im_end_token_ids), "Input IDs are shorter than the end token sequence.")
- # Compare the last tokens of input_ids with im_end_token_ids
- self.assertEqual(
- input_ids[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last tokens of input_ids do not match the end token sequence."
- )
- # 2. Ensure labels are masked correctly and match input_ids after the mask
- # Find where labels start being non-masked (-100 is the mask value)
- label_indices = np.where(labels != -100)[0]
- # There should be at least one label that is not masked
- self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
- first_label_index = label_indices[0]
- # Ensure the masked portion is at least 10 tokens long
- self.assertTrue(first_label_index >= 10, "Masked portion of labels is less than 10 tokens.")
- # Check that all values before first_label_index are -100
- self.assertTrue(np.all(labels[:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
- # Check that the unmasked labels match the corresponding input_ids
- self.assertTrue(
- np.array_equal(labels[first_label_index:], input_ids[first_label_index:]), "Unmasked labels do not match the corresponding input_ids."
- )
- # Optionally, verify that the last unmasked tokens in labels match the end token IDs
- unmasked_labels = labels[labels != -100]
- self.assertEqual(
- unmasked_labels[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last unmasked tokens in labels do not match the end token sequence."
- )
- def testListTargetAnchorLength(self):
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
- config = TrainConfig(
- train_data=DataConfig(
- seed=42,
- sources=[
- SourceConfig(
- name="eval_test",
- target_longest_image_dim=1024,
- target_anchor_text_len=[0, 6000], # Only 0 and 6000
- response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
- )
- ],
- ),
- valid_data=DataConfig(
- seed=42,
- sources=[
- SourceConfig(
- name="eval_test",
- target_longest_image_dim=1024,
- target_anchor_text_len=[0, 6000], # Only 0 and 6000
- response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
- )
- ],
- ),
- )
- # Set a fixed seed for reproducibility
- random.seed(42)
- train_dataset, valid_dataset = make_dataset(config, processor)
- zero_count = 0
- full_count = 0
- num_iterations = 100
- for i in range(num_iterations):
- entry = train_dataset[0] # Get the first entry repeatedly
- # Basic type checks
- self.assertEqual(entry["input_ids"].dtype, np.int64)
- self.assertEqual(entry["attention_mask"].dtype, np.int64)
- self.assertEqual(entry["labels"].dtype, np.int64)
- self.assertEqual(entry["pixel_values"].dtype, np.float32)
- self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
- # Get the input text before the response
- # Find where labels start being non-masked (-100 is the mask value)
- label_indices = np.where(entry["labels"] != -100)[0]
- first_label_index = label_indices[0] if len(label_indices) > 0 else len(entry["input_ids"])
- # Decode the input portion to check the prompt
- input_text = processor.tokenizer.decode(entry["input_ids"][:first_label_index])
- pattern = r"RAW_TEXT_START\nPage dimensions: (\d+\.?\d*)x(\d+\.?\d*)\s+RAW_TEXT_END"
- match = re.search(pattern, input_text, flags=re.MULTILINE)
- if match:
- zero_count += 1
- else:
- full_count += 1
- # Verify the distribution: should be roughly 10% zero-length, 90% full-length
- zero_ratio = zero_count / num_iterations
- full_ratio = full_count / num_iterations
- print(zero_count, full_count)
- self.assertTrue(0.45 <= zero_ratio <= 0.55, f"Expected zero-length ratio around 0.5, got {zero_ratio:.2f}")
- self.assertTrue(0.45 <= full_ratio <= 0.55, f"Expected full-length ratio around 0.5, got {full_ratio:.2f}")
- # Verify total adds up to 100%
- self.assertEqual(zero_count + full_count, num_iterations, "Total count should equal number of iterations")
- @pytest.mark.nonci
- class TestMolmoDataPrep(unittest.TestCase):
- def testMolmoDefaultSetup(self):
- processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
- # process the image and text
- inputs = processor.process(images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)], text="Describe this image.")
- print(inputs.keys())
- print(inputs["input_ids"])
- print(processor.tokenizer.batch_decode(inputs["input_ids"]))
- labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")
- print(labels)
- print(processor.tokenizer.batch_decode(labels["input_ids"]))
- def testMolmoDataPrep(self):
- # Initialize the processor for Molmo
- processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
- # Create a mock example
- example = {
- "local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
- "page_num": 1,
- "response": "This is the response text.",
- }
- # Define target dimensions and anchor text lengths
- target_longest_image_dim = [1024]
- target_anchor_text_len = [0, 6000]
- # Set a fixed seed for reproducibility
- random.seed(42)
- # Mock the functions that require actual PDF files
- with (
- patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
- patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
- ):
- # Set return values for the mocked functions
- mock_get_anchor_text.return_value = "This is the anchor text."
- # Create a red square image and encode it in base64
- img = Image.new("RGB", (100, 100), color="red")
- buffered = BytesIO()
- img.save(buffered, format="PNG")
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
- mock_render_pdf_to_base64png.return_value = img_str
- # Process the example using the prepare_data_for_molmo_training function
- processed_example = prepare_data_for_molmo_training(
- example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
- )
- # Basic type checks
- self.assertIsInstance(processed_example["input_ids"], torch.Tensor, "input_ids should be a torch.Tensor")
- self.assertIsInstance(processed_example["attention_mask"], torch.Tensor, "attention_mask should be a torch.Tensor")
- self.assertIsInstance(processed_example["labels"], torch.Tensor, "labels should be a torch.Tensor")
- self.assertIsInstance(processed_example["images"], torch.Tensor, "images should be a torch.Tensor")
- self.assertIsInstance(processed_example["image_input_idx"], torch.Tensor, "image_input_idx should be a torch.Tensor")
- self.assertIsInstance(processed_example["image_masks"], torch.Tensor, "image_masks should be a torch.Tensor")
- # Check tensor dimensions
- self.assertEqual(len(processed_example["input_ids"].shape), 1, "input_ids should be a 1D tensor")
- self.assertEqual(
- processed_example["input_ids"].shape, processed_example["attention_mask"].shape, "input_ids and attention_mask should have the same shape"
- )
- self.assertEqual(processed_example["input_ids"].shape, processed_example["labels"].shape, "input_ids and labels should have the same shape")
- # Verify label masking
- # Find where labels start being non-masked (-100 is the mask value)
- label_indices = torch.where(processed_example["labels"] != -100)[0]
- # There should be at least one label that is not masked
- self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
- first_label_index = label_indices[0]
- # Ensure the masked portion is reasonable (at least a few tokens long)
- self.assertTrue(first_label_index >= 5, "Masked portion of labels is too short")
- # Check that all values before first_label_index are -100
- self.assertTrue(torch.all(processed_example["labels"][:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
- # Verify attention mask
- self.assertTrue(torch.all(processed_example["attention_mask"] == 1), "All attention mask values should be 1")
- # Verify image input indices
- self.assertTrue(
- torch.all(processed_example["image_input_idx"] < len(processed_example["input_ids"])),
- "Image input indices should be within the range of input_ids length",
- )
- # Decode and verify content structure
- decoded_input = processor.tokenizer.decode(processed_example["input_ids"])
- self.assertIn("This is the anchor text", decoded_input, "Anchor text should be present in the decoded input")
- # Verify that unmasked labels decode to the response text
- unmasked_labels = processed_example["labels"][processed_example["labels"] != -100]
- decoded_labels = processor.tokenizer.decode(unmasked_labels)
- self.assertIn("This is the response text", decoded_labels, "Response text should be present in the decoded labels")
- def testBatchMolmoDataPrep(self):
- """Test the batch preparation function for Molmo"""
- processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
- # Create a mock batch
- batch = {
- "local_pdf_path": [os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")],
- "page_num": [1],
- "response": ["This is the response text."],
- }
- target_longest_image_dim = [1024]
- target_anchor_text_len = [0, 6000]
- # Mock the necessary functions
- with (
- patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
- patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
- ):
- mock_get_anchor_text.return_value = "This is the anchor text."
- img = Image.new("RGB", (100, 100), color="red")
- buffered = BytesIO()
- img.save(buffered, format="PNG")
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
- mock_render_pdf_to_base64png.return_value = img_str
- # Process the batch
- processed_batch = batch_prepare_data_for_molmo_training(
- batch, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
- )
- # Verify batch structure
- self.assertEqual(len(processed_batch["input_ids"]), 1, "Batch size should be 1")
- self.assertEqual(len(processed_batch["attention_mask"]), 1, "Batch size should be 1")
- self.assertEqual(len(processed_batch["labels"]), 1, "Batch size should be 1")
- self.assertEqual(len(processed_batch["images"]), 1, "Batch size should be 1")
- self.assertEqual(len(processed_batch["image_input_idx"]), 1, "Batch size should be 1")
- self.assertEqual(len(processed_batch["image_masks"]), 1, "Batch size should be 1")
- # Verify the first item in the batch
- first_item = {k: v[0] for k, v in processed_batch.items()}
- self.assertIsInstance(first_item["input_ids"], torch.Tensor, "Batch item should contain torch.Tensor")
- self.assertTrue(torch.all(first_item["attention_mask"] == 1), "All attention mask values should be 1")
|