Skip to content

Commit eb377bc

Browse files
committed
Handle images more flexibly
1 parent 49d5f32 commit eb377bc

File tree

1 file changed

+149
-52
lines changed

1 file changed

+149
-52
lines changed

pymupdf4llm/pymupdf4llm/helpers/pymupdf_rag.py

Lines changed: 149 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import os
3030
import string
31+
import typing
3132

3233
try:
3334
import pymupdf as fitz # available with v1.24.3
@@ -40,14 +41,19 @@
4041
if fitz.pymupdf_version_tuple < (1, 24, 2):
4142
raise NotImplementedError("PyMuPDF version 1.24.2 or later is needed.")
4243

43-
bullet = ("* ", chr(0xF0B7), chr(0xB7), chr(8226), chr(9679))
44+
bullet = ("- ", "* ", chr(0xF0A7), chr(0xF0B7), chr(0xB7), chr(8226), chr(9679))
4445
GRAPHICS_TEXT = "\n![%s](%s)\n"
4546

4647

4748
class IdentifyHeaders:
4849
"""Compute data for identifying header text."""
4950

50-
def __init__(self, doc, pages: list = None, body_limit: float = None):
51+
def __init__(
52+
self,
53+
doc: fitz.Document | str,
54+
pages: list | range | None = None,
55+
body_limit: float = 12,
56+
):
5157
"""Read all text and make a dictionary of fontsizes.
5258
5359
Args:
@@ -85,29 +91,33 @@ def __init__(self, doc, pages: list = None, body_limit: float = None):
8591
self.header_id = {}
8692

8793
# If not provided, choose the most frequent font size as body text.
88-
# If no text at all on all pages, just use 12
89-
if body_limit is None:
90-
temp = sorted(
91-
[(k, v) for k, v in fontsizes.items()],
92-
key=lambda i: i[1],
93-
reverse=True,
94-
)
95-
if temp:
96-
body_limit = temp[0][0]
97-
else:
98-
body_limit = 12
94+
# If no text at all on all pages, just use 12.
95+
# In any case all fonts not exceeding
96+
temp = sorted(
97+
[(k, v) for k, v in fontsizes.items()],
98+
key=lambda i: i[1],
99+
reverse=True,
100+
)
101+
if temp:
102+
b_limit = max(body_limit, temp[0][0])
103+
else:
104+
b_limit = body_limit
99105

100-
sizes = sorted([f for f in fontsizes.keys() if f > body_limit], reverse=True)
106+
# identify up to 6 font sizes as header candidates
107+
sizes = sorted(
108+
[f for f in fontsizes.keys() if f > b_limit],
109+
reverse=True,
110+
)[:6]
101111

102112
# make the header tag dictionary
103113
for i, size in enumerate(sizes):
104114
self.header_id[size] = "#" * (i + 1) + " "
105115

106-
def get_header_id(self, span):
116+
def get_header_id(self, span: dict, **kwargs) -> str:
107117
"""Return appropriate markdown header prefix.
108118
109-
Given a text span from a "dict"/"radict" extraction, determine the
110-
markdown header prefix string of 0 to many concatenated '#' characters.
119+
Given a text span from a "dict"/"rawdict" extraction, determine the
120+
markdown header prefix string of 0 to n concatenated '#' characters.
111121
"""
112122
fontsize = round(span["size"]) # compute fontsize
113123
hdr_id = self.header_id.get(fontsize, "")
@@ -118,20 +128,37 @@ def to_markdown(
118128
doc: fitz.Document | str,
119129
*,
120130
pages: list | range | None = None,
121-
hdr_info: IdentifyHeaders | None = None,
131+
hdr_info: typing.Any = None,
122132
write_images: bool = False,
123133
page_chunks: bool = False,
134+
margins: float | typing.Iterable = (0, 50, 0, 50),
124135
) -> str | list[dict]:
125136
"""Process the document and return the text of its selected pages."""
126137

127138
if isinstance(doc, str):
128139
doc = fitz.open(doc)
129140

130-
if not pages: # use all pages if argument not given
131-
pages = range(doc.page_count)
132-
133-
if not isinstance(hdr_info, IdentifyHeaders):
141+
if pages is None: # use all pages if no selection given
142+
pages = list(range(doc.page_count))
143+
144+
if hasattr(margins, "__float__"):
145+
margins = [margins] * 4
146+
if len(margins) == 2:
147+
margins = (0, margins[0], 0, margins[1])
148+
if len(margins) != 4:
149+
raise ValueError("margins must have length 2 or 4 or be a number.")
150+
elif not all([hasattr(m, "__float__") for m in margins]):
151+
raise ValueError("margin values must be numbers")
152+
153+
# If "hdr_info" is not an object having method "get_header_id", scan the
154+
# document and use font sizes as header level indicators.
155+
if callable(hdr_info):
156+
get_header_id = hdr_info
157+
elif hasattr(hdr_info, "get_header_id") and callable(hdr_info.get_header_id):
158+
get_header_id = hdr_info.get_header_id
159+
else:
134160
hdr_info = IdentifyHeaders(doc)
161+
get_header_id = hdr_info.get_header_id
135162

136163
def resolve_links(links, span):
137164
"""Accept a span and return a markdown link string."""
@@ -146,17 +173,15 @@ def resolve_links(links, span):
146173
return text
147174

148175
def save_image(page, rect, i):
149-
"""Optionally render the rect part of a page.
150-
151-
In any case return the image filename.
152-
"""
176+
"""Optionally render the rect part of a page."""
153177
filename = page.parent.name.replace("\\", "/")
154178
image_path = f"{filename}-{page.number}-{i}.png"
155179
if write_images is True:
156180
pix = page.get_pixmap(clip=rect)
157181
pix.save(image_path)
158182
del pix
159-
return os.path.basename(image_path)
183+
return os.path.basename(image_path)
184+
return ""
160185

161186
def write_text(
162187
page: fitz.Page,
@@ -166,7 +191,6 @@ def write_text(
166191
tab_rects: dict | None = None,
167192
img_rects: dict | None = None,
168193
links: list | None = None,
169-
hdr_info=None,
170194
) -> string:
171195
"""Output the text found inside the given clip.
172196
@@ -227,7 +251,8 @@ def write_text(
227251
key=lambda j: (j[1].y1, j[1].x0),
228252
):
229253
pathname = save_image(page, img_rect, i)
230-
out_string += GRAPHICS_TEXT % (pathname, pathname)
254+
if pathname:
255+
out_string += GRAPHICS_TEXT % (pathname, pathname)
231256
del img_rects[i]
232257

