test_anchor.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import glob
  2. import io
  3. import json
  4. import os
  5. import unittest
  6. from pypdf import PdfReader
  7. from olmocr.data.renderpdf import get_pdf_media_box_width_height
  8. from olmocr.prompts.anchor import _linearize_pdf_report, _pdf_report, get_anchor_text
  9. class AnchorTest(unittest.TestCase):
  10. def testExtractText(self):
  11. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "some_ocr1.pdf")
  12. reader = PdfReader(local_pdf_path)
  13. page = reader.pages[0]
  14. def visitor_body(text, cm, tm, font_dict, font_size):
  15. print(repr(text), cm, tm, font_size)
  16. def visitor_op(op, args, cm, tm):
  17. # print(op, args, cm, tm)
  18. pass
  19. page.extract_text(visitor_text=visitor_body, visitor_operand_before=visitor_op)
  20. def testAnchorBase(self):
  21. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf")
  22. report = _pdf_report(local_pdf_path, 2)
  23. print(report)
  24. print(get_anchor_text(local_pdf_path, 2, pdf_engine="pdfreport"))
  25. def testAnchorImage(self):
  26. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "some_ocr1.pdf")
  27. report = _pdf_report(local_pdf_path, 1)
  28. print(report)
  29. print(get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport"))
  30. def testSmallPage(self):
  31. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "small_page_size.pdf")
  32. report = _pdf_report(local_pdf_path, 1)
  33. print(report)
  34. print(get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport"))
  35. def testBadUTFSurrogatePairsGeneration(self):
  36. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "badlines.pdf")
  37. anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
  38. jsondata = json.dumps({"text": anchor_text})
  39. import pyarrow as pa
  40. import pyarrow.compute as pc
  41. import pyarrow.json as paj
  42. buffer = io.BytesIO(jsondata.encode("utf-8"))
  43. paj.read_json(buffer, read_options=paj.ReadOptions(use_threads=False, block_size=len(jsondata)))
  44. def testLargePromptHint1(self):
  45. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint1.pdf")
  46. anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
  47. print(anchor_text)
  48. print(len(anchor_text))
  49. self.assertLessEqual(len(anchor_text), 1000)
  50. def testLargePromptHint2(self):
  51. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint2.pdf")
  52. anchor_text = get_anchor_text(local_pdf_path, 2, pdf_engine="pdfreport")
  53. print(anchor_text)
  54. print(len(anchor_text))
  55. self.assertLessEqual(len(anchor_text), 4000)
  56. def testLargePromptHint3(self):
  57. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint3.pdf")
  58. anchor_text = get_anchor_text(local_pdf_path, 2, pdf_engine="pdfreport")
  59. print(anchor_text)
  60. print(len(anchor_text))
  61. self.assertLessEqual(len(anchor_text), 4000)
  62. def testNewsPaperPromptHint(self):
  63. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "newspaper.pdf")
  64. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport")
  65. print(anchor_text)
  66. print(len(anchor_text))
  67. self.assertLessEqual(len(anchor_text), 4000)
  68. def testTobaccoPaperMissingParagraphs(self):
  69. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
  70. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport")
  71. print(anchor_text)
  72. print(len(anchor_text))
  73. self.assertLessEqual(len(anchor_text), 4000)
  74. def testAnchorOtherLengths(self):
  75. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
  76. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=2000)
  77. print(anchor_text)
  78. print(len(anchor_text))
  79. self.assertLessEqual(len(anchor_text), 2000)
  80. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
  81. print(anchor_text)
  82. print(len(anchor_text))
  83. self.assertLessEqual(len(anchor_text), 6000)
  84. def testFailingAnchor(self):
  85. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "failing_anchor_pg4.pdf")
  86. anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
  87. print(anchor_text)
  88. print(len(anchor_text))
  89. self.assertLessEqual(len(anchor_text), 4000)
  90. def testEmptyAnchor(self):
  91. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
  92. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=0)
  93. self.assertEqual(anchor_text.strip(), "Page dimensions: 612.0x792.0")
  94. def testCannotLoad(self):
  95. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "load_v_error.pdf")
  96. reader = PdfReader(local_pdf_path)
  97. page = 5
  98. anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=6000)
  99. print(anchor_text)
  100. print(len(anchor_text))
  101. self.assertLessEqual(len(anchor_text), 6000)
  102. @unittest.skip("TODO, this unit test still fails, the map text is too large.")
  103. def testExcessiveMapAnchor(self):
  104. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "map1.pdf")
  105. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
  106. print(anchor_text)
  107. print(len(anchor_text))
  108. self.assertLessEqual(len(anchor_text), 4000)
  109. def testKyleOnePageAnchors1(self):
  110. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "dolma-page-1.pdf")
  111. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
  112. print(anchor_text)
  113. print(len(anchor_text))
  114. self.assertLessEqual(len(anchor_text), 6000)
  115. def testKyleOnePageAnchors2(self):
  116. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "olmo-page-1.pdf")
  117. anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
  118. print(anchor_text)
  119. print(len(anchor_text))
  120. self.assertLessEqual(len(anchor_text), 6000)
  121. class BuildSilverTest(unittest.TestCase):
  122. def testSmallPage(self):
  123. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "small_page_size.pdf")
  124. from olmocr.data.buildsilver import build_page_query
  125. result = build_page_query(local_pdf_path, "s3://test.pdf", 1)
  126. from olmocr.data.renderpdf import get_png_dimensions_from_base64
  127. base64data = result["body"]["messages"][0]["content"][1]["image_url"]["url"]
  128. if base64data.startswith("data:image/png;base64,"):
  129. base64data = base64data[22:]
  130. width, height = get_png_dimensions_from_base64(base64data)
  131. print(width, height)
  132. assert max(width, height) == 2048
  133. class TestRenderPdf(unittest.TestCase):
  134. def testFastMediaBoxMatchesPyPdf(self):
  135. for file in glob.glob(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "*.pdf")):
  136. reader = PdfReader(file)
  137. print("checking", file)
  138. for page_num in range(1, len(reader.pages) + 1):
  139. w1, h1 = get_pdf_media_box_width_height(file, page_num)
  140. pypdfpage = reader.pages[page_num - 1]
  141. self.assertAlmostEqual(w1, pypdfpage.mediabox.width, places=3)
  142. self.assertAlmostEqual(h1, pypdfpage.mediabox.height, places=3)
  143. class TestOutputSamplePage(unittest.TestCase):
  144. def testTobaccoPaper(self):
  145. local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
  146. anchor_text = get_anchor_text(local_pdf_path, 1, "pdfreport", target_length=6000)
  147. print("")
  148. print(anchor_text)
  149. print("")