Commit 49276b31 authored by Adéla Štěpková's avatar Adéla Štěpková
Browse files

improved plot exporting

parent 41e9f384
Loading
Loading
Loading
Loading
+91 −3
Original line number Diff line number Diff line
import xml.etree.ElementTree as ET
import csv
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def write_to_xml(data_dict: dict, path: str):
    root = ET.Element("automaton")
@@ -21,7 +24,7 @@ def xml_to_df(xml_path: str):

def xml_to_csv(xml_path: str, csv_path: str):
    df = xml_to_df(xml_path)
    df.to_csv(csv_path)
    df.to_csv(csv_path, index=False)


def mode_to_str(mode: str):
@@ -35,6 +38,8 @@ def mode_to_str(mode: str):
        return "pow_f"
    elif mode == "r":
        return "pow_rev"
    elif mode == "p2":
        return "port_2comp"
    return ""


@@ -42,4 +47,87 @@ def get_stats_file_name(base_name: str, mode: str, options: list[str]) -> str:
    return base_name + "_" + mode_to_str(mode) + "_" +  "_".join(options)


def merge_dataframes(dfs: list[pd.DataFrame], on_columns = ["name", "orig_size"]) -> pd.DataFrame:
    merged_df = dfs[0]

    for df in dfs[1:]:
        merged_df = pd.merge(merged_df, df, on=on_columns)

    return merged_df


def extension(path: str):
    _, file = os.path.split(path)
    name, ext = os.path.splitext(file)
    return ext


def load_dataframes(dir: str):
    dfs = []
    for filename in os.scandir(dir):
        path = filename.path
        ext = extension(path)

        df = None
        if ext == ".csv":
            df = pd.read_csv(path)
        elif ext == ".xml":
            df = pd.read_xml(path)

        if df is not None:
            dfs.append(df)

    return dfs
        

def compose_into_one_csv(dir: str, merged_csv_path: str | None = None, drop_columns: list[str] = []):
    dfs = load_dataframes(dir)
    dfs = remove_useess_columns(dfs, drop_columns)
    merged_df = merge_dataframes(dfs)

    if merged_csv_path is None:
        merged_csv_path = os.path.join(dir, "all.csv")
    merged_df.to_csv(merged_csv_path, index=False)


def remove_useess_columns(dfs: list[pd.DataFrame], useless_columns: list[str]) -> list[pd.DataFrame]:
    updated_dfs = []
    for df in dfs:
        for column_name in useless_columns:
            if column_name in df.columns:
                df = df.drop(column_name, axis=1)
        updated_dfs.append(df)

    return updated_dfs


def graph_basic(data: pd.DataFrame, x_col: str, y_col: str, 
                x_col_name: str, y_col_name: str, 
                plot_limit: int | None = None):
    fig, ax = plt.subplots()

    plot = sns.scatterplot(data=data, x=x_col, y=y_col, ax=ax)

    # draw middle line
    max_val = np.max([data[x_col].max(), data[y_col].max()])
    line = np.linspace(0, max_val, 1000)
    ax.plot(line, line, "-r")

    # set axis names
    ax.set_xlabel(x_col_name)
    ax.set_ylabel(y_col_name)

    if plot_limit is not None and plot_limit < max_val:
        plt.xlim(0, plot_limit)
        plt.ylim(0, plot_limit)


def convert_cols_to_bool(df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
    for col in columns:
        df[col] = df[col].astype(bool)


if __name__ == "__main__":
    directory = "/home/xstepkov/aligater/src/experiments/csv_files/automatark"
    compose_into_one_csv(directory, drop_columns=["success", "Unnamed: 0"])
+1646 −0

File added.

Preview size limit exceeded, changes collapsed.