How to use a Layout Model on PDFs?

I have some PDFs that I want to extract “chunks” from, but I want the layout of the page to be taken into consideration.

I want to use a Layout Model to do so, how can I leverage it in Foundry ?

This post makes some references to https://community.palantir.com/t/how-to-do-pdf-annotation-in-workshop/3289/2

You can extract “coordinates” of bounding boxes of “elements” on pages by using a Layout model, indeed.

  1. Create a Code repository (Pipelines > Python)
  2. Install a few libraries: transforms-media, pdfplumber, pillow, pypdf2
  3. Extract the media reference from the mediaset. The goal is identical as for Solution A, where we want to have references (small json) that points to the original media.
    Note: you can actually do this step in Pipeline Builder. See steps 2 and 3 of the other post.
from transforms.api import Output, transform
from transforms.mediasets import MediaSetInput

@transform(
    output_document=Output("ri.foundry.main.dataset.543a13cd-8de1-49d5-b291-d6c9173622f1"),
    documents_media_set=MediaSetInput("ri.mio.main.media-set.85f88366-bd0c-441e-97d4-5b752e5eeae8"),
)
def generate_media_reference(ctx, output_document, documents_media_set):
    # List the media of the mediaset
    media_references = documents_media_set.list_media_items_by_path_with_media_reference(ctx)
    # Enables in-line thumbnails in dataset
    column_typeclasses = {"mediaReference": [{"kind": "reference", "name": "media_reference"}]}
    # Write media references to document object
    output_document.write_dataframe(media_references, column_typeclasses=column_typeclasses)
  1. Extract each page of a document as a picture - As the layout model needs to operate on pictures, we need to convert each page of the PDF into dedicated pictures
import base64
import io
from io import BytesIO
import pdfplumber
import pyspark.sql.functions as F
from PIL import Image
from transforms.api import Output, transform
from transforms.mediasets import MediaSetInput, MediaSetOutput

def load_pil_from_bytes(img_bytes: bytes) -> Image:
    return Image.open(BytesIO(img_bytes))

def image_dpi_resize(image: Image, target_length: float = 2048.0):
    width_x, height_y = image.size
    factor = float(target_length / max(width_x, height_y))
    size = int(factor * width_x), int(factor * height_y)
    return size

# Split a PDF document into its pages and convert it to base64
def process(row, documents_media_set, pages_pictures_media_set):
    rows = []
    # Get the document metadata
    document_metadata = documents_media_set.get_media_item_metadata(row.mediaItemRid).document
    num_pages = document_metadata.pages
    dims = document_metadata.dimensions.page_dimensions
    # Load the media in memory
    item = documents_media_set.get_media_item(row.mediaItemRid)
    # input_pdf = PdfReader(io.BytesIO(item.read())) # Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for PyPDF2.generic._base.NumberObject). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
    with pdfplumber.open(io.BytesIO(item.read())) as pdf_doc:
        # For each page ...
        for i in range(num_pages):
            # Extract rotation of the page
            rotation = pdf_doc.pages[i].rotation
            # Transform the page to JPG
            img_bytes = documents_media_set.transform_document_to_jpg(
                row.mediaItemRid, page_number=i, height=2048
            ).read()
            # Optional - Save a copy to an output mediaset
            file_base_name = row["path"].replace(".pdf", "") + "_" + "page_" + str(i).zfill(5)
            jpg_rid_response = pages_pictures_media_set.put_media_item(img_bytes, file_base_name + ".jpeg")
            jpg_rid = jpg_rid_response.media_item_rid
            image = load_pil_from_bytes(img_bytes)
            # Q: Is it handling the rotation of the pages ?
            # Resize and encode
            projected_width, projected_height = image_dpi_resize(image, target_length=2048.0)
            img_str = base64.b64encode(img_bytes).decode("ascii")
            # Create a row
            rows.append({
                "mediaItemRid": row.mediaItemRid,
                "base64_encoded_img": f"data:image/png;base64,{img_str}",
                "page_number": int(i + 1),
                "original_width": dims[i].width,
                "original_height": dims[i].height,
                "new_width": projected_width,
                "new_height": projected_height,
                "page_as_picture_media_id": jpg_rid,
                "page_rotation": rotation,
            })
    return rows

