Commit 6f312216 authored by xsedmid's avatar xsedmid
Browse files

classifier

parent fb9ee074
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
/__pycache__/
/backup/
/data/
 No newline at end of file
+136 −180
Original line number Diff line number Diff line
from PIL import Image
from matplotlib.backends.backend_agg import FigureCanvasAgg
import io

from torchvision.transforms import ToTensor, Lambda, Compose
import torchvision.transforms as transforms

from matplotlib.colors import Normalize
import matplotlib.pyplot as plt

import json
import numpy as np
import pickle
from PIL import Image
import os
import pandas as pd
import torch

from models import MLP, binary_resnet18, binary_resnet50
from utils import load_kNN_classifier_sklearn
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

import feature_extractor as feat_ext

from sklearn.preprocessing import StandardScaler

from torch.utils.data import Dataset, DataLoader, TensorDataset, SequentialSampler, Subset
from torchvision import transforms

import pickle
import models as models
from utils import load_sklearn_kNN_classifier

# constants
meta_dir = "y:/datasets/dyslex/experiment-final/meta"
out_models_dir = "y:/trials/xsedmid/dyslex/experiment-final/v2.1/models/"

fill_null_values_by_zero = True
summarize_characteristics_by_mean = True
normalize_values = False
generate_fixation_image_for_each_trialid = False




#
def classify(task_def, model_def, feature_file_path_fixations, feature_file_path_saccades, feature_file_path_metrics):

    accuracy_score = 0.5




    return accuracy_score

# main
def main():
    
    # Read task-invariant properties from a json file
    with open(os.path.join(meta_dir, 'properties.json')) as property_file:
    # Read configuration properties from a json file
    with open('properties.json') as property_file:
        properties = json.loads(property_file.read())

    # Path to directory with metainformation about the dataset and features for individual tasks
    meta_dir = properties['meta_dir']
    # Path to directory with trained models
    models_dir = properties['models_dir']

    # Data properties
    aoi_id_col_name = properties['aoi_id_col_name']
    subject_id_col_name = properties['subject_id_col_name']
    degrees_visual_angle_pixels = properties['degrees_visual_angle_pixels']