233258
text = " ".join([s["text"] for s in spans])
@@ -247,11 +272,11 @@ def write_text(
247272
out_string += indent + text + "\n"
248273
continue # done with this line
249274

250-
bno = spans[0]["block"] # block number of line
275+
span0 = spans[0]
276+
bno = span0["block"] # block number of line
251277
if bno != prev_bno:
252278
out_string += "\n"
253279
prev_bno = bno
254-
span0 = spans[0]
255280

256281
if ( # check if we need another line break
257282
prev_lrect
@@ -264,19 +289,24 @@ def write_text(
264289
prev_lrect = lrect
265290

266291
# if line is a header, this will return multiple "#" characters
267-
hdr_string = hdr_info.get_header_id(spans[0])
292+
hdr_string = get_header_id(span0)
268293

269294
# intercept if header text has been broken in multiple lines
270295
if hdr_string and hdr_string == prev_hdr_string:
271296
out_string = out_string[:-1] + " " + text + "\n"
272297
continue
298+
273299
prev_hdr_string = hdr_string
300+
if hdr_string.startswith("#"): # if a header output and skip the rest
301+
out_string += hdr_string + text + "\n"
302+
continue
303+
304+
# this line is not all-mono, so switch off "code" mode
305+
if code: # still in code output mode?
306+
out_string += "```\n" # switch of code mode
307+
code = False
274308

275309
for i, s in enumerate(spans): # iterate spans of the line
276-
# this line is not all-mono, so switch off "code" mode
277-
if code: # still in code output mode?
278-
out_string += "```\n" # switch of code mode
279-
code = False
280310
# decode font properties
281311
mono = s["flags"] & 8
282312
bold = s["flags"] & 16
@@ -312,6 +342,7 @@ def write_text(
312342
if code:
313343
out_string += "```\n" # switch of code mode
314344
code = False
345+
315346
return (
316347
out_string.replace(" \n", "\n").replace(" ", " ").replace("\n\n\n", "\n\n")
317348
)
@@ -361,7 +392,8 @@ def output_images(page, text_rect, img_rects):
361392
key=lambda j: (j[1].y1, j[1].x0),
362393
):
363394
pathname = save_image(page, img_rect, i)
364-
this_md += GRAPHICS_TEXT % (pathname, pathname)
395+
if pathname:
396+
this_md += GRAPHICS_TEXT % (pathname, pathname)
365397
del img_rects[i] # do not touch this image twice
366398

367399
else: # output all remaining table
@@ -370,7 +402,8 @@ def output_images(page, text_rect, img_rects):
370402
key=lambda j: (j[1].y1, j[1].x0),
371403
):
372404
pathname = save_image(page, img_rect, i)
373-
this_md += GRAPHICS_TEXT % (pathname, pathname)
405+
if pathname:
406+
this_md += GRAPHICS_TEXT % (pathname, pathname)
374407
del img_rects[i] # do not touch this image twice
375408
return this_md
376409

@@ -381,8 +414,20 @@ def get_metadata(doc, pno):
381414
meta["page"] = pno + 1
382415
return meta
383416

384-
def get_page_output(doc, pno, textflags):
385-
"""Process one page."""
417+
def get_page_output(doc, pno, margins, textflags):
418+
"""Process one page.
419+
420+
Args:
421+
doc: fitz.Document
422+
pno: 0-based page number
423+
textflags: text extraction flag bits
424+
images: store image information here
425+
tables: store table information here
426+
graphics: store graphics information here
427+
428+
Returns:
429+
Markdown string of page content.
430+
"""
386431
page = doc[pno]
387432
md_string = ""
388433

@@ -392,37 +437,71 @@ def get_page_output(doc, pno, textflags):
392437
# make a TextPage for all later extractions
393438
textpage = page.get_textpage(flags=textflags)
394439

440+
img_info = page.get_image_info()
441+
images = img_info[:]
442+
tables = []
443+
graphics = []
444+
395445
# Locate all tables on page
396-
tabs = page.find_tables()
446+
tabs = page.find_tables(strategy="lines_strict")
397447

398448
# Make a list of table boundary boxes.
399449
# Must include the header bbox (may exist outside tab.bbox)
400450
tab_rects = {}
401451
for i, t in enumerate(tabs):
402452
tab_rects[i] = fitz.Rect(t.bbox) | fitz.Rect(t.header.bbox)
453+
tab_dict = {
454+
"bbox": tuple(tab_rects[i]),
455+
"rows": t.row_count,
456+
"columns": t.col_count,
457+
}
458+
tables.append(tab_dict)
403459
tab_rects0 = list(tab_rects.values())
404460

405461
# Select paths that are not contained in any table
406462
page_clip = page.rect + (36, 36, -36, -36) # ignore full page graphics
407463
paths = [
408464
p
409465
for p in page.get_drawings()
410-
if not intersects_rects(p["rect"], tab_rects0) and p["rect"] in page_clip
466+
if not intersects_rects(p["rect"], tab_rects0)
467+
and p["rect"] in page_clip
468+
and p["rect"].width < page_clip.width
469+
and p["rect"].height < page_clip.height
411470
]
412471

413-
# determine vector graphics outside any tables
414-
vg_clusters = page.cluster_drawings(drawings=paths)
472+
# Determine vector graphics outside any tables, filerting out any
473+
# which contain no stroked paths
474+
vg_clusters = []
475+
for bbox in page.cluster_drawings(drawings=paths):
476+
include = False
477+
for p in [p for p in paths if p["rect"] in bbox]:
478+
if p["type"] != "f":
479+
include = True
480+
break
481+
if [item[0] for item in p["items"] if item[0] == "c"]:
482+
include = True
483+
break
484+
if include is True:
485+
vg_clusters.append(bbox)
486+
487+
actual_paths = [p for p in paths if is_in_rects(p["rect"], vg_clusters)]
488+
print(f"before: {len(vg_clusters)=}")
415489
vg_clusters0 = [
416490
r
417491
for r in vg_clusters
418492
if not intersects_rects(r, tab_rects0) and r.height > 20
419-
] + [fitz.Rect(i["bbox"]) for i in page.get_image_info()]
493+
]
494+
495+
if write_images is True:
496+
vg_clusters0 += [fitz.Rect(i["bbox"]) for i in img_info]
420497

421498
vg_clusters = dict((i, r) for i, r in enumerate(vg_clusters0))
422499
# Determine text column bboxes on page, avoiding tables and graphics
500+
print(f"{len(tab_rects0)=}, {len(vg_clusters0)=}")
423501
text_rects = column_boxes(
424502
page,
425-
paths=paths,
503+
paths=actual_paths,
504+
no_image_text=write_images,
426505
textpage=textpage,
427506
avoid=tab_rects0 + vg_clusters0,
428507
)
@@ -444,28 +523,46 @@ def get_page_output(doc, pno, textflags):
444523
tab_rects=tab_rects,
445524
img_rects=vg_clusters,
446525
links=links,
447-
hdr_info=hdr_info,
448526
)
449527

