Making a Model Adapter for Image Data

As the subject states, I’m having problems making a model adapter for image data. Specifically, I don’t work around the api method. For all of the examples in the documetation, pm.Pandas() is used. Of course, for my task, I need images as inputs, not Pandas. When I saw the list of available functions in the palantir_models library, these seemed promising:

  • Object
  • ObjectSet
  • Pandas

However, I don’t know how to use them, and the lack of examples is making this workflow pretty difficult. Here’s what I have so far, with comments abstracting what I don’t know how to do:

class ImageGeolocatorV1Adapter(pm.ModelAdapter):

    @pm.auto_serialize(
        model=PytorchStateSerializer()
    )
    def __init__(self, model):
        self.model = model

        # Input image preprocessing steps
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    @classmethod
    def api(cls):
        # Define inputs, outputs
        
        return inputs, outputs

    def predict(self, image_in):
        # Some way to load the images 

        image = self.transform(image).unsqueeze(0).to(device) # Convert to Tensor
        image = image.to(torch.device('cpu'))
        
        # Run the model to predict latitude and longitude
        self.model.eval()
        with torch.no_grad():
            prediction = self.model(image)  
            
        # Return the output (latitude, longitude) as the most convenient format

Any help with this?

Hi!

If you define your inputs (in #api) using a MediaReference (from palantir_models import MediaReference) this should allow you to pass in an image;

inputs = [
            ModelInput.Tabular(
                name="inference_data",
                df_type=DFType.PANDAS,
                columns=[
                    ModelApiColumn(name="media_item_rid", type=str, required=True),
                    ModelApiColumn(name="media_reference", type=MediaReference, required=True),
                ],
            )
        ]

You can then in your #predict access the image using

input_media: MediaReference = row["media_reference"].get_media_item()
image = Image.open(input_media)

I hope this helps!

-Eirik

Hey @adampsu @eiriklt ,

Media References model inputs are a little more complicated, and would not work with just the rid and MediaReference input.

It is easier to pass the base64 encoded image as a string input and decode it in the adapter. This might look like:

class MyModelAdapter(pm.ModelAdapter):
    ...
    
    @classmethod
    def api(cls):
        inputs = {
            "image_base64": pm.Parameter(type=str),
        }
        outputs = {
            "your_output": pm.Parameter(type=str),
        }
        return inputs, outputs
    
    def predict(self, image_base64: str) -> str:
        # decode the base64 and pass it into the model
        return "..."

If this does not work for your case or you have any other questions feel free to reach out

Hey @william,

Thanks for flagging! Are you sure this is still the case? It is explicitly listed as supported here, and the source code I posted above is from an example using it (Disclaimer; I did not verify that example).

Of course, using this requires wrapping your input as tabular so if you want to avoid that using base64 is probably easier. Media references are also not supported as outputs.

-Eirik

Hey! This works for the most part, and it is relatively easy to implement. I’ve submitted the model, but I’m getting a different error when testing it out via the Sandbox tool.

TypeError: Compose.__call__() got an unexpected keyword argument 'image_base64'

I’m following an example, and my inputs and outputs match pretty closely. I must be overlooking something entirely; I can’t quite tell what the error is. The model’s .predict() function seems to work fine (I was able to pass in an input in my Code Workspace and it worked), but there seems to be an error elsewhere.

class ImageGeolocatorV1Adapter(pm.ModelAdapter):
    @pm.auto_serialize(
        model=pms.DillSerializer()
    )                  
    def __init__(self, model):
        self.model = model
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(), 
            transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
        ])

    @classmethod
    def api(cls):
        inputs = {
            "image_base64": pm.Parameter(type=str),
        }
        outputs = {
            "coordinates": pm.Parameter(type=dict),
        }
        return inputs, outputs

    def predict(self, image_base64):
        image_data = base64.b64decode(image_base64)
        
        # Convert the byte data to an image
        image = Image.open(BytesIO(image_data)).convert('RGB')
        image = self.transform(image).unsqueeze(0)

        with torch.no_grad():
            outputs = self.model(image)
            latitude, longitude = outputs.squeeze().cpu().numpy()

        result = {'latitude': latitude.item(), 'longitude': longitude.item()}

        return result

Any help would be appreciated.

Do you have a stacktrace? Also what does the JSON input you are passing in look like? Could you also send the imports for the model adapter?

This potentially looks like a pytorch error

Hello William,

I’m not sure it’s a PyTorch error. I’m able to use the model adapter in Jupyter Code Workspaces; it works fine. The below image shows my process:

And here’s the result when I paste the same image into the Sandbox environment:

The stack trace is included. Please let me know if you notice something; I’m grateful for your help thus far.

Could you send your entire model adapter (with imports)?

Also in code workspaces, could you try calling image_geolocator_v1_adapter.transform(image_base64="...")

I’ve found the issue! Seems like the image_base64 input was getting passed into all constructor members, which wasn’t an intended effect. I simply had to move my self.transform variable within the self.predict() method. Thanks for the help!!