@transform(
    output_page=Output("ri.foundry.main.dataset.4513c366-4f8e-4631-bb3d-2c4733af867d"),
    documents_media_set=MediaSetInput("ri.mio.main.media-set.85f88366-bd0c-441e-97d4-5b752e5eeae8"),
    pages_pictures_media_set=MediaSetOutput(
        "ri.mio.main.media-set.e7d8e686-4507-4c60-97bd-48f6a5145e93",
        media_set_schema={"schema_type": "imagery", "primary_format": "jpg"},
        additional_allowed_input_formats=["png"],
        storage_configuration={"type": "native"},
        retention_policy="forever",
        write_mode="transactional",
    ),
)
def extract_pages(ctx, output_page, documents_media_set, pages_pictures_media_set):
    # List the media of the mediaset
    media_references = documents_media_set.list_media_items_by_path_with_media_reference(ctx)
    # Run PDF -> base64 encoding
    page_reference_with_media = media_references.rdd.flatMap(
        lambda row: process(row, documents_media_set, pages_pictures_media_set)
    ).toDF()
    # Create a PK for each page
    page_reference_with_media = page_reference_with_media.withColumn(
        "page_pk",
        F.sha2(
            F.concat_ws("-", F.col("mediaItemRid"), F.col("page_number").cast("string")),
            numBits=256,
        ),
    )
    output_page.write_dataframe(page_reference_with_media)
  1. We will then run the layout model on each picture
import json
import logging

import pyspark.sql.functions as F
import pyspark.sql.types as T
from transforms.api import Input, Output, transform
from transforms.mediasets import MediaSetInput

def resolve_media_bounding_boxes(response):
    element_ids, boxes = {}, []
    for idx_box, box in enumerate(response):
        if box["type"] in ["Image", "Table"]:
            element_ids[box["element_id"]] = idx_box
            boxes.append(box)
    merged_boxes, removed = merge_overlapping_boxes(boxes)
    for box in merged_boxes:
        response[element_ids[box["element_id"]]] = box
    for to_remove_id in sorted([element_ids[removed_id] for removed_id in removed], reverse=True):
        response.pop(to_remove_id)

    return response
    


