Commit c37e1b06 authored by Tomas Pavlik's avatar Tomas Pavlik
Browse files

initial commit of some resolvers

Took 16 seconds
parent 987ba566
Loading
Loading
Loading
Loading
+0 −0

Empty file added.

+199 −0
Original line number Diff line number Diff line
import configparser
import json
import os.path

import MDAnalysis
from src.utils.file_loader import SimulationFile

from ..base.fragment_resolver import IdentifierResolutionError, SignatureGenerationError
from ..base.fragment_resolver import ResolverBase, ResolverKind
from ..base.resolver_database import ResolverCache
from ...constants import (
    ATOMTABLE_ELEMENTS_MASSES,
    RESOLVER_ATOM_APPROX_FIND_SIMILAR_N_MASSES,
)
from ...utils.repair_simulation import fix_missing_elements


class AtomResolver(ResolverBase):
    """
    AtomResolver looks at the molecule as a single atom (possibly ion)
    """

    RESOLVER_NAME, RESOLVER_NAME_UNDETERMINED = "atom", "atom_aprox"
    RESOLVER_IDENT, RESOLVER_IDENT_UNDETERMINED = "ELEMENT", "ELEMENT_APROX"

    def __init__(
        self,
        cache: ResolverCache,
        configuration: configparser.ConfigParser,
        try_fix_elements: bool = False,
        allow_undetermined_elements: bool = False,
    ) -> None:
        """
        Initializes the resolver of atomic elements
        :param cache: resolver cache
        :param configuration: for the resolver
        :param try_fix_elements: if atomic elements are not present, try to fix them using atomic masses
        :param allow_undetermined_elements: ignore atomic element and only predict possible elements given the mass
        """
        super().__init__(
            (
                AtomResolver.RESOLVER_NAME
                if not allow_undetermined_elements
                else AtomResolver.RESOLVER_NAME_UNDETERMINED
            ),
            ResolverKind.PRIMARY,
            (
                AtomResolver.RESOLVER_IDENT
                if not allow_undetermined_elements
                else AtomResolver.RESOLVER_IDENT_UNDETERMINED
            ),
            cache,
            configuration,
        )

        self.try_fix_elements: bool = try_fix_elements
        self.allow_undetermined_elements: bool = allow_undetermined_elements

    def _is_applicable(
        self, simulation: SimulationFile, fragment: MDAnalysis.AtomGroup
    ) -> bool:
        """
        Decides whether the AtomResolver can be used on the given fragment
        :param fragment: AtomGroup representing a fragment (single atom)
        :param simulation: of which fragment is part of
        :return: True if resolver can be used on a given fragment
        """

        fragment_name: str = fragment.segids[0]

        # in order to resolve fragment as atomic element, it requires to have element information
        # as well as be a single atom. If undetermined element information is allowed, it requires only atomic mass

        if not self.allow_undetermined_elements:
            # firstly, atom element information must be present or determined
            if not simulation.supports_atomic_elements(fragment_name):
                if (
                    simulation.supports_atomic_masses(fragment_name)
                    and self.try_fix_elements
                    and fix_missing_elements(simulation, fragment)
                ):
                    ...  # ok!
                else:
                    return False
        else:
            if not simulation.supports_atomic_masses(fragment_name):
                return False

        # secondly, it must be only one atom (otherwise chem resolver is applicable)
        return fragment.n_atoms == 1

    def get_signature(
        self, simulation: SimulationFile, fragment: MDAnalysis.AtomGroup
    ) -> str:
        """
        Translates fragment into element name or element mass (see .allow_undetermined_elements)
        :param simulation: of which the fragment is part of
        :param fragment: AtomGroup representing a single atom
        :return: element of the atom
        :raises: SignatureGenerationError if fingerprint cannot be created
        """

        if not self.allow_undetermined_elements:
            element_code: str = fragment.elements[0].upper()
            if element_code not in [x.upper() for x in ATOMTABLE_ELEMENTS_MASSES]:
                raise SignatureGenerationError(
                    "extracting atom element",
                    f"atom element is invalid: '{element_code}'",
                )
            return element_code
        else:
            element_mass: str = str(fragment.masses[0])
            return element_mass

    def try_fetch_identifiers_exact_match(self, signature: str) -> list[str]:
        """
        Just returns the signature of element from get_signature(...).
        :param signature: to be translated into standardized element name
        :return: list of exactly one element name, given in the fingerprint
        :raises IdentifierResolutionError: if allow_undetermined_elements is True (no identity match in that mode)
        """
        if not self.allow_undetermined_elements:
            return [signature]
        else:
            raise IdentifierResolutionError(
                "resolving ATOM identity match",
                "not supported with .allow_undetermined_elements",
            )

    def try_fetch_identifiers_relaxed_match(
        self, signature: str
    ) -> list[tuple[str, float]]:
        """
        If allow_undetermined_elements is False, returns fingerprint and 100% confidence.
        Otherwise, tries to find similar elements given atomic mass.
        :param signature: to be translated into standardized element name
        :return: list of elements with similarity
        """

        if not self.allow_undetermined_elements:
            return [(signature, 100)]

        atomic_mass: float = float(signature)
        mass_differences: list[tuple[str, float, float]] = [
            (element, abs(atomic_mass - element_mass), element_mass)
            for element, element_mass in ATOMTABLE_ELEMENTS_MASSES.items()
        ]
        mass_differences.sort(key=lambda x: x[1])

        results: list[tuple[str, float]] = []
        for elem, _, mass in mass_differences[:RESOLVER_ATOM_APPROX_FIND_SIMILAR_N_MASSES]:
            similarity: float = atomic_mass / mass
            if similarity > 1:
                similarity = 1 / similarity  # symmetric
            results.append((elem, round(similarity * 100)))
        return results

    def generate_debug_data(
        self,
        simulation: SimulationFile,
        fragment: MDAnalysis.AtomGroup,
        out_path: str,
        out_name: str,
    ) -> bool:
        """
        Generates very basic information about the atom in the simulation
        :param out_path: directory where to save generated representations
        :param out_name: basename of the file to save representations to (will be suffixed with kind of export)
        :param simulation: of which fragment is part of
        :param fragment: AtomGroup representing a molecule
        :return: True if export went without problems
        """
        with open(os.path.join(out_path, f"{out_name}_ATOM_desc.json"), "w") as fd:
            json.dump(
                {
                    "atom_element": (
                        fragment.elements[0]
                        if hasattr(fragment, "elements") and len(fragment.elements) > 0
                        else None
                    ),
                    "atom_name": (
                        fragment.names[0]
                        if hasattr(fragment, "names") and len(fragment.names) > 0
                        else None
                    ),
                    "atom_mass": (
                        fragment.masses[0]
                        if hasattr(fragment, "masses") and len(fragment.masses) > 0
                        else None
                    ),
                    "atom_charge": (
                        fragment.charges[0]
                        if hasattr(fragment, "charges") and len(fragment.charges) > 0
                        else None
                    ),
                },
                fd,
            )
        return True
