We are using transforms for an API integration with an external data source. This data source requires a unique access token per entity in our data, so we have been storing the access tokens as a data source. Access tokens expire every 24 hours and must be refreshed through a call that produces a new access token and refresh token (again unique per entity).
We are running into an issue where our ideal transform creates a circular dependency since it needs to reference the access token and the refresh token from the prior API call (stored as a dataset). How do we achieve our desired outcome without a circular dependency?
Here is how we have solved this, using incremental transforms. You probably have to add the @use_external_systems decorator in case you need to use egress policies. Make sure to secure the credentials dataset with a Marking.
import logging
from typing import Union
import requests
import pyspark.sql.types as T
import pyspark
from transforms.api import transform, Input, Output, incremental
TransformInput = Union[
"transforms.api.TransformInput", "transforms.api.IncrementalTransformInput"
]
TransformOutput = Union[
"transforms.api.TransformOutput", "transforms.api.IncrementalTransformOutput"
]
LOGGER = logging.getLogger(__name__)
def _incremental_refresh_token_schema():
return T.StructType([T.StructField("refresh_token", T.StringType())])
@incremental(
require_incremental=True,
semantic_version=1,
snapshot_inputs=["credentials"],
)
@transform(
credentials=Input("/.../credentials"),
output_dataset=Output("/.../path-to-output-dataset"),
incremental_refresh_token=Output("/.../incremental_refresh_token"),
)
def ingest_transform(
ctx,
credentials: TransformInput,
output_dataset: TransformOutput,
incremental_refresh_token: TransformOutput,
):
client_id, client_secret, initial_refresh_token = _load_credentials(credentials)
if not ctx.is_incremental:
LOGGER.info("First run of incremental code")
refresh_token = initial_refresh_token
else:
LOGGER.info(
"Running in incremental mode, getting refresh_token from output incremental_refresh_token"
)
spark_df = incremental_refresh_token.dataframe(
"previous", schema=_incremental_refresh_token_schema()
)
refresh_token_pdf = spark_df.select("refresh_token").toPandas()
if refresh_token_pdf.shape[0] == 0:
LOGGER.info(
"Running incremental, but incremental_refresh_token is still empty. "
"Falling back to initial_refresh_token."
)
refresh_token = initial_refresh_token
else:
refresh_token = refresh_token_pdf["refresh_token"].iloc[0]
fresh_credentials = _refresh_grant(client_id, client_secret, refresh_token)
token_spark_df_to_write = ctx.spark_session.createDataFrame(
[[fresh_credentials["refresh_token"]]],
schema=_incremental_refresh_token_schema(),
)
incremental_refresh_token.set_mode("replace")
incremental_refresh_token.write_dataframe(token_spark_df_to_write)
ingested_data_spark_df = _ingestion(fresh_credentials["access_token"])
output_dataset.set_mode("replace")
output_dataset.write_dataframe(ingested_data_spark_df)
def _load_credentials(credentials: TransformInput) -> tuple:
as_pd = credentials.dataframe().toPandas()
return (
as_pd["client_id"].iloc[0],
as_pd["client_secret"].iloc[0],
as_pd["initial_refresh_token"].iloc[0],
)
def _refresh_grant(client_id: str, client_secret: str, refresh_token: str) -> dict:
data = {
"grant_type": "refresh_token",
"scope": "...",
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
}
response = requests.post("https://.../token", data=data)
response.raise_for_status()
return response.json()
def _ingestion(ctx, access_token: str) -> pyspark.sql.DataFrame:
# Use access_token to ingest the data and return as spark dataframe
return ctx.spark_session.createDataFrame()