test_dataloader.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import unittest
  2. from functools import partial
  3. import pytest
  4. from torch.utils.data import DataLoader
  5. from tqdm import tqdm
  6. from transformers import AutoProcessor
  7. from olmocr.train.dataloader import (
  8. build_finetuning_dataset,
  9. extract_openai_batch_response,
  10. list_dataset_files,
  11. load_jsonl_into_ds,
  12. )
  13. from olmocr.train.dataprep import batch_prepare_data_for_qwen2_training
  14. @pytest.mark.nonci
  15. class TestBatchQueryResponseDataset(unittest.TestCase):
  16. def testLoadS3(self):
  17. ds = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
  18. print(f"Loaded {len(ds)} entries")
  19. print(ds)
  20. print(ds["train"])
  21. def testFinetuningDS(self):
  22. ds = build_finetuning_dataset(
  23. response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
  24. )
  25. print(ds)
  26. processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
  27. ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
  28. print(ds[0])
  29. def testPlotSequenceLengthHistogram(self):
  30. import plotly.express as px
  31. ds = build_finetuning_dataset(
  32. response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
  33. )
  34. processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
  35. ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
  36. processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
  37. initial_len = len(ds)
  38. train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
  39. max_seen_len = 0
  40. steps = 0
  41. sequence_lengths = [] # List to store sequence lengths
  42. for entry in tqdm(train_dataloader):
  43. num_input_tokens = entry["input_ids"].shape[1]
  44. max_seen_len = max(max_seen_len, num_input_tokens)
  45. sequence_lengths.append(num_input_tokens) # Collecting sequence lengths
  46. if steps % 100 == 0:
  47. print(f"Max input len {max_seen_len}")
  48. steps += 1
  49. # model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
  50. print(f"Max input len {max_seen_len}")
  51. print(f"Total elements before filtering: {initial_len}")
  52. print(f"Total elements after filtering: {steps}")
  53. # Plotting the histogram using Plotly
  54. fig = px.histogram(
  55. sequence_lengths, nbins=100, title="Distribution of Input Sequence Lengths", labels={"value": "Sequence Length", "count": "Frequency"}
  56. )
  57. fig.write_image("sequence_lengths_histogram.png")