+0 −0

Empty file added.

+280 −0
Original line number Diff line number Diff line
import configparser
import logging
from enum import Enum

import MDAnalysis

from .resolver_database import ResolverCache
from ...utils.file_loader import SimulationFile
from ...utils.timeout import timeout

log = logging.getLogger("resolver")


class SignatureGenerationError(Exception):
    def __init__(self, generation_step, cause):
        super().__init__(f"[RES.SIGN] Error while {generation_step}: {cause}")
        self.generation_step = generation_step
        self.cause = cause


class IdentifierResolutionError(Exception):
    def __init__(self, generation_step, cause):
        super().__init__(f"[RES.IDENT] Error while {generation_step}: {cause}")
        self.generation_step = generation_step
        self.cause = cause


class ResolverKind(Enum):
    PRIMARY = 1  # identity match can be considered ground truth
    FALLBACK = 2  # identity match can be considered true with very high likelihood


class ResolverBase:
    """
    Abstract class for fragment resolver. The workflow for resolvers is as follows:
    input: MDAnalysis' fragment (one molecule)
     ... .is_applicable() - can this resolver parse this fragment? (a quick check, but signature generation might fail)
     ... .get_signature() - translates atomic fragment into its string signature
     ... .get_identifiers() - attempts to identity match signature into standardized identifier(s)
     ... .get_similar_identifiers() - attempts to similarity match signature into standardized identifier(s)
     ... .generate_debug_data() - generates representation of the atomic fragment in paradigm of the resolver
     !! Due to built-in timeout with signals, calling get_identifiers or get_similar_identifiers is not thread safe !!
     !! run only one concurrent instance in a single process !!
    """

    def __init__(
        self,
        resolver_name: str,
        resolver_kind: ResolverKind,
        ident_name: str,
        cache: ResolverCache,
        configuration: configparser.ConfigParser,
    ) -> None:
        """
        Initializes base resolver
        :param resolver_name: name of the resolver (for caching purposes)
        :param resolver_kind: type of the resolver
        :param ident_name: name of the identifier (for result purposes)
        :param cache: cache for the resolver
        :param configuration: global configuration for the application
        """
        self.resolver_name: str = resolver_name
        self.resolver_kind: ResolverKind = resolver_kind
        self.ident_name: str = ident_name
        self.cache: ResolverCache = cache
        self.config: configuration.ConfigParser = configuration

        self.allow_resolution: bool = False  # 'activates' the resolver
        self.allow_partial_resolution: bool = (
            False  # on non-identity hit, still returns partial results and scores
        )
        self.resolve_through_cache: bool = False  # allows using cache for resolution
        self.resolve_through_search: bool = (
            False  # allows using online or offline databases for resolution
        )
        self.allow_debug_output: bool = (
            False  # allows resolver during debug simulation export
        )
        self.fetch_idents_timeout_sec: int | None = (
            None  # maximum time for fetching identifiers, None for unlimited
        )
        self._load_config(configuration)

    def _load_config(self, configuration: configparser.ConfigParser) -> None:
        """
        Loads resolver settings based on configuration provided
        :param configuration: loaded ConfigParser() instance
        :return: None
        """
        if not configuration.has_section(self.resolver_name):
            log.error(
                f"Resolver {self.resolver_name} is not in the configuration! Disabled by default"
            )
            return

        if configuration.has_option(self.resolver_name, "enabled"):
            self.allow_resolution = configuration.getboolean(
                self.resolver_name, "enabled"
            )

        if configuration.has_option(self.resolver_name, "enabled_similarity"):
            self.allow_partial_resolution = configuration.getboolean(
                self.resolver_name, "enabled_similarity"
            )

        if configuration.has_option(self.resolver_name, "from_cache"):
            self.resolve_through_cache = configuration.getboolean(
                self.resolver_name, "from_cache"
            )

        if configuration.has_option(self.resolver_name, "from_search"):
            self.resolve_through_search = configuration.getboolean(
                self.resolver_name, "from_search"
            )

        if configuration.has_option(self.resolver_name, "allow_export"):
            self.allow_debug_output = configuration.getboolean(
                self.resolver_name, "allow_export"
            )

        if configuration.has_option(self.resolver_name, "timeout_sec"):
            self.fetch_idents_timeout_sec = configuration.getint(
                self.resolver_name, "timeout_sec"
            )
            if self.fetch_idents_timeout_sec == 0:
                self.fetch_idents_timeout_sec = None

    def is_applicable(
        self, simulation: SimulationFile, fragment: MDAnalysis.AtomGroup
    ) -> bool:
        """
        Decides whether the resolver can be used on the given fragment
        :param fragment: AtomGroup representing a molecule
        :param simulation: of which fragment is part of
        :return: True if resolver can be used on a given fragment
        """
        if not self.allow_resolution:
            return False
        return self._is_applicable(simulation, fragment)

    def _is_applicable(
        self, simulation: SimulationFile, fragment: MDAnalysis.AtomGroup
    ) -> bool:
        """
        Decides whether the resolver can be used on the given fragment.
        To be implemented by resolvers
        :param fragment: AtomGroup representing a molecule
        :param simulation: of which fragment is part of
        :return: True if resolver can be used on a given fragment
        """
        raise NotImplementedError

    def get_signature(
        self, simulation: SimulationFile, fragment: MDAnalysis.AtomGroup
    ) -> str:
        """
        Translates fragment into its string fingerprint
        To be implemented by resolvers
        :param fragment: AtomGroup representing a molecule
        :param simulation: of which fragment is part of
        :return: string fingerprint of the fragment given the resolver
        :raises: SignatureGenerationError if fingerprint cannot be created
        """
        raise NotImplementedError

    def get_identifiers(self, signature: str) -> list[str]:
        """
        Fetches <<identity-level>> identifiers associated with the given fingerprint
        :param signature: string fingerprint of the fragment given the resolver
        :return: list of standardized identifiers of the fingerprint
        :raises: IdentifierNotFoundError if no identifier couldn't be found
        """
        if self.resolve_through_cache:
            cached_identifiers: list[str] = self.cache.fetch_identity_identifiers(
                self.resolver_name, signature
            )
            if cached_identifiers:
                return cached_identifiers

        if self.resolve_through_search:

            search_timeout_error = IdentifierResolutionError(
                "resolving identifiers", "fetching identity matches timed out"
            )
            with timeout(
                seconds=self.fetch_idents_timeout_sec, error=search_timeout_error
            ):
                fetched_identifiers: list[str] = self.try_fetch_identifiers_exact_match(
                    signature
                )

            if fetched_identifiers:
                self.cache.save_identity_identifiers(
                    self.resolver_name, signature, fetched_identifiers
                )
                return fetched_identifiers

        raise IdentifierResolutionError(
            "resolving identifiers",
            "no identifiers found for identity match of the fingerprint",
        )

    def get_similar_identifiers(self, signature: str) -> list[tuple[str, float]]:
        """
        Fetches similar identifiers associated with the given fingerprint
        :param signature: string fingerprint of the fragment given the resolver
        :return: list of standardized identifiers of the fingerprint and their scores
        :raises: IdentifierNotFoundError if no identifier couldn't be found or this functionality is disabled
        """
        if not self.allow_partial_resolution:
            raise IdentifierResolutionError(
                "resolving similar identifiers", "functionality is disabled"
            )

        if self.resolve_through_cache:
            cached_identifiers: list[tuple[str, float]] = (
                self.cache.fetch_similarity_identifiers(self.resolver_name, signature)
            )
            if cached_identifiers:
                return cached_identifiers

        if self.resolve_through_search:
            search_timeout_error = IdentifierResolutionError(
                "resolving identifiers", "fetching similarity matches timed out"
            )
            with timeout(
                seconds=self.fetch_idents_timeout_sec, error=search_timeout_error
            ):
                fetched_identifiers: list[tuple[str, float]] = (
                    self.try_fetch_identifiers_relaxed_match(signature)
                )

            if fetched_identifiers:
                return fetched_identifiers

            raise IdentifierResolutionError(
                "resolving identifiers",
                "no identifiers found for similar match of the fingerprint",
            )

    def try_fetch_identifiers_exact_match(self, signature: str) -> list[str]:
        """
        Attempts to resolve signature into standardized identifiers, requiring identity match of the fingerprint
        To be implemented by resolvers
        :param signature: to be translated into identifiers
        :return: list of standardized identifiers of the fingerprint
        :raises IdentifierResolutionError: on internal failure during fetching
        """
        raise NotImplementedError

    def try_fetch_identifiers_relaxed_match(
        self, signature: str
    ) -> list[tuple[str, float]]:
        """
        Attempts to resolve signature into standardized identifiers, based on similarity match(es)
        To be implemented by resolvers
        :param signature: to be translated into identifiers
        :return: list of standardized identifiers of the fingerprint and their scores
        :raises IdentifierResolutionError: on internal failure during fetching
        """
        raise NotImplementedError

    def generate_debug_data(
        self,
        simulation: SimulationFile,
        fragment: MDAnalysis.AtomGroup,
        out_path: str,
        out_name: str,
    ) -> bool:
        """
        If resolver is applicable and can generate signature,
        generates representation of the fragment AtomGroup in resolver's paradigm
        To be implemented by resolvers
        :param out_path: directory where to save generated representations
        :param out_name: basename of the file to save representations to (will be suffixed with kind of export)
        :param simulation: of which fragment is part of
        :param fragment: AtomGroup representing a molecule
        :return: True if export went without problems
        """
        raise NotImplementedError
