Commit 7ac12bd9 authored by Václav Novák's avatar Václav Novák
Browse files

fix: remove concurrency, refactor

parent c997f706
Loading
Loading
Loading
Loading
+1 −14
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ from src.evaluation import evaluate
NUM_WORKERS = 0 if platform.system() == "Windows" else 6
CUDA_CAPABLE = torch.cuda.is_available()
MAX_SECONDS_PER_TRIAL = 7200 if len(sys.argv) < 2 else int(sys.argv[1])
STORAGE_PATH = "results.log" if len(sys.argv) < 3 else sys.argv[2]

###############################################################################

@@ -55,19 +54,7 @@ def objective(trial):
###############################################################################


if platform.system() != "Windows":
    storage = optuna.storages.JournalStorage(
        optuna.storages.JournalFileStorage(STORAGE_PATH),
    )
    
else:
    lock_obj = optuna.storages.JournalFileOpenLock(STORAGE_PATH)
    storage = optuna.storages.JournalStorage(
        optuna.storages.JournalFileStorage(STORAGE_PATH, lock_obj=lock_obj),
    )


if __name__ == '__main__':
    study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(), storage=storage)
    study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler())

    study.optimize(objective)
+2 −7
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ from torch.utils.data import DataLoader, Dataset
from typing import Iterator, Tuple
from optuna import Trial

import torchvision.transforms as transforms
import torch
import torch.nn as nn

@@ -14,15 +13,11 @@ def build_dataset(dataset_dir: str, is_train: bool, image_input_size: int) -> Si

    @param dataset_dir: str, path to dataset
    @param is_train: bool, indicates whether the dataset should load test or train data
    @param image_input_size: int, dimension of image taken from dataset
    @param image_input_size: int, dimension to which data should be scaled

    @return: SiameseNetworkDataset, dataset
    """
    return SiameseNetworkDataset(dataset_dir, is_train,
                                 transform=transforms.Compose([transforms.Resize((image_input_size,
                                                                                  image_input_size)),
                                                               transforms.ToTensor()
                                                               ]))
    return SiameseNetworkDataset(dataset_dir, is_train, image_input_size)


def build_dataloader(dataset: Dataset, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader:
+16 −8
Original line number Diff line number Diff line
@@ -3,9 +3,11 @@ import csv

from PIL import Image
from platform import system
from typing import Tuple, List, Optional
from typing import Tuple, List
from torchvision.transforms import Compose
from torch.utils.data import Dataset
import torchvision.transforms as transforms



class SiameseNetworkDataset(Dataset):
@@ -14,18 +16,19 @@ class SiameseNetworkDataset(Dataset):
    return two images instead of one.
    """

    def __init__(self, data_dir: str, is_train: bool = False, transform: Optional[Compose] = None):
    def __init__(self, data_dir: str, is_train: bool, 
                 image_input_size: int):
        """
        Constructor.

        @param data_dir: str, path to directory with dataset, has to contain index files with paths to images
        (e.g. train_unix.csv, test_unix.csv...)
        @param is_train: bool, indicates wherether to load train or test data
        @param transform: Compose, transforms to be applied on each image
        @image_input_size: int, dimension to which data should be scaled
        """
        self.data_dir = data_dir
        self.variations = self.get_variations(is_train)
        self.transform = transform
        self.image_input_size = image_input_size

    def get_variations(self, is_train) -> List[Tuple[str, str, str]]:
        """
@@ -53,12 +56,17 @@ class SiameseNetworkDataset(Dataset):
        img0 = Image.open(image_variation[0])
        img1 = Image.open(image_variation[1])
        
        # preprocessing
        img0 = img0.convert("L").point(lambda p: 255 if p > 242 else p)
        img1 = img1.convert("L").point(lambda p: 255 if p > 242 else p)
        
        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        current_transform = transforms.Compose([transforms.Resize((self.image_input_size,
                                                                   self.image_input_size)),
                                                transforms.ToTensor()])
        
        img0 = current_transform(img0)
        img1 = current_transform(img1)


        return img0, img1, int(image_variation[2])