How can I process a dataset by chunk?

I have a dataset where I have a large amount of rows.
For some reason - e.g. because my particular processing is particularly slow due to some library I’m using - I want to be able to process this dataset “chunk by chunk”.

I can’t use standard incremental transforms, because my source dataset is not an incremental dataset.

How can I process this dataset by chunk ?

You can use the @incremental decorator still and tweak the behavior in the transform so that it will:

  • Read the input dataset
  • Remove the part that was already processed (if any) by comparing with the current output
  • Limit the remaining part to the size of the chunk we want to process
from pyspark.sql import functions as F
from pyspark.sql import types as T
from transforms.api import transform, Input, Output, incremental

'''
Intent: If downstream processing is too heavy and does not pass in one build, we can "chunk" the input dataset in multiple smaller transactions. 
The goal is at each run to write N rows more on the output, so that downstream can process it, and then repeat (more rows, process, more rows, process, ...)
'''

@incremental(snapshot_inputs=["input_dataset"], semantic_version=1)
# @transform to have more control over inputs and outputs.
@transform(
    output_dataset=Output("/path/dataset_only_50_by_50"),
    input_dataset=Input("/path/dataset_snapshot"),
)
def example_chunking_transform(ctx, input_dataset, output_dataset):
    # We enforce the read of the input dataframe as a snapshot, via the snapshot_input decorator
    input_df_all_dataframe = input_dataset.dataframe(mode="current")

    if ctx._is_incremental:
        # We read the current output to see what we already processed in previous builds
        # Note: We have to specify the schema for the first run
        out_schema = T.StructType([
            T.StructField('key', T.StringType()),
            T.StructField('when', T.TimestampType())
        ])
        output_df_previous_dataframe = output_dataset.dataframe('current', out_schema)

        # ==== Example processing here ====
        # We diff the input with the current output, to find the "new rows".
        # We do this with a LEFT ANTI join : A - B <==> A LEFT ANTI B
        KEY = ["key"]
        new_rows_df = input_df_all_dataframe.join(output_df_previous_dataframe, how="left_anti", on=KEY)
    else:
        # On first run
        new_rows_df = input_df_all_dataframe

    # We add a timestamp for easier tracking/understanding of the example
    new_rows_df = new_rows_df.withColumn('incremental_ts', F.current_timestamp())

    # We limit at N rows only
    new_rows_df = new_rows_df.limit(50)

    # ==== End of example processing ====

    # This will append rows. We don't need to explicitly specify it via set_mode, 
    # as the transform will run in incremental mode on each run, as so will automatically happen at each build.
    # output_dataset.set_mode("modify")
    output_dataset.write_dataframe(new_rows_df)

Downstream, you can use a standard incremental transform, so whenever you build, you will build the dataset “that adds a chunk” and the downstream one that “process the chunks” and then repeat the operation as many time as needed - of course this can be done automatically with to a schedule.

3 Likes

[edit] reply above is pretty much the same but with better code!

Hey Vincent,

One way to do it would be to process only batch by batch with two incremental outputs: one with the result of your processing (assuming it’s a row by row processing, not requiring the full dataset) and one with the logs of the rows that were already processed to be filtered out from the next build.

This assumes you have a way to know which row you’ve processed already, like a uuid or artificial batch ids.

You would use something like
@incremental(snapshot_inputs=[“input”])

logs=Output(),
out=Output(),
input=Input(),

processed = logs.dataframe(“previous”)
current_batch = input.join(processed, “my_id”, how=“left_anti”)

logs.write_dataframe(current_batch.unionByName(current_batch.select(“my_id”)))

(pardon my code written without any kind of validation)

1 Like

If this approach isn’t possible due to a large scale of data, you can process things file by file, using pyspark.sql.functions.input_file_name or file-based processing