def merge_overlapping_boxes(json_list):
    def boxes_overlap(box1, box2):
        # Check if two boxes overlap
        x_min1, y_min1 = min(point[0] for point in box1), min(point[1] for point in box1)
        x_max1, y_max1 = max(point[0] for point in box1), max(point[1] for point in box1)
        x_min2, y_min2 = min(point[0] for point in box2), min(point[1] for point in box2)
        x_max2, y_max2 = max(point[0] for point in box2), max(point[1] for point in box2)
        return not (x_min1 > x_max2 or x_max1 < x_min2 or y_min1 > y_max2 or y_max1 < y_min2)

    def merge_boxes(box1, box2):
        # Merge two overlapping boxes
        x_min = min(min(point[0] for point in box1), min(point[0] for point in box2))
        y_min = min(min(point[1] for point in box1), min(point[1] for point in box2))
        x_max = max(max(point[0] for point in box1), max(point[0] for point in box2))
        y_max = max(max(point[1] for point in box1), max(point[1] for point in box2))
        return [[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]

    removed = []
    merged = False
    while not merged:
        merged = True
        for i in range(len(json_list)):
            for j in range(i + 1, len(json_list)):
                box1 = json_list[i]["metadata"]["coordinates"]["points"]
                box2 = json_list[j]["metadata"]["coordinates"]["points"]
                if boxes_overlap(box1, box2):
                    # Merge boxes and update the list
                    new_box = merge_boxes(box1, box2)
                    json_list[i]["metadata"]["coordinates"]["points"] = new_box
                    removed.append(json_list[j]["element_id"])
                    del json_list[j]
                    merged = False
                    break
            if not merged:
                break

    return json_list, removed


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

INFERENCE_SCHEMA = T.StructType([
    T.StructField("page_pk", T.StringType()),
    T.StructField("response", T.StringType()),
    T.StructField("response_raw", T.StringType()),
    T.StructField("error", T.BooleanType())
])

OUTPUT_SCHEMA = T.StructType([
    T.StructField('page_pk', T.StringType()),
    T.StructField('mediaItemRid', T.StringType()),
    T.StructField('path', T.StringType()),
    T.StructField('mediaReference', T.StringType()),
    T.StructField('preview_image', T.StringType()),
    T.StructField('base64_encoded_img', T.StringType()),
    T.StructField('page_number', T.LongType()),
    T.StructField('response', T.StringType()),
    T.StructField('response_raw', T.StringType()),
    T.StructField('error', T.BooleanType())
])

MIO_RESPONSE_DPI = 72
LAYOUT_RESPONSE_DPI = 72

type_mapping = {
    "LAYOUT_LIST": "List",
    "LAYOUT_TEXT": "NarrativeText",
    "LAYOUT_TITLE": "Title",
    "LAYOUT_HEADER": "Header",
    "LAYOUT_TABLE": "Table",
    "LAYOUT_FIGURE": "Image",
}

def map_new_to_old_api_response(new_response, height, width, original_height, original_width):
    model_dpi, mio_dpi = LAYOUT_RESPONSE_DPI, MIO_RESPONSE_DPI
    dpi_ratio = model_dpi / mio_dpi
    old_response = []
    for block in new_response["blocks"]:
        old_block = {
            "element_id": block["id"],
            "text": block["text"],
            "type": type_mapping.get(block["blockType"], "UncategorizedText"),
            "metadata": {
                "coordinates": {
                    # Assuming fixed layout dimensions for simplicity
                    "layout_height": height,
                    "layout_width": width,
                    "points": [],  # This will need to be calculated based on boundingBox
                    "system": "PixelSpace",
                },
                "filetype": "PNG",
                "languages": block["languages"],
                "page_number": block["page"],
            },
        }
        # Simplified calculation for points based on boundingBox (assuming top-left origin and returned DPI of 200)
        bbox = block["geometry"]["boundingBox"]
        model_height, model_width = dpi_ratio * original_height, dpi_ratio * original_width
        scale_factor_h = height / model_height
        scale_factor_w = width / model_width
        top_left = (bbox["left"] * scale_factor_w, bbox["top"] * scale_factor_h)
        bottom_right = ((bbox["left"] + bbox["width"]) * scale_factor_w, (bbox["top"] + bbox["height"]) * scale_factor_h)
        old_block["metadata"]["coordinates"]["points"] = [
            [top_left[0], top_left[1]],
            [top_left[0], bottom_right[1]],
            [bottom_right[0], bottom_right[1]],
            [bottom_right[0], top_left[1]],
        ]
        if "confidence" in block:
            old_block["metadata"]["detection_class_prob"] = block["confidence"]
        old_response.append(old_block)
    return old_response



# Process one page of one document
def extract_all_text(media_item_rid, documents_media_set,
                    page_number, img_str, new_height, new_width, original_height, original_width):
    results = []
    error_occured = False

    try:
        # start_time = time.time()

        # Run the layout model
        raw_layout_features = documents_media_set.transform_media_item(
            media_item_rid,
            str(page_number - 1),
            {
                    "type": "documentToText",
                    "documentToText": {
                        "operation": {
                            "type": "extractLayoutAwareContent",
                            "extractLayoutAwareContent": {"parameters": {"languages": ["ENG"]}},
                        }
                    },
                },
                timeout_seconds=200,
        ).json()

        # time_elapsed = time.time() - start_time

        # ...
        mapped_layout_features = map_new_to_old_api_response(
                raw_layout_features, new_height, new_width, original_height, original_width
            )

        # ... 
        normalised_layout_features = resolve_media_bounding_boxes(mapped_layout_features)

        for idx, box in enumerate(normalised_layout_features):
            box["id"] = idx
            results.append(box)
    except Exception as e:
        results, raw_layout_features = [[str(e)]] * 2
        error_occured = True

    return json.dumps(results), json.dumps(raw_layout_features), error_occured


# Process one partition
def process_page(iterator, pdfs):
    for row in iterator:
        yield (
            row["page_pk"],
            *extract_all_text(row.mediaItemRid, pdfs, row.page_number, row.base64_encoded_img, row.new_height, row.new_width, row.original_height, row.original_width)
        )



@transform(
    output=Output("/.../layout_success"),
    error_output=Output("/.../layout_error"),
    relevant_pages=Input("ri.foundry.main.dataset.4513c366-4f8e-4631-bb3d-2c4733af867d"),
    pdfs=MediaSetInput("ri.mio.main.media-set.85f88366-bd0c-441e-97d4-5b752e5eeae8"),
)
def run_layout_model(ctx, output, error_output, relevant_pages, pdfs):
    relevant_pages = relevant_pages.dataframe()

    # Process each page of each document
    CONCURRENT_CALL_LIMIT = 12
    relevant_pages = relevant_pages.repartition(CONCURRENT_CALL_LIMIT)
    retrieved_layout_results = relevant_pages.rdd.mapPartitions(lambda row: process_page(row, pdfs)).toDF(INFERENCE_SCHEMA)

    # Re-enrich the original dataset
    retrieved_layout_results = relevant_pages.join(retrieved_layout_results, "page_pk", "left")

    # Split the success and failures and store them separately
    layout_results_completed = retrieved_layout_results.filter(F.col("error") == False)
    output.write_dataframe(layout_results_completed)

    layout_results_error = retrieved_layout_results.filter(F.col("error") == True)
    error_output.write_dataframe(layout_results_error)
  1. We need to post process the output of the layout model so that Workshop will be able to display the rectangle, taking care of edge cases like page rotations (some pages of a PDF can be rotated, and as we converted the page in pictures before extracting coordinates, we need to rotate the coordinates back).
import json
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from transforms.api import Input, Output, transform


def adjust_for_rotation(x1, y1, x2, y2, width, height, rotation):
    if rotation == 90:
        x1, y1 = y1, width - x1
        x2, y2 = y2, width - x2
    elif rotation == 180:
        x1, y1 = width - x1, height - y1
        x2, y2 = width - x2, height - y2
    elif rotation == 270:
        x1, y1 = height - y1, x1
        x2, y2 = height - y2, x2
    return x1, y1, x2, y2

def adjust_coordinates(original_width, original_height, image_width, image_height, x1, y1, x2, y2, rotation):
    # Calculate scale factors
    scale_factor_x = original_width / image_width
    scale_factor_y = original_height / image_height
    # Scale coordinates
    x1 *= scale_factor_x
    y1 *= scale_factor_y
    x2 *= scale_factor_x
    y2 *= scale_factor_y
    # Adjust for rotation
    x1, y1, x2, y2 = adjust_for_rotation(x1, y1, x2, y2, original_width, original_height, rotation)
    return x1, y1, x2, y2
    
    
# Explode json string response and pivot into cols
def process(row):
    rows = []
    # Load the layout blocks
    layout_response = json.loads(row.response)
    # For each block in the layout detected
    for idx, parsed_input in enumerate(layout_response):
        # Extract the coordinates of the current block
        box_coordinates = [
            *parsed_input["metadata"]["coordinates"]["points"][0],
            *parsed_input["metadata"]["coordinates"]["points"][2],
        ]
        # We adjust the coordinate, as we want them to be valid on the original PDF. 
        # Given there was a PDF > Picture conversion, the scale and rotation might be invalid. Hence we correct both.
        box_coordinates_adjusted = utils.adjust_coordinates(row.original_width, row.original_height, row.new_width, row.new_height, 
                                                            box_coordinates[0], box_coordinates[1], box_coordinates[2], box_coordinates[3], row.page_rotation)
        # Generate a row for this block
        rows.append({
            "mediaItemRid": row.mediaItemRid,
            "page_pk": row.page_pk,
            "page_number": row.page_number,
            "ordering_in_page": idx + 1,
            "content": parsed_input["text"],
            "content_type": parsed_input["type"].title(),
            "class_probability": parsed_input["metadata"].get("detection_class_prob", 0),
            "bounding_box": ", ".join([str(round(coord, 2)) for coord in box_coordinates]),
            "pdf_bounding_box": [
                json.dumps({
                    "x1": round(box_coordinates_adjusted[0], 2),
                    "y1": round(box_coordinates_adjusted[1], 2),
                    "x2": round(box_coordinates_adjusted[2], 2),
                    "y2": round(box_coordinates_adjusted[3], 2),
                })
            ],  # e.g.  [ "{"x1":0.0, "y1":2.5, "x2":50.0, "y2":10.1}", "{"x1":2.0, "y1":12, "x2":35.0, "y2":14.2}" ]
        })
    return rows

@transform(
    output=Output(
        "/.../layout_processed"
    ),
    responses=Input(
        "/.../layout_success"
    ),
)
def compute(output, responses):
    responses = responses.dataframe().filter(F.col("response").isNotNull())
    responses = responses.rdd.flatMap(process).toDF()
    responses = responses.withColumn(
        "content_id", F.sha1(F.concat_ws("-", "page_pk", "content_type", "content", "ordering_in_page"))
    )
    responses = responses.withColumn(
        "title",
        F.concat_ws(
            "",
            F.initcap(F.col("content_type")),
            F.lit(" "),
            F.col("ordering_in_page"),
            F.lit(" - "),
            F.lit("Page "),
            "page_number",
        ),
    )
    responses = responses.withColumn(
        "content_format",
        F.when(F.col("content_type") == "Table", F.lit("CSV")).otherwise(F.lit("Detailed Description")),
    )
    # Define the order of columns to sort by
    order_of_columns = ["page_number", "ordering_in_page"]
    # Define the Window spec for the ordering
    window_spec = Window.orderBy(*[F.asc(col_name) for col_name in order_of_columns])
    # Cast long columnes to integer
    responses = responses.withColumn("ordering_in_page", F.col("ordering_in_page").cast("integer"))
    responses = responses.withColumn("page_number", F.col("page_number").cast("integer"))
    # Add a new column order with the integer order of the rows
    responses = responses.withColumn("ordering_in_document", F.row_number().over(window_spec))
    output.write_dataframe(responses)

At this stage, we obtained a pipeline like the picture below: the mediaset of documents is split in pages, then split in layout chunks, and we can use this output to create a “Layout chunk” Object Type which we will use in Workshop.

Hence we have:

  • a dataset containing “the document references”
  • a dataset containing chunks (aka “pieces of the document”), with the page they belong to, their content, and with a reference to the document they belong to, and the exact coordinates (rectangle) of the “chunk”
1 Like