FastAPI Service#

import argparse
import sys
from io import BytesIO
import torch
import numpy as np
import satellighte as sat
import uvicorn
from PIL import Image

from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.openapi.utils import get_openapi

tags_metadata = [
    {
        "name": "Predict",
        "description": "Satellighte is an image classification.",
        "externalDocs": {
            "description": "External Docs for Library: ",
            "url": "https://satellighte.readthedocs.io/",
        },
    },
    {
        "name": "Information",
        "description": "Information about the library and the service.",
        "externalDocs": {
            "description": "Project Homepage: ",
            "url": "https://canturan10.github.io/satellighte/",
        },
    },
]

app = FastAPI()


def custom_openapi():
    if app.openapi_schema:
        return app.openapi_schema
    openapi_schema = get_openapi(
        title="Satellighte API",
        version=sat.__version__,
        description=sat.__description__,
        routes=app.routes,
        tags=tags_metadata,
        license_info={
            "name": sat.__license__,
            "url": sat.__license_url__,
        },
        contact={
            "name": sat.__author__,
        },
    )
    openapi_schema["info"]["x-logo"] = {
        "url": "https://raw.githubusercontent.com/canturan10/satellighte/master/src/satellighte.png"
    }
    app.openapi_schema = openapi_schema
    return app.openapi_schema


app.openapi = custom_openapi


@app.on_event("startup")
def load_artifacts():
    if not hasattr(app.state, "model"):
        app.state.model = sat.Classifier.from_pretrained("mobilenetv2_default_eurosat")
        app.state.model.eval()
        app.state.model.to("cuda" if torch.cuda.is_available() else "cpu")


@app.on_event("shutdown")
def empty_cache():
    # clear Cuda memory
    torch.cuda.empty_cache()


def read_imagefile(data) -> Image.Image:
    image = Image.open(BytesIO(data))
    return image


@app.get("/", tags=["Information"])
def read_root():
    return {
        "Satellighte": f"{sat.__version__}",
        "Description": f"{sat.__description__}",
        "Author": f"{sat.__author__}",
    }


@app.post("/predict/", tags=["Predict"])
async def predict(file: UploadFile = File(...)):
    if file.content_type.startswith("image/") is False:
        raise HTTPException(
            status_code=400,
            detail=f"File '{file.filename}' is not an image.",
        )

    try:
        contents = await file.read()
        image = np.array(read_imagefile(contents).convert("RGB"))
        predicted_class = app.state.model.predict(image)

        return predicted_class
    except Exception:
        e_info = sys.exc_info()[1]
        raise HTTPException(
            status_code=500,
            detail=str(e_info),
        )


if __name__ == "__main__":
    from pathlib import Path

    parser = argparse.ArgumentParser(description="Runs the API server.")
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to run the API on.",
    )
    parser.add_argument(
        "--port",
        help="The port to listen for requests on.",
        type=int,
        default=8080,
    )
    parser.add_argument(
        "--workers",
        help="Number of workers to use.",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--reload",
        help="Reload the model on each request.",
        action="store_true",
    )
    parser.add_argument(
        "--use-colors",
        help="Enable user-friendly color output.",
        action="store_true",
    )

    args = parser.parse_args()
    uvicorn.run(
        f"{Path(__file__).stem}:app",
        host=args.host,
        port=args.port,
        workers=args.workers,
        reload=args.reload,
        use_colors=args.use_colors,
    )