450-
# write remaining tables.
528+
# write any remaining tables and images
451529
md_string += output_tables(tabs, None, tab_rects)
452530
md_string += output_images(None, tab_rects, None)
453531
md_string += "\n-----\n\n"
454-
return md_string
532+
while md_string.startswith("\n"):
533+
md_string = md_string[1:]
534+
return md_string, images, tables, graphics
455535

456536
if page_chunks is False:
457537
document_output = ""
458538
else:
459539
document_output = []
460540

541+
# read the Table of Contents
542+
toc = doc.get_toc()
461543
textflags = fitz.TEXT_DEHYPHENATE | fitz.TEXT_MEDIABOX_CLIP
462-
for pno in list(pages):
463-
page_output = get_page_output(doc, pno, textflags)
544+
for pno in pages:
545+
546+
page_output, images, tables, graphics = get_page_output(
547+
doc, pno, margins, textflags
548+
)
464549
if page_chunks is False:
465550
document_output += page_output
466551
else:
552+
# build subet of TOC for this page
553+
page_tocs = [t for t in toc if t[-1] == pno + 1]
554+
467555
metadata = get_metadata(doc, pno)
468-
document_output.append({"metadata": metadata, "text": page_output})
556+
document_output.append(
557+
{
558+
"metadata": metadata,
559+
"toc_items": page_tocs,
560+
"tables": tables,
561+
"images": images,
562+
"graphics": graphics,
563+
"text": page_output,
564+
}
565+
)
469566

470567
return document_output
471568

0 commit comments

Comments
 (0)