draw_boxplots.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. """
  2. Boxplots of Elo ratings with 95% confidence intervals for each method.
  3. Invocation:
  4. python draw_boxplots.py results.txt boxplots.png
  5. @kylel
  6. """
  7. import hashlib
  8. import re
  9. from pathlib import Path
  10. import click
  11. import matplotlib.font_manager as font_manager
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. import requests
  15. # AI2 Colors
  16. AI2_PINK = "#f0529c"
  17. AI2_DARK_TEAL = "#0a3235"
  18. AI2_TEAL = "#105257"
  19. # Name mappings
  20. NAME_DISPLAY_MAP = {"pdelf": "olmOCR", "mineru": "MinerU", "marker": "Marker", "gotocr_format": "GOTOCR"}
  21. def download_and_cache_file(url, cache_dir=None):
  22. """Download a file and cache it locally."""
  23. if cache_dir is None:
  24. cache_dir = Path.home() / ".cache" / "elo_plot"
  25. cache_dir = Path(cache_dir)
  26. cache_dir.mkdir(parents=True, exist_ok=True)
  27. # Create filename from URL hash
  28. url_hash = hashlib.sha256(url.encode()).hexdigest()[:12]
  29. file_name = url.split("/")[-1]
  30. cached_path = cache_dir / f"{url_hash}_{file_name}"
  31. if not cached_path.exists():
  32. response = requests.get(url, stream=True)
  33. response.raise_for_status()
  34. with open(cached_path, "wb") as f:
  35. for chunk in response.iter_content(chunk_size=8192):
  36. f.write(chunk)
  37. return str(cached_path)
  38. def parse_elo_data(file_path):
  39. """Parse Elo ratings data from a text file."""
  40. with open(file_path, "r") as f:
  41. content = f.read()
  42. # Regular expression to match the data lines
  43. pattern = r"(\w+)\s+(\d+\.\d+)\s*±\s*(\d+\.\d+)\s*\[(\d+\.\d+),\s*(\d+\.\d+)\]"
  44. matches = re.finditer(pattern, content)
  45. # Initialize lists to store data
  46. names = []
  47. medians = []
  48. errors = []
  49. ci_low = []
  50. ci_high = []
  51. for match in matches:
  52. names.append(match.group(1))
  53. medians.append(float(match.group(2)))
  54. errors.append(float(match.group(3)))
  55. ci_low.append(float(match.group(4)))
  56. ci_high.append(float(match.group(5)))
  57. return names, medians, errors, ci_low, ci_high
  58. def create_boxplot(names, medians, errors, ci_low, ci_high, output_path, font_path):
  59. """Create and save a boxplot of Elo ratings."""
  60. # Set up Manrope font
  61. font_manager.fontManager.addfont(font_path)
  62. plt.rcParams["font.family"] = "Manrope"
  63. plt.rcParams["font.weight"] = "medium"
  64. # Define colors - pdelf in pink, others in shades of teal/grey based on performance
  65. max_median = max(medians)
  66. colors = []
  67. for i, median in enumerate(medians):
  68. if names[i] == "pdelf":
  69. colors.append(AI2_PINK)
  70. else:
  71. # Calculate a shade between dark teal and grey based on performance
  72. performance_ratio = (median - min(medians)) / (max_median - min(medians))
  73. base_color = np.array(tuple(int(AI2_DARK_TEAL[i : i + 2], 16) for i in (1, 3, 5))) / 255.0
  74. grey = np.array([0.7, 0.7, 0.7]) # Light grey
  75. color = tuple(np.clip(base_color * performance_ratio + grey * (1 - performance_ratio), 0, 1))
  76. colors.append(color)
  77. # Create box plot data
  78. box_data = []
  79. for i in range(len(names)):
  80. q1 = medians[i] - errors[i]
  81. q3 = medians[i] + errors[i]
  82. box_data.append([ci_low[i], q1, medians[i], q3, ci_high[i]])
  83. # Create box plot with smaller width and spacing
  84. plt.figure(figsize=(4, 4))
  85. bp = plt.boxplot(
  86. box_data,
  87. labels=[NAME_DISPLAY_MAP[name] for name in names],
  88. whis=1.5,
  89. patch_artist=True,
  90. widths=0.15, # Make boxes much narrower
  91. medianprops=dict(color="black"), # Make median line black
  92. positions=np.arange(len(names)) * 0.25,
  93. ) # Reduce spacing between boxes significantly
  94. # Color each box
  95. for patch, color in zip(bp["boxes"], colors):
  96. patch.set_facecolor(color)
  97. patch.set_alpha(0.8)
  98. # Style the plot
  99. # plt.ylabel("Elo Rating", fontsize=12, color=AI2_DARK_TEAL)
  100. plt.xticks(
  101. np.arange(len(names)) * 0.25, # Match positions from boxplot
  102. [NAME_DISPLAY_MAP[name] for name in names],
  103. rotation=45,
  104. ha="right",
  105. color=AI2_DARK_TEAL,
  106. )
  107. plt.yticks(color=AI2_DARK_TEAL)
  108. # Set x-axis limits to maintain proper spacing
  109. plt.xlim(-0.1, (len(names) - 1) * 0.25 + 0.1)
  110. # Remove the title and adjust the layout
  111. plt.tight_layout()
  112. # Remove spines
  113. for spine in plt.gca().spines.values():
  114. spine.set_visible(False)
  115. # Add left spine only
  116. plt.gca().spines["left"].set_visible(True)
  117. plt.gca().spines["left"].set_color(AI2_DARK_TEAL)
  118. plt.gca().spines["left"].set_linewidth(0.5)
  119. # Add bottom spine only
  120. plt.gca().spines["bottom"].set_visible(True)
  121. plt.gca().spines["bottom"].set_color(AI2_DARK_TEAL)
  122. plt.gca().spines["bottom"].set_linewidth(0.5)
  123. plt.savefig(output_path, dpi=300, bbox_inches="tight", transparent=True)
  124. plt.close()
  125. @click.command()
  126. @click.argument("input_file", type=click.Path(exists=True))
  127. @click.argument("output_file", type=click.Path())
  128. @click.option(
  129. "--manrope-medium-font-path",
  130. type=str,
  131. default="https://dolma-artifacts.org/Manrope-Medium.ttf",
  132. help="Path to the Manrope Medium font file (local path or URL)",
  133. )
  134. def main(input_file, output_file, manrope_medium_font_path):
  135. """Generate a boxplot from Elo ratings data.
  136. INPUT_FILE: Path to the text file containing Elo ratings data
  137. OUTPUT_FILE: Path where the plot should be saved
  138. """
  139. try:
  140. # Handle font path - download and cache if it's a URL
  141. if manrope_medium_font_path.startswith(("http://", "https://")):
  142. font_path = download_and_cache_file(manrope_medium_font_path)
  143. else:
  144. font_path = manrope_medium_font_path
  145. # Parse the data
  146. names, medians, errors, ci_low, ci_high = parse_elo_data(input_file)
  147. # Create and save the plot
  148. create_boxplot(names, medians, errors, ci_low, ci_high, output_file, font_path)
  149. click.echo(f"Plot successfully saved to {output_file}")
  150. except Exception as e:
  151. click.echo(f"Error: {str(e)}", err=True)
  152. raise click.Abort()
  153. if __name__ == "__main__":
  154. main()