import os
from typing import Dict, List, Union
import numpy as np
import pytorch_lightning as pl
import torch
from . import api
from .core import _get_arch_cls, _get_model_dir, _parse_saved_model_name
from .utils import configure_batch, convert_json
import torchmetrics as tm
[docs]class Classifier(pl.LightningModule):
"""Generic pl.LightningModule definition for image classification"""
# pylint: disable=no-member
# pylint: disable=not-callable
def __init__(
self,
model: torch.nn.Module,
hparams: Dict = None,
):
super().__init__()
self.model = model
self.__metrics = {}
self.save_hyperparameters(hparams)
self.configure_preprocess()
@property
def input_size(self):
return self.model.config.get("input").get("input_size")
@property
def labels(self):
return self.model.labels
[docs] def summarize(self):
print(pl.utilities.model_summary.model_summary.summarize(self))
# WARNING: This function should only be used during training. not inference
[docs] def forward(
self,
batch: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the model.
Args:
batch (torch.Tensor): Batch of tensors of shape (B x C x H x W).
Returns:
torch.Tensor: Prediction of the model.
"""
# Apply preprocess with the help of registered buffer
batch = ((batch / self.normalizer) - self.mean) / self.std
with torch.no_grad():
# Get logits from the model
logits = self.model.forward(batch)
# Apply postprocess for the logits that are returned from model and get predictions
preds = self.model.logits_to_preds(logits)
return preds
[docs] @torch.jit.unused
def predict(
self,
data: Union[np.ndarray, List],
target_size: int = None,
):
"""
Perform image classification using given image or images.
Args:
data (Union[np.ndarray, List]): Numpy array or list of numpy arrays. In the form of RGB.
target_size (int, optional): If it is not None, the image will be resized to the target size. Defaults to None.
Returns:
[type]: [description]
"""
# Converts given image or list of images to list of tensors
batch = self.to_tensor(data)
# Override target_size if input_size is given and target_size is None
if self.input_size and (target_size is None):
target_size = self.input_size
# Configure batch for the required size
batch = configure_batch(
batch,
target_size=target_size,
adaptive_batch=target_size is None,
)
# Get predictions from the model
preds = self.forward(batch)
# Convert predictions to json format
json_preds = convert_json(preds, self.labels)
return json_preds
[docs] @classmethod
def build(
cls,
arch: str,
config: str = None,
hparams: Dict = {},
**kwargs,
) -> pl.LightningModule:
"""
Build the model with given architecture and configuration.
Args:
arch (str): Model architecture name.
config (str, optional): Model configuration. Defaults to None.
hparams (Dict, optional): Hyperparameters. Defaults to {}.
Returns:
pl.LightningModule: Model instance with randomly initialized weights.
"""
model = cls.build_arch(arch, config=config, **kwargs)
return cls(model, hparams=hparams)
[docs] @classmethod
def build_arch(
cls,
arch: str,
config: str = None,
**kwargs,
) -> torch.nn.Module:
"""
Build the architecture model with given configuration.
Args:
arch (str): Model architecture name.
config (str, optional): Model configuration. Defaults to None.
Returns:
torch.nn.Module: Architecture model instance with randomly initialized weights.
"""
arch_cls = _get_arch_cls(arch)
return arch_cls.build(config=config, **kwargs)
[docs] @classmethod
def from_pretrained(
cls,
model_name: str,
version: int = None,
hparams: Dict = {},
) -> pl.LightningModule:
"""
[summary]
Args:
model_name (str): Model name in the format of {arch}_{config}_{dataset}
version (int, optional): Model version. Defaults to None.
hparams (Dict, optional): Hyperparameters. Defaults to {}.
Returns:
pl.LightningModule: Model instance.
"""
model = cls.from_pretrained_arch(model_name, version=version)
return cls(model, hparams=hparams)
[docs] @classmethod
def from_pretrained_arch(
cls,
model_name: str,
version: int = None,
) -> torch.nn.Module:
"""
Get pretrained arch model from the model name.
Args:
model_name (str): Model name in the format of {arch}_{config}_{dataset}
version (int, optional): Model version. Defaults to None.
Returns:
torch.nn.Module: Architecture model instance.
"""
# Check if version is not given then get the latest version
if not version:
version = api.get_model_latest_version(model_name)
# Get arch name and config name from the given model_name
arch, config, _ = _parse_saved_model_name(model_name)
# Get arch class
arch_cls = _get_arch_cls(arch)
api.get_saved_model(model_name, version)
# Get pretrained model pat
model_path = os.path.join(_get_model_dir(), model_name, f"v{version}")
return arch_cls.from_pretrained(model_path, config=config)
[docs] def training_step(self, batch, batch_idx):
batch, targets = batch
# Apply preprocess with the help of registered buffer
batch = ((batch / self.normalizer) - self.mean) / self.std
# Get logits from the model
logits = self.model.forward(batch)
# Compute loss
loss = self.model.compute_loss(
logits,
targets,
hparams=self.hparams,
)
return loss
[docs] def training_epoch_end(self, outputs):
losses = {}
for output in outputs:
if isinstance(output, dict):
for k, v in output.items():
if k not in losses:
losses[k] = []
losses[k].append(v)
else:
if "loss" not in losses:
losses["loss"] = []
losses["loss"].append(output)
for name, loss in losses.items():
self.log("{}/training".format(name), sum(loss) / len(loss))
[docs] def on_validation_epoch_start(self):
for metric in self.__metrics.values():
metric.reset()
[docs] def validation_step(self, batch, batch_idx):
batch, targets = batch
# Apply preprocess with the help of registered buffer
batch = ((batch / self.normalizer) - self.mean) / self.std
with torch.no_grad():
# Get logits from the model
logits = self.model.forward(batch)
# Compute loss
loss = self.model.compute_loss(
logits,
targets,
hparams=self.hparams,
)
# Apply postprocess for the logits that are returned from model and get predictions
preds = self.model.logits_to_preds(logits)
for metric in self.__metrics.values():
metric.update(preds.cpu(), targets.cpu())
return loss
[docs] def validation_epoch_end(self, outputs):
losses = {}
for output in outputs:
if isinstance(output, dict):
for k, v in output.items():
if k not in losses:
losses[k] = []
losses[k].append(v)
else:
if "loss" not in losses:
losses["loss"] = []
losses["loss"].append(output)
for name, loss in losses.items():
self.log("{}/validation".format(name), sum(loss) / len(loss))
for name, metric in self.__metrics.items():
self.log(
"metrics/{}".format(name),
metric.compute(),
prog_bar=True,
)
[docs] def on_test_epoch_start(self):
for metric in self.__metrics.values():
metric.reset()
[docs] def test_step(self, batch, batch_idx):
batch, targets = batch
# Apply preprocess with the help of registered buffer
batch = ((batch / self.normalizer) - self.mean) / self.std
with torch.no_grad():
# Get logits from the model
logits = self.model.forward(batch)
# Compute loss
loss = self.model.compute_loss(
logits,
targets,
hparams=self.hparams,
)
# Apply postprocess for the logits that are returned from model and get predictions
preds = self.model.logits_to_preds(logits)
for metric in self.__metrics.values():
metric.update(preds.cpu(), targets.cpu())
return loss
[docs] def test_epoch_end(self, outputs):
metric_results = {}
for name, metric in self.__metrics.items():
metric_results[name] = metric.compute()
for name, metric in self.__metrics.items():
self.log(
"metrics/{}".format(name),
metric.compute(),
prog_bar=True,
)
return metric_results
[docs] def add_metric(self, name: str, metric: tm.Metric):
"""Adds given metric with name key
Args:
name (str): name of the metric
metric (tm.Metric): Metric object
"""
# TODO add warnings if override happens
self.__metrics[name] = metric
[docs] def get_metrics(self) -> Dict[str, tm.Metric]:
"""Return metrics defined in the `FaceDetector` instance
Returns:
Dict[str, tm.Metric]: defined model metrics with names
"""
return {k: v for k, v in self.__metrics.items()}
[docs] def to_tensor(self, images: Union[np.ndarray, List]) -> List[torch.Tensor]:
"""Converts given image or list of images to list of tensors
Args:
images (Union[np.ndarray, List]): RGB image or list of RGB images
Returns:
List[torch.Tensor]: list of torch.Tensor(C x H x W)
This method is taken from fastface repositories.
`fastface.module.to_tensor`
Here is a link for it: `github.com/borhanMorphy/fastface`
"""
assert isinstance(
images, (list, np.ndarray)
), "give images must be eather list of numpy arrays or numpy array"
if isinstance(images, np.ndarray):
images = [images]
batch: List[torch.Tensor] = []
for img in images:
assert (
len(img.shape) == 3
), "image shape must be channel, height\
, with length of 3 but found {}".format(
len(img.shape)
)
assert (
img.shape[2] == 3
), "channel size of the image must be 3 but found {}".format(img.shape[2])
batch.append(
# h,w,c => c,h,w
torch.tensor(img, dtype=self.dtype, device=self.device).permute(2, 0, 1)
)
return batch