Trouble getting concurrency with Asyncio in Code Repos

Does anyone have an example of an external transform using asyncio to sync data concurrently? I am trying to set up an example on my end, but am not getting it to loop through the pages of the API I’m hitting, even though it works fine when I use multithreading instead.

Code:

from transforms.api import transform, Output, incremental, lightweight
from transforms.external.systems import external_systems, Source
import pandas as pd
import yaml
import logging
import asyncio
import aiohttp

log = logging.getLogger(__name__)


@lightweight()
@incremental(require_incremental=True)
@external_systems(
    pokeSource=Source("ri.magritte..source.72b1293a-2cd3-467a-af37-79e83888a565")
)
@transform(
    out=Output("ri.foundry.main.dataset.0ca50951-c5a8-4722-a168-9ad3bce02fd5"),
)
async def compute(pokeSource, out):
    out_fs = out.filesystem()
    state_filename = "_state.yaml"
    state = {"start_url": "https://pokeapi.co/api/v2/pokemon?limit=100&offset=0"}

    try:
        with out_fs.open(state_filename, mode='r') as state_file:
            state = yaml.safe_load(state_file)
            logging.info(f"state file found, continuing from : {state}")
    except Exception:
        logging.warn("state file not found, starting over from default state")

    async def fetch_pokemon_data(session, url):
        async with session.get(url) as response:
            return await response.json()

    async with aiohttp.ClientSession() as session:
        response0 = await fetch_pokemon_data(session, state["start_url"])
        urls = [response0["next"]]  # Start with the next URL

        # Fetch more URLs if needed
        while len(urls) < 5 and urls[-1] is not None:  # Adjust the number of parallel requests as needed
            resp = await fetch_pokemon_data(session, urls[-1])
            if resp["next"]:
                urls.append(resp["next"])
            else:
                break

        tasks = [fetch_pokemon_data(session, url) for url in urls]
        responses = await asyncio.gather(*tasks)

    new_data = []
    for url, responseJson in zip(urls, responses):
        for pokemon in responseJson["results"]:
            new_data.append({"name": pokemon["name"], "url": url})

    new_df = pd.DataFrame(new_data)

    # Write the DataFrame directly using the lightweight API
    out.write_table(new_df)

    state["start_url"] = urls[-1] if urls else state["start_url"]

    with out_fs.open(state_filename, "w") as state_file:
        yaml.dump(state, state_file)

Here are a few code snippets using asyncio to perform a lot of API calls quickly (hitting >20k/min). I’m not sure how it will help you in your particular case, but at least that’s an example of working Asyncio implementation in lightweight.

import asyncio
from transforms.api import Input, Output, lightweight, transform
from myproject.datasets import utils


@lightweight()
@transform(
    out=Output(
        "ri.foundry.main.dataset.AAA"
    ),
    out_error=Output(
        "ri.foundry.main.dataset.BBB"
    ),
    source_dataset=Input("ri.foundry.main.dataset.CCC"),
)
def compute(source_dataset, out, out_error):
    input_df = source_dataset.polars(lazy=False)

    # Array of results
    good_responses_array = []
    bad_responses_array = []

    # Compute in a for loop
    headers = {"some_header": "example"}
    good_responses_array, bad_responses_array = asyncio.run(
        utils.asyncio_threadpool_with_semaphore(input_df, "some_url.com", headers, 100)
    )

    # Write to outputs
    utils.write_to_output(good_responses_array, out)
    utils.write_to_output(bad_responses_array, out_error)

async def asyncio_threadpool_with_semaphore(input_df, url, headers, max_concurrent_tasks=5):
    semaphore = asyncio.Semaphore(max_concurrent_tasks)
    good_responses_array = []  # To collect all responses
    bad_responses_array = []

    async with aiohttp.ClientSession() as session:

        async def limit_parallelism(pk, message):
            async with semaphore:
                logging.info(f"Task started for {pk}")
                result = await fake_api_call_async(url, message, headers)
                return pk, result

        tasks = []
        for curr_row in input_df.rows(named=True):
            message = json.loads(curr_row["data_payload"])
            pk = curr_row["pk"]
            task = asyncio.create_task(limit_parallelism(pk, message))
            tasks.append(task)
            logging.info(f"Task submitted for {pk}")

        responses = await asyncio.gather(*tasks, return_exceptions=True)

        for pk, result in responses:
            if isinstance(result, Exception) or result.status_code != 200:
                bad_responses_array.append(result)
            else:
                good_responses_array.append(result)

    return good_responses_array, bad_responses_array


def write_to_output(responses_array, response_dataset):
    # Write all the bad responses to the a dedicated output dataset
    if responses_array:
        response_df = pl.DataFrame(responses_array)
        response_df = response_df.with_columns(pl.lit(date.today()).alias("api_call_date"))
        response_dataset.write_table(response_df)
    else:
        # Define the schema as a list of tuples (column name, data type)
        schema = [("pk", pl.String), ("status", pl.String), ("error", pl.String)]
        # Create an empty DataFrame with the specified schema
        df = pl.DataFrame(schema=schema)
        response_dataset.write_table(df)