Source code for satellighte.datasets.resisc45

import os
from typing import List, Tuple

import numpy as np
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torchvision.datasets.utils import extract_archive

from ..core import _download_file_from_url
from .base import BaseDataset, Identity


[docs]class RESISC45(BaseDataset): """ RESISC45 dataset is a dataset for Remote Sensing Image Scene Classification (RESISC). It contains 31,500 RGB images of size 256×256 divided into 45 scene classes, each class containing 700 images. """ __phases__ = ( "train", "test", "val", ) def __init__( self, root_dir: str = None, phase: str = "val", transforms=None, **kwargs, ): if root_dir is None: root_dir = "TODO" self.root_dir = root_dir self.phase = phase self.transforms = Identity() if transforms is None else transforms self.download() ids, targets = self._split_dataset(phase) super().__init__(ids, targets, transforms=transforms, **kwargs) @property def num_classes(self) -> int: return len(self.__classes) @property def classes(self) -> List[str]: return self.__classes
[docs] def name_to_id(self, name: str) -> int: return self.__classes.index(name)
[docs] def id_to_name(self, idx: int) -> str: return self.__classes[idx]
[docs] def _split_dataset(self, phase: str) -> Tuple: labels = [] filenames = [] data_dir = os.path.join(self.root_dir) for item in os.listdir(data_dir): f = os.path.join(data_dir, item) if os.path.isfile(f): continue for subitem in os.listdir(f): sub_f = os.path.join(f, subitem) filenames.append(sub_f) labels.append(item) filenames = np.asarray(filenames) labels = np.asarray(labels) labels = labels[filenames.argsort()] filenames = filenames[filenames.argsort()] # convert to integer labels label_encoder = preprocessing.LabelEncoder() label_encoder.fit(np.sort(np.unique(labels))) labels = label_encoder.transform(labels) labels = np.asarray(labels) # remember label encoding self.__classes = list(label_encoder.classes_) # split into a is_train and test set as provided data is not presplit x_train, x_test, y_train, y_test = train_test_split( filenames, labels, test_size=0.2, random_state=1, stratify=labels, ) x_train, x_val, y_train, y_val = train_test_split( x_train, y_train, test_size=0.25, random_state=1, stratify=y_train, ) # 0.25 x 0.8 = 0.2 if phase == "train": return x_train.tolist(), y_train.tolist() elif phase == "test": return x_test.tolist(), y_test.tolist() elif phase == "val": return x_val.tolist(), y_val.tolist() else: raise ValueError("Unknown phase")
[docs] def _check_exists(self) -> bool: """ Check the Root directory is exists """ return os.path.exists(self.root_dir)
[docs] def download(self) -> None: """ Download the dataset from the internet """ if self._check_exists(): return os.makedirs(self.root_dir, exist_ok=True) _download_file_from_url( "https://drive.google.com/u/0/uc?id=1PCesRqeXYINcsulnTixVjR15xFNXropZ&export=download&confirm=t", os.path.join(self.root_dir, "resisc45.zip"), ) extract_archive( os.path.join(self.root_dir, "resisc45.zip"), self.root_dir, remove_finished=True, )
if __name__ == "__main__": data = RESISC45("satellighte/datas/resisc45") print(data[0]) print(data.classes) print(len(data.classes))