Yes it works without any issue both in build and live deployement !
For dependency, I mostly relied on PyPi for transformers, etc…
Here is my .ipynb, copy it as .txt and rename it to .ipynb. I had to manual obfuscate some variable name so don’t be surprised if it is not tally . As well, I am using 1 T4 GPU for deployement:
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "000bccb4-f2fb-40ce-873a-cba8bda8dbdf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -q transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3816cf09-e38f-4722-83c5-0924565c43a2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -q transformers==4.46.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad58cc28-babd-4898-a932-cd0fdb320919",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -q bitsandbytes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9804f454-df17-4cd3-af9d-a7223e8fdc56",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -q 'accelerate>=0.26.0'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69c259cb-b109-4519-bf74-2eed66619c7e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install numpy==1.24.3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fc788c5-4d61-4651-8d95-d00b25bca194",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -q pillow"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e0a77eb-bbe1-4ea9-864d-1d34b4c1b63a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install pymupdf"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cb3860e-88fd-41c7-9335-f7f3d0a77b5a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install requests"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a52ad22-8fc1-4591-964c-9da85d157be2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!maestro env conda install palantir_models palantir_models_serializers \"ontology_sdk>=0.1.0\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5bc4876-d363-4150-a186-5ffb0eb49ae2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"from ontology_sdk import FoundryClient\n",
"\n",
"client = FoundryClient()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a2c4947-ad0b-476f-83fd-16e9985fd730",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from foundry.transforms import Dataset, Column\n",
"\n",
"smolvlm = Dataset.get(\"smolvlm\")\n",
"smolvlm_files = smolvlm.files().download()\n",
"model_path=smolvlm_files['README.md'].replace('README.md','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f4254ca-c8f1-4447-8235-88f6b22f0493",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# # Load the autoreload extension and automatically reload all modules\n",
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"from palantir_models.code_workspaces import ModelOutput\n",
"from SmolVLM_adapter_with_media_reference import SmolVLMAdapter # Update if class or file name changes\n",
"\n",
"# Wrap the trained model in a model adapter for Foundry\n",
"SmolVLM_adapter = SmolVLMAdapter(model_path)\n",
"\n",
"# Get a writable reference to your model resource.\n",
"model_output = ModelOutput(\"smolvlm___model_asset\")\n",
"model_output.publish(SmolVLM_adapter) # Publishes the model to Foundry"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23e25909-2e7b-4137-a8e3-0f33459905c5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Load the autoreload extension and automatically reload all modules\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"from palantir_models.code_workspaces import ModelOutput\n",
"from smolvlm___model_asset import SmolVLMAdapter # Update if class or file name changes\n",
"\n",
"# Wrap the trained model in a model adapter for Foundry\n",
"smolvlm_model_asset = SmolVLMAdapter(model_path)\n",
"\n",
"# Get a writable reference to your model resource.\n",
"model_output = ModelOutput(\"smolvlm___model_asset\")\n",
"model_output.publish(smolvlm___model_asset) # Publishes the model to Foundry"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f887e53e-f1fd-485b-9a37-80b1cb6b24cd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"smolvlm___model_asset_client.disable()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "293a65bb-af61-4b26-b119-569f6cda729d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from palantir_models.code_workspaces import ModelInput\n",
"smolvlm___model_asset = ModelInput(\"smolvlm___model_asset\")\n",
"# Specify deployment resources such as cpus, gpus, memory (in bytes), and maximum number of replicas\n",
"# Gpus can be allocated by passing the required number per replica to the gpu argument\n",
"# Specify deployment resources such as cpus, gpus, memory (in bytes), and maximum number of replicas\n",
"# Gpus can be allocated by passing the required number per replica to the gpu argument\n",
"smolvlm___model_asset_client = smolvlm___model_asset.deploy(cpu=2, memory=\"4G\",gpu=1, max_replicas=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82731d31-60cd-43c3-a3d2-bd8198de93c5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(smolvlm___model_asset_client.get_health_checks(timeout=5))\n",
"print(smolvlm___model_asset_client.is_ready(timeout=5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "927922fd-aecd-45a6-a336-2e416c0696ea",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"inference_data_df = {\n",
" \"inference_data\": [\n",
" {\n",
" \"prompt\": \"Describe the document\",\n",
" \"schema_str\": \"\",\n",
" \"max_new_tokens\": 128,\n",
" \"temperature\": 0.3\n",
" }\n",
" ],\n",
" \"media_item\": [ # ✅ ObjectSet expects a list\n",
" {\n",
" \"primaryKey\": {\n",
" \"uuid\": \"871d228d-9036-4994-964f-04d08209d040\"\n",
" }\n",
" }\n",
" ]\n",
"}\n",
"\n",
"smolvlm___model_asset_client.predict(inference_data_df)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e580376-9707-458c-b19c-461c0e42e349",
"metadata": {},
"outputs": [],
"source": [
"smolvlm___model_asset_client.disable()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c11e357e-fe7a-413a-8b81-bac6ceef8eac",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"smolvlm___model_asset_client"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [user-default]",
"language": "python",
"name": "conda-env-user-default-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Model Adapter with Media Reference:
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
import os
import json
import gc
from typing import List, Optional
import shutil
import pandas as pd
import torch
import tempfile
from io import BytesIO
import base64
from PIL import Image
import logging
from pdf2image import convert_from_bytes
import requests
from models_api.models_api_executable import (
ContainerizedApplicationContext,
ExternalModelExecutionContext,
)
from palantir_models.models import (
ModelAdapter,
PythonEnvironment,
ModelStateReader,
ModelStateWriter,
)
from palantir_models.models.api import (
DFType,
ModelApi,
ModelApiColumn,
ModelInput,
ModelOutput,
)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
auth_token = <YOUR_AUTH_TOKEN>
class APIError(Exception):
def __init__(self, status_code, data):
self.status_code = status_code
self.data = data
def __str__(self):
return '[{}]\n{}'.format(self.status_code, self.data)
class APIClient:
def __init__(self, auth_token: str):
self.auth_token = auth_token
self.base_url = "https://palantir_hostname.com/api/v2"
def get_binary_from_pdf(self, mediaset_rid: str, media_item_rid: str):
headers = {"Authorization": self.auth_token, "Content-Type": "application/json"}
url = f"{self.base_url}/mediasets/{mediaset_rid}/items/{media_item_rid}/content"
# Paramètre preview=true important!
response = requests.get(url, headers=headers, params={"preview": "true"})
if int(response.status_code / 100) not in (1, 2, 3):
raise APIError(response.status_code, response.text)
return response.content
class SmolVLMAdapter(ModelAdapter):
INPUT_COLUMN: str = "prompt"
GENERATION_COLUMN: str = "prediction"
def __init__(self, model_dir=None, processor=None, model=None, state_reader=None, auth_token=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
self.auth_token = auth_token
self.api_client = APIClient(self.auth_token)
if model_dir:
self.processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
self.model = AutoModelForVision2Seq.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
_attn_implementation="eager",
).to(self.device)
elif state_reader:
self.pretrained_model_tmp_dir = tempfile.TemporaryDirectory()
state_reader.extract(self.pretrained_model_tmp_dir.name)
self.processor = AutoProcessor.from_pretrained(self.pretrained_model_tmp_dir.name, trust_remote_code=True)
self.model = AutoModelForVision2Seq.from_pretrained(
self.pretrained_model_tmp_dir.name,
torch_dtype=torch.bfloat16,
_attn_implementation="eager",
).to(self.device)
else:
self.processor = processor
self.model = model
@classmethod
def load(
cls,
state_reader: ModelStateReader,
container_context: Optional[ContainerizedApplicationContext] = None,
external_model_context: Optional[ExternalModelExecutionContext] = None,
):
return cls(state_reader=state_reader, auth_token=auth_token)
def save(self, state_writer: ModelStateWriter) -> None:
"""Save the processor and model to state."""
model_temp_dir = tempfile.mkdtemp()
try:
self.processor.save_pretrained(model_temp_dir, from_pt=True)
self.model.save_pretrained(model_temp_dir, safe_serialization=True, from_pt=True)
for f in os.listdir(model_temp_dir):
local_name = os.path.join(model_temp_dir, f)
with state_writer.open(f, "wb") as remote_file:
with open(local_name, "rb") as local_file:
shutil.copyfileobj(local_file, remote_file)
finally:
shutil.rmtree(model_temp_dir)
@classmethod
def api(cls):
inputs = [
ModelInput.Parameter(name="media_reference", type=str, required=True),
ModelInput.Parameter(name="prompt", type=str, required=True),
ModelInput.Parameter(name="schema_str", type=str, required=False),
ModelInput.Parameter(name="max_new_tokens", type=int, required=False, default=512),
ModelInput.Parameter(name="temperature", type=float, required=False, default=0.1),
]
outputs = [
ModelOutput.Parameter(name=cls.GENERATION_COLUMN, type=str, required=True)
]
return ModelApi(inputs, outputs)
def extract_media_info(self, media_reference: str):
"""Extract media_set_rid and media_item_rid from media_reference JSON string."""
try:
media_reference_obj = json.loads(media_reference)
# Extract media_set_rid and media_item_rid from the nested structure
media_set_rid = media_reference_obj['reference']['mediaSetViewItem']['mediaSetRid']
media_item_rid = media_reference_obj['reference']['mediaSetViewItem']['mediaItemRid']
return media_set_rid, media_item_rid
except KeyError as e:
logger.error(f"Error extracting media information: Missing key {e}")
raise ValueError("Invalid media_reference format")
except json.JSONDecodeError:
logger.error("Error decoding media_reference JSON")
raise ValueError("Invalid JSON format for media_reference")
def predict(self, media_reference, prompt, schema_str=None, max_new_tokens=512, temperature=0.1):
"""Implementation of predict method called by run_inference."""
try:
# Parse the media_reference JSON string to extract mediaSetRid and mediaItemRid
try:
media_reference_dict = json.loads(media_reference)
mediaSetRid = media_reference_dict['reference']['mediaSetViewItem']['mediaSetRid']
mediaItemRid = media_reference_dict['reference']['mediaSetViewItem']['mediaItemRid']
except (KeyError, json.JSONDecodeError) as e:
logger.error(f"Failed to parse media_reference: {str(e)}", exc_info=True)
return f"ERROR: Failed to extract media reference details"
logger.info(f"Fetched mediaSetRid: {mediaSetRid}, mediaItemRid: {mediaItemRid}")
# Fetch media content from the API
media_item = self.api_client.get_binary_from_pdf(mediaSetRid, mediaItemRid)
if not media_item:
logger.error("Media item not found or empty")
return "ERROR: Media item not found"
logger.info(f"Media file downloaded, size: {len(media_item)} bytes")
# Proceed with PDF to image conversion or image processing as before
if len(media_item) > 4 and media_item[:4] == b'%PDF':
logger.info("Converting PDF to image...")
images = convert_from_bytes(media_item, dpi=300)
logger.info(f"Conversion returned {len(images) if images else 0} images")
if not images or len(images) == 0:
logger.error("PDF conversion produced no images")
return "ERROR: Failed to convert PDF to image - no pages found"
image = images[0] # Use the first page
logger.info(f"Using first page image, size: {image.size}")
else:
logger.info("Processing as direct image...")
image = Image.open(BytesIO(media_item)).convert("RGB")
logger.info(f"Image loaded, size: {image.size}")
if "<image>" not in prompt:
formatted_prompt = f"<image> {prompt}"
else:
formatted_prompt = prompt
logger.info(f"Using prompt: {formatted_prompt}")
model_inputs = self.processor(images=[image], text=[formatted_prompt], return_tensors="pt")
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items() if isinstance(v, torch.Tensor)}
logger.info(f"Model input keys: {list(model_inputs.keys())}")
with torch.no_grad():
generated_outputs = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
eos_token_id=self.processor.tokenizer.eos_token_id,
)
logger.info(f"Generation complete, shape: {generated_outputs.shape}")
results = self.processor.batch_decode(generated_outputs, skip_special_tokens=True)
if not results or len(results) == 0:
logger.error("No results from batch_decode")
return "ERROR: No text was generated"
logger.info(f"Decoded {len(results)} results")
result = results[0]
logger.info(f"Generated text: {result}")
return result
except Exception as e:
logger.error(f"Error during inference: {str(e)}", exc_info=True)
return f"ERROR: {str(e)}"
def run_inference(self, inputs, outputs):
"""Overrides the default run_inference method."""
result = self.predict(
media_reference=inputs.media_reference, # Pass media_reference instead of separate ids
prompt=inputs.prompt,
schema_str=inputs.schema_str,
max_new_tokens=inputs.max_new_tokens or 512,
temperature=inputs.temperature or 0.1
)
outputs.prediction.write(result)