+134 −0
Original line number Diff line number Diff line
import configparser
import logging
import os.path
import sqlite3
import warnings

IS_IDENTITY = 1
IS_SIMILARITY = 0

log = logging.getLogger("base")


class ResolverCache:
    def __init__(self, configuration: configparser.ConfigParser) -> None:
        """
        Initializes new or loads existing resolver cache database
        :param configuration: for determining cache directory and cache file
        """

        cache_directory: str = configuration.get("cache", "directory")
        cache_file: str = configuration.get("cache", "file")

        self.cache_path: str = os.path.join(cache_directory, cache_file)
        if not os.path.exists(self.cache_path):
            log.warning(
                f"resolver cache not found at '{self.cache_path}'. Initializing empty"
            )
            self._create_empty()

        self.conn: sqlite3.Connection = sqlite3.connect(self.cache_path)

    def _create_empty(self) -> None:
        """
        Creates an empty cache database file
        :return: None
        """
        directory_path: str = os.path.dirname(self.cache_path)
        if not os.path.exists(directory_path):
            os.makedirs(directory_path)
        with sqlite3.connect(self.cache_path) as conn:
            conn.execute(
                """CREATE TABLE "CACHE" (
                            "RESOLVER_NAME"	TEXT NOT NULL,
                            "SIGNATURE"	TEXT NOT NULL,
                            "IDENTIFIER"	TEXT NOT NULL,
                            "IS_IDENTITY"	INTEGER NOT NULL,
	                        "SIMILARITY"	REAL,
                            PRIMARY KEY("IDENTIFIER","SIGNATURE","RESOLVER_NAME"));
                        """
            )

    def fetch_identity_identifiers(
        self, resolver_name: str, signature: str
    ) -> list[str]:
        """
        Finds all identifiers already resolved on identity level and cached
        :param resolver_name: of resolver trying to identify the signature
        :param signature: of the molecule
        :return: list of all identifiers associated with the signature on identity level
        """
        with self.conn as conn:
            records = conn.execute(
                "SELECT IDENTIFIER FROM CACHE WHERE RESOLVER_NAME=? AND SIGNATURE=? AND IS_IDENTITY=?",
                (resolver_name, signature, IS_IDENTITY),
            ).fetchall()
            return [record[0] for record in records]

    def fetch_similarity_identifiers(
        self, resolver_name: str, signature: str
    ) -> list[tuple[str, float]]:
        """
        Finds all similar identifiers already resolved on similarity level and cached
        :param resolver_name: of resolver trying to identify the signature
        :param signature: of the molecule
        :return: list of all identifiers associated with the fingerprint on similarity level
        """
        with self.conn as conn:
            records = conn.execute(
                "SELECT IDENTIFIER, SIMILARITY FROM CACHE "
                "WHERE RESOLVER_NAME=? AND SIGNATURE=? AND IS_IDENTITY=?",
                (resolver_name, signature, IS_SIMILARITY),
            ).fetchall()
            return records

    def save_identity_identifiers(
        self, resolver_name: str, signature: str, identifiers: list[str]
    ) -> None:
        """
        Saves all identifiers found into the cache as identical resolutions of the signature
        :param resolver_name: of the resolver that found the identifiers
        :param signature: of which identifiers were found
        :param identifiers: to be saved
        :return: None
        """

        with self.conn as conn:
            identifiers_already_cached: set[str] = set(
                self.fetch_identity_identifiers(resolver_name, signature)
            )
            for ident in identifiers:
                if ident in identifiers_already_cached:
                    continue
                conn.execute(
                    "INSERT INTO CACHE VALUES (?, ?, ?, ?, ?)",
                    (resolver_name, signature, ident, IS_IDENTITY, None),
                )
            conn.commit()

    def save_similarity_identifiers(
        self, resolver_name: str, signature: str, identifiers: list[tuple[str, float]]
    ) -> None:
        """
        Saves all identifiers found into the cache as similar resolutions of the signature
        :param resolver_name: of the resolver that found the identifiers
        :param signature: of which identifiers were found
        :param identifiers: to be saved
        :return: None
        """

        with self.conn as conn:
            identifiers_already_cached: set[str] = set(
                _ident
                for _ident, _ in self.fetch_similarity_identifiers(
                    resolver_name, signature
                )
            )
            for _ident, _similarity in identifiers:
                if _ident in identifiers_already_cached:
                    continue
                conn.execute(
                    "INSERT INTO CACHE VALUES (?, ?, ?, ?, ?)",
                    (resolver_name, signature, _ident, IS_SIMILARITY, _similarity),
                )
            conn.commit()
Loading