Skip to content
Snippets Groups Projects
trip_similarity.py 7.71 KiB
Newer Older
Jiří Vrbka's avatar
Jiří Vrbka committed
import collections
import copy
import random
from typing import List, Dict
from sympy.utilities.iterables import multiset_permutations

"""
DEBUG if true, there will be small debug logs (on system.in)
"""
DEBUG = True
"""
SHUFFLE directs shuffling result with the same number of occurrences. If true, occurrences are going to be shuffeled
"""
SHUFFLE = True


class TripSimilarity:
    """
    Finds out recommendation with similarity of people trips
    """

    @staticmethod
    def find_recommendation_with_removal(data: List[List[str]], user_data: List[str], min_number_of_rec: int = 5) -> List[str]:
        """
        Find recommendation based of similarity with given trips and returns untraveled places that
        has those similarities. If no similarity found it tries to find similarity for set that is smaller
         (removing one city from user data, two,...)
        :param data: groups of cities to find similarity with
        :param user_data: to find similarity by
        :param min_number_of_rec: minimal result that is required to be recommended
        :return: list of recommended cities
        """
        result = TripSimilarity.find_recommendation(data, user_data, min_number_of_rec)
        if len(result) < min_number_of_rec:
            for count_to_remove in range(1, len(user_data) - 1):
                remove = [True] * count_to_remove
                not_remove = [False] * (len(user_data) - count_to_remove)
                remove.extend(not_remove)

                permutation = list(multiset_permutations(remove))
                for mutation in permutation:

                    user_data_copy = user_data.copy()

                    if DEBUG:
                        print("\nfor this round removing: ", end=" ")

                    for i in range(0, len(user_data)):
                        if mutation[i]:
                            user_data_copy.remove(user_data[i])

                            if DEBUG:
                                print(user_data[i], end=", ")

                    result = TripSimilarity.find_recommendation(data, user_data_copy, min_number_of_rec + 1)
                    if len(result) >= min_number_of_rec:
                        return result

        return result

    @staticmethod
    def find_recommendation(data: List[List[str]], user_data: List[str], max_number_of_rec: int = 5) -> List[str]:
        """
        Find recommendation based of similarity with given trips and returns untraveled places that
        has those similarities.
        :param data: groups of cities to find similarity with
        :param user_data: to find similarity by
        :param max_number_of_rec: maximum number of cities that will be return
        :return: list of recommended cities
        """
        list_of_groups = copy.deepcopy(data)  # TripSimilarity.__load_from_csv(data_file)
        list_of_groups = TripSimilarity.__remove_duplicities_in_group(list_of_groups)
        list_of_groups = TripSimilarity.__get_lists_containing(list_of_groups, user_data)
        list_of_groups = TripSimilarity.__remove_given_values_from_each_group(list_of_groups, user_data)
        cities_occurrences = TripSimilarity.__get_occurrences_of_cities(list_of_groups)
        cities_occurrences = TripSimilarity.__shuffle_in_groups(cities_occurrences) if SHUFFLE else cities_occurrences
        recommended_cities = TripSimilarity.__get_first_x_cities(cities_occurrences, max_number_of_rec)

        return recommended_cities

    @staticmethod
    def __get_occurrences_of_cities(groups: List[List[str]]) -> Dict[int, List[str]]:
        """
        Gets number and for this number list of cities that appeared number-times in param groups.
        Count one city in one group max one time (aka does not count duplicities in one group)
        :param groups: groups of cities (aka trips)
        :return: key: number, value: cities that appeared number-times in different groups
        """
        occurrences = {}
        for v in groups:
            for city in list(dict.fromkeys(v)):
                if city not in occurrences:
                    occurrences[city] = 1
                else:
                    occurrences[city] += 1

        result = {}
        for key in occurrences.keys():
            value = occurrences[key]
            if value not in result:
                result[value] = []

            result[value].append(key)

        return result

    @staticmethod
    def __get_first_x_cities(groups: Dict[int, List[str]], x=10) -> List[str]:
        """
        Get X cities based on they key (higher = better)
        :param groups: key: number, value: cities that appeared number-times in different groups
        :param x: number of cities to be returned
        :return: x cities with highest number
        """

        sorted_occurrences = sorted(groups.items(), key=lambda kv: kv[1])
        sorted_occurrences.reverse()
        result = []

        for key, group in collections.OrderedDict(sorted_occurrences).items():
            for city in group:
                result.append(city)

                if DEBUG:
                    print("Adding city [{}] {} ".format(key, city))

                if len(result) >= x:
                    return result

        return result

    @staticmethod
    def __shuffle_in_groups(groups: Dict[int, List[str]]) -> Dict[int, List[str]]:
        """
        Shuffles lists in groups
        :param groups: dict that values in Dict.values() will be shuffled
        :return:
        """
        for key in groups:
            random.shuffle(groups[key])
        return groups

    @staticmethod
    def __remove_duplicities_in_group(groups: List[List[str]]) -> List[List[str]]:
        """
        Removes duplicates in inner lists
        :param groups:
        :return:
        """
        result = []
        for group in groups:
            result.append(list(dict.fromkeys(group)))

        return result

    @staticmethod
    def __remove_given_values_from_each_group(groups: List[List[str]], to_remove: List[str]) -> List[List[str]]:
        """
        Removes given values from groups
        :param groups: to be removed from
        :param to_remove: to be removed
        :return:
        """
        result = []
        for group in groups:
            result.append([elem for elem in group if elem not in to_remove])

        return result

    @staticmethod
    def __get_lists_containing(groups: List[List[str]], to_contain: List[str]) -> List[List[str]]:
        """
        Gets lists that contains values given
        :param groups: to be find in
        :param to_contain: to be contained
        :return: lists that contains values given in to_contain
        """
        result = []
        for group in groups:
            if all(elem in group for elem in to_contain):
                result.append(group)

        return result


def load_from_csv(filepath: str) -> List[List[str]]:
    dictionary = {}

    with open(filepath, 'r') as csvFile:
        csvFile.readline()  # remove header
        line = csvFile.readline()
        while line:
            try:
                parts = line.split(",")
                user = parts[1]
                city = parts[2]

                if user not in dictionary:
                    dictionary[user] = []

                dictionary[user].append(city)

            except Exception as e:
                print("Fail at " + line)

            line = csvFile.readline()

    result = []

    for value in dictionary.values():
        result.append(value)

    return result


def main():
    print(TripSimilarity.find_recommendation_with_removal(load_from_csv("../../data/trips.csv"),
                                                          ['prague', 'london', 'jakubov', 'berlin', 'amsterdam',
                                                           'madrid', "znojmo"]))


main()