@@ -69,150 +46,129 @@ def main():
        ]
    fixation_image_visual_params = eval(properties['fixation_image_visual_params'])

    #task_definitions = [('T1', 'Syllables'), ('T2', 'Prosaccades'), ('T3', 'Antisaccades'), ('T4', 'Meaningful_Text'), ('T5', 'Pseudo_Text'), ('T6', 'Visual_Diff_1'), ('T6', 'Visual_Diff_2')]
    # Definitions of tasks and trained models
    tasks_def = eval(properties['tasks'])

    models_def = eval(properties['models'])
    
    task_feature_def_dict = feat_ext.load_task_feature_definition_dict(tasks_def, meta_dir)

    # 0
    # feature_file_path_metrics = "y:/datasets/dyslex/experiment-final/v2.1/original/Subject_1257_T1_Syllables_metrics.csv"
    # feature_file_path_fixations = "y:/datasets/dyslex/experiment-final/v2.1/original/Subject_1257_T1_Syllables_fixations.csv"
    feature_file_path_metrics = "y:/datasets/dyslex/experiment-final/v3.0/original/Subject_1999_T1_Syllables_metrics.csv"
    feature_file_path_fixations = "y:/datasets/dyslex/experiment-final/v3.0/original/Subject_1999_T1_Syllables_fixations.csv"

    # 1
    # feature_file_path_metrics = "y:/datasets/dyslex/experiment-final/v2.1/original/Subject_1038_T1_Syllables_metrics.csv"
    # feature_file_path_fixations = "y:/datasets/dyslex/experiment-final/v2.1/original/Subject_1038_T1_Syllables_fixations.csv"
    # feature_file_path_metrics = "y:/datasets/dyslex/experiment-final/v3.0/original/Subject_1879_T1_Syllables_metrics.csv"
    # feature_file_path_fixations = "y:/datasets/dyslex/experiment-final/v3.0/original/Subject_1879_T1_Syllables_fixations.csv"

    task_type_id = 'T1'
    
    task_models_def = eval(properties['task_models'])

    # Load the trained models
    for task_id in task_models_def.keys():
        for task_model_def in task_models_def[task_id]:
            type_id = task_model_def['type_id']
            file_name = task_model_def['file_name']
            desc_short = task_model_def['desc_short']
            print(f'Loading model: {desc_short} ({file_name})')

            # Load the data scaler from file
            if 'scaler_file_name' not in task_model_def.keys():
                scaler = None
            else:
                scaler = pickle.load(open(os.path.join(models_dir, task_model_def['scaler_file_name']), 'rb'))
            task_model_def['scaler'] = scaler
            
            # Load the classifier
            # Case 1: kNN
            if type_id == 'kNN':
                classifier = load_sklearn_kNN_classifier(models_dir, file_name)
            # Case 2: MLP
            elif type_id == 'MLP':
                params = task_model_def['params'].split(';')
                classifier = models.MLP(int(params[0]), int(params[1]), int(params[2]), int(params[3]), float(params[4]))
                classifier.load_state_dict(torch.load(os.path.join(models_dir, file_name)))
                classifier.eval()
            # Case 3: ResNet18
            elif type_id == 'CNN-RN18':
                classifier = models.binary_resnet18()
                classifier.load_state_dict(torch.load(os.path.join(models_dir, file_name)))
                classifier.eval()
            elif type_id == 'CNN-RN50':
                classifier = models.binary_resnet50()
                classifier.load_state_dict(torch.load(os.path.join(models_dir, file_name)))
                classifier.eval()
            else:
                classifier = None
            
            task_model_def['loaded_model'] = classifier
            if classifier is not None:
                print(f'Classifier successfully loaded.')
            else:
                print(f'Classifier not loaded.')

    # Classification
    for task_def in tasks_def:
        task_id = task_def['id']
        task_type_id = task_def['type_id']
        print(f'Classification task: {task_id}')

        # Path to sample data files
        # 0 (non-dyslexic):
        # feature_file_path_metrics = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1892_{task_id}_metrics.csv"
        # feature_file_path_fixations = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1892_{task_id}_fixations.csv"
        # feature_file_path_metrics = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1901_{task_id}_metrics.csv"
        # feature_file_path_fixations = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1901_{task_id}_fixations.csv"
        # 1 (dyslexic):
        # feature_file_path_metrics = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1009_{task_id}_metrics.csv"
        # feature_file_path_fixations = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1009_{task_id}_fixations.csv"
        feature_file_path_metrics = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1113_{task_id}_metrics.csv"
        feature_file_path_fixations = f"y:/datasets/dyslex/experiment-final/v3.0/original-all/Subject_1113_{task_id}_fixations.csv"

        # Feature extraction -- pre-defined features
        subject_id, subject_data_dict = feat_ext.load_and_transform_subject_characteristics(task_type_id, task_feature_def_dict, fill_null_values_by_zero, subject_id_col_name, aoi_id_col_name, summarize_characteristics_by_mean, feature_file_path_metrics)
        df = feat_ext.create_subject_characteristics_profile(subject_id, subject_data_dict)

    #df.to_csv('c:/temp/out.csv', sep=';', index=False)

        X = df.drop('subject_id', axis=1).astype(np.float64)
    print(f'Feature count: {len(X.columns)}')
    print(X)

    classifier_knn = load_kNN_classifier_sklearn(
        'T1_Syllables_3NN',
        out_models_dir
    )
    X_preds = classifier_knn.predict(X.values)
    print(f'kNN: {X_preds}')

    print(X.values.shape)
    n_features = len(X.columns)

    # MLP
    #classifier_mlp = torch.load(os.path.join(out_models_dir, 'T1_Syllables_MLP_lc2_lf2_d0.25-e20_lr0.001.pt'))
    classifier_mlp = MLP(n_features, n_features // 2, 2, 2, 0.25)
    classifier_mlp.load_state_dict(torch.load(os.path.join(out_models_dir, 'T1_Syllables_MLP_lc2_lf2_d0.25-e20_lr0.001.pt')))
    classifier_mlp.eval()

    # Load scaler from file
    scaler = pickle.load(open(os.path.join(out_models_dir, 'T1_Syllables_MLP_lc2_lf2_d0.25-e20_lr0.001.scaler.pkl'), 'rb'))

    # Classify with MLP
    with torch.no_grad():
        feature = torch.from_numpy(scaler.transform(X.values))
        # Convert feature to float64
        feature = feature.float()
        #feature = X.values
        outputs = classifier_mlp(feature)
        print(f'MLP: {outputs}')
        outputs = outputs.softmax(dim=1)
        probs, preds = torch.max(outputs, 1)
        print(f'MLP: {probs} {preds}')

    # with torch.no_grad():
    #     for inputs in test_loader:

    #         #print(len(inputs))

    #         outputs = classifier_mlp(inputs)
    #         #outputs = classifier_mlp(X.values)
    #         print(f'MLP: {outputs}')





        #print(f'Feature count: {len(X.columns)}')

        # Feature extraction -- fixation images
        # Task-specific variables for the construction of fixation image
        x_min, x_max, y_min, y_max, d_max = fixation_image_visual_params[task_type_id]
        fixation_duration_color_norm = Normalize(0, d_max)

    # Load the fixation image
        # Subject data
        subject_id, figs_dict, df_fixations_all = feat_ext.generate_subject_fixation_images(generate_fixation_image_for_each_trialid, fixation_image_characteristics_names, fill_null_values_by_zero, subject_id_col_name, degrees_visual_angle_pixels, fixation_duration_color_norm, x_min, x_max, y_min, y_max, feature_file_path_fixations)
    print(f'subject_id: {subject_id}')
        # Expect a single fixation image per experiment (i.e., generate_fixation_image_for_each_trialid = False)
        fig = figs_dict[-1]
    print(f'Fixation image: {fig}')
    # fixima_path = 'c:/temp/out-fixima.png'
    # fig.savefig(fixima_path, bbox_inches='tight', facecolor='white', transparent=False, pad_inches=0)
    # plt.close(fig)

    #fig = Image.open(fixima_path).convert('RGB')
    
    # Create a FigureCanvasAgg instance and render the figure to it
        img_buf = io.BytesIO()
        plt.savefig(img_buf, format='png') # Create a PIL (Pillow) Image object from the byte array
        plt.close(fig)
        fig = Image.open(img_buf)
    
        fig = fig.convert('RGB')
    transform=transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
    )
    fig = transform(fig)
        fixation_image_transforms = feat_ext.FixationImageTransform()
        fig = fixation_image_transforms(fig)
        fig = fig.unsqueeze(0)
    #print(f'fig: {fig}')


    classifier_cnn = binary_resnet18()
    classifier_cnn.load_state_dict(torch.load(os.path.join(out_models_dir, 'T1_Syllables_ResNet18-e20_lr0.001.pt')))
    classifier_cnn.eval()

    outputs = classifier_cnn(fig)
    print(f'CNN: {outputs}')
        # Classification
        for task_model_def in task_models_def[task_id]:
            type_id = task_model_def['type_id']
            classifier = task_model_def['loaded_model']

            # Data scaling
            scaler = task_model_def['scaler']
            if scaler is None:
                X_scaled = X
            else:
                X_scaled = torch.from_numpy(scaler.transform(X.values))
                #X_scaled = torch.from_numpy(scaler.transform(X))
                X_scaled = X_scaled.float() # Convert X to float64

            # Case 1: kNN
            if type_id == 'kNN':
                probs = torch.from_numpy(classifier.predict_proba(X_scaled))
                probs, preds = torch.max(probs, 1)

            # Case 2: MLP
            if type_id == 'MLP':
                with torch.no_grad():
                    outputs = classifier(X_scaled)
                outputs = outputs.softmax(dim=1)
                probs, preds = torch.max(outputs, 1)
    print(f'CNN: {probs} {preds}')


    classifier_cnn = binary_resnet50()
    classifier_cnn.load_state_dict(torch.load(os.path.join(out_models_dir, 'T1_Syllables_ResNet50-e20_lr0.001.pt')))
    classifier_cnn.eval()

    outputs = classifier_cnn(fig)
    print(f'CNN: {outputs}')
            # Case 3: CNN
            if type_id.startswith('CNN-'):
                outputs = classifier(fig)
                outputs = outputs.softmax(dim=1)
                probs, preds = torch.max(outputs, 1)
    print(f'CNN: {probs} {preds}')


    
    # # get task definitions
    # tasks = get_task_definitions()

    # # get model definitions
    # models = get_model_definitions()

    # # read trained models from files
    # trained_models = []

    # # classify
    # classify(tasks[0]['id'], models[0]['file_name'])
            
    # # print results
            # Prints the classification results
            print(f'  {task_model_def["desc_short"]}: {preds.squeeze()} ({probs.squeeze():.3f})')

if __name__ == "__main__":
    main()
+10 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import json
import numpy as np
import os
import pandas as pd
from torchvision import transforms

import matplotlib.cm as cm
from matplotlib.colors import Normalize
@@ -137,6 +138,15 @@ def create_subject_characteristics_profile(subject_id, characteristics_dict):
    
    return pd.concat(subject_dfs, axis=1)

class FixationImageTransform(transforms.Compose):
    def __init__(self):
        fixation_image_transforms = [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]
        super().__init__(fixation_image_transforms)

def generate_fixation_image(df_fixations, degrees_visual_angle_pixels, fixation_duration_color_norm, x_min, x_max, y_min, y_max):

    # Plot the fixations

model_training.ipynb

0 → 100644
+1583 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading