Onnx Export#

import os
import tempfile
import torch
import argparse
import onnx

import torch

import satellighte as sat


def parse_arguments():
    """
    Parse command line arguments.

    Returns: Parsed arguments
    """
    arg = argparse.ArgumentParser()
    arg.add_argument(
        "--model_name",
        type=str,
        default=sat.available_models()[0],
        choices=sat.available_models(),
        help="Model architecture",
    )
    arg.add_argument(
        "--version",
        type=str,
        help="Model version",
    )
    arg.add_argument(
        "--target",
        "-t",
        type=str,
        help="Target path to save the model",
    )
    arg.add_argument(
        "--quantize",
        "-q",
        action="store_true",
        help="Quantize the model",
    )
    arg.add_argument(
        "--opset-version",
        type=int,
        default=11,
        help="Onnx opset version",
    )
    return arg.parse_args()


def main(args):
    """
    Main function.

    Args:
        args : Parsed arguments
    """
    # pylint: disable=no-member
    if args.version:
        if args.version not in sat.get_model_versions(args.model_name):
            raise ValueError(
                f"model version {args.version} not available for model {args.model_name}, available versions: {sat.get_model_versions(args.model_name)}"
            )
        version = args.version
    else:
        version = sat.get_model_latest_version(args.model_name)

    model = sat.Classifier.from_pretrained(
        args.model_name,
        version=version,
    )
    model.eval()

    if args.target:
        target_path = args.target
    else:
        target_path = os.path.join(
            sat.core._get_model_dir(),
            args.model_name,
            f"v{version}",
        )

    print(f"Target Path: {target_path}")
    dynamic_axes = {
        "input_data": {0: "batch", 2: "height", 3: "width"},  # write axis names
        "preds": {0: "batch"},
    }
    input_names = ["input_data"]
    output_names = ["preds"]

    input_sample = torch.rand(1, 3, model.input_size, model.input_size)

    if args.quantize:

        try:
            from onnxruntime.quantization import quantize_qat
        except ImportError:
            raise AssertionError("run `pip install onnxruntime`")

        target_model_path = os.path.join(
            target_path,
            "{}_quantize.onnx".format(args.model_name),
        )
        with tempfile.NamedTemporaryFile(suffix=".onnx") as temp:
            model.to_onnx(
                temp.name,
                input_sample=input_sample,
                opset_version=args.opset_version,
                input_names=input_names,
                output_names=output_names,
                dynamic_axes=dynamic_axes,
                export_params=True,
            )
            quantize_qat(temp.name, target_model_path)
    else:
        target_model_path = os.path.join(
            target_path,
            "{}.onnx".format(args.model_name),
        )
        model.to_onnx(
            target_model_path,
            input_sample=input_sample,
            opset_version=args.opset_version,
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            export_params=True,
        )

    onnx_model = onnx.load(target_model_path)
    meta = onnx_model.metadata_props.add()
    meta.key = "labels"
    meta.value = "\n".join(model.labels)
    onnx.save(onnx_model, target_model_path)
    print("Model saved")


if __name__ == "__main__":
    pa = parse_arguments()
    main(pa)