Commit b37ac96b authored by Marek Medved's avatar Marek Medved
Browse files

new approach with separate KB - working

parent 61ba9ef8
Loading
Loading
Loading
Loading

__init__.py

0 → 100644
+0 −0

Empty file added.

context.py

0 → 100644
+98 −0
Original line number Diff line number Diff line
#!/usr/bin/python2
# coding: utf-8
import locale
import os
import re
import sys
from BTrees.OOBTree import BTree
import persistent.list
from sqad_db import word2id

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(f'{dir_path}/set')
from grammar import Grammar
from setparser import Parser
from segment import Segment, Tree

locale.setlocale(locale.LC_ALL, '')
struct_re = re.compile('^<[^>]+>$')


class SetInterface:
    def __init__(self, grammar_path):
        self.grammar_path = grammar_path
        self.grammar = os.path.join(self.grammar_path)

    def parse_input(self, lines):
        g = Grammar(self.grammar)
        p = Parser(g)
        s = Segment(lines)
        p.parse(s)
        return s.get_marx_phrases_vert(filter_phr=True)


def get_struct(data, struct_id):
    struct = []
    struct_start = False
    struct_end = False

    for line in data:
        if re.match("^<{}>$".format(struct_id), line.strip()) or \
                re.match("^<{} .*?>$".format(struct_id), line.strip()):
            struct_start = True
        if re.match("^</{}>$".format(struct_id), line.strip()):
            struct_end = True
        if struct_start:
            struct.append(line.strip())
        if struct_start and struct_end:
            yield struct
            struct = []
            struct_start = False
            struct_end = False


def name_phrases(text, title, context_window, num_phr_per_sent, vocabulary, w2v):
    set_parser = SetInterface("/nlp/projekty/set/set/grammar.set")
    text_context = persistent.list.PersistentList()
    # Read file and create phrases for all sentences
    phrases_per_sentence = []
    for sent_num, sentence in enumerate(get_struct(text, 's')):
        set_phrases = set_parser.parse_input(sentence)
        phrases = persistent.list.PersistentList()
        for p in set_phrases:
            phr = persistent.list.PersistentList()
            for token in p.split('\n'):
                if not struct_re.match(token):
                    word, lemma, tag = token.strip().split('\t')[:3]
                    wid = word2id(vocabulary, word, lemma, tag, w2v)
                    phr.append(wid)
            phrases.append(phr)

        phrases_per_sentence.append(phrases)

    # Crating context according to args.context_length and args.number
    for curr_sent_pos in range(len(phrases_per_sentence)):
        context_phrases = persistent.list.PersistentList()
        context_position = curr_sent_pos - 1
        while (context_position >= 0) and (curr_sent_pos - context_position <= context_window):
            context_phrases += phrases_per_sentence[context_position][:num_phr_per_sent]
            context_position -= 1

        # Title as a context for first sentence in document
        if curr_sent_pos == 0:
            title_phr = persistent.list.PersistentList()
            for token in title:
                if not struct_re.match(token):
                    word, lemma, tag = token.strip().split('\t')[:3]
                    wid = word2id(vocabulary, word, lemma, tag, w2v)
                    title_phr.append(wid)
            context_phrases.append(title_phr)

        text_context.append(context_phrases)

    return text_context


def get_context(text, title, context_window, num_phr_per_sent, vocabulary, w2v):
    context = BTree({'name_phr': name_phrases(text, title, context_window, num_phr_per_sent, vocabulary, w2v)})
    return context
+28 −16
Original line number Diff line number Diff line
@@ -52,32 +52,44 @@ def get_record(db, record_id):

def print_record(db, record_id):
    record = db.get_record(record_id)
    vocabulary, qa_type_dict = db.get_dicts()
    vocabulary, qa_type_dict, kb = db.get_dicts()

    # get word, lemma, tag, vector
    def get_wltv(data):
        for sentence in data:
    def get_ctx(data):
        for ctx_type, phrs in data.items():
            print(f'\t\tctx_type: {ctx_type}')
            for p in phrs:
                p_content = []
                for w_id_cx in p:
                    p_content.append(id2word(vocabulary, w_id_cx)['word'])
                print(f'\t\t\tc: {" ".join(p_content)}')

    def get_senence(data):
        sent = []
            for w_id in sentence['sentence']:
        for w_id in data:
            sent.append(id2word(vocabulary, w_id)['word'])
        print(f'\ts: {" ".join(sent)}')
            for phr in sentence['phrs']:
                p = []
                for w_id in phr:
                    p.append(id2word(vocabulary, w_id)['word'])
                print(f'\t\tc: {" ".join(p)}')

    def get_content(data):
        for sentence in data:
            get_senence(sentence)

    def get_content_ctx(url, kb):
        for sentence in kb.url2doc.get(url):
            get_senence(sentence['content'])
            get_ctx(sentence['context'])


    print(f'rec_id: {record.rec_id}')
    print(f'q_type: {id2qt(qa_type_dict, record.q_type)}')
    print(f'a_type: {id2qt(qa_type_dict, record.a_type)}')
    print('question:')
    get_wltv(record.question)
    get_content(record.question)
    print('a_sel:')
    get_wltv(record.answer_selection)
    get_content(record.answer_selection)
    print('a_ext:')
    get_wltv(record.answer_extraction)
    get_content(record.answer_extraction)
    print('text:')
    get_wltv(record.text)
    get_content_ctx(record.text, kb)


def main():

set/.gitignore

0 → 100644
+24 −0
Original line number Diff line number Diff line
pdtgrammar*
*.brief
brief
set
set*.tar.gz
pdt_results
tmp*
/run_set.sh
*.pyc
*.swp
grammar.set.dev
compare.py
data/collocations2.collx
data/collocations2.dict
run.sh
compareetest.py
nohup.etest
nohup.out
runetest.sh
tree.eps
run2.sh
oldparser.py
cmpltmpl
results

set/COPYING

0 → 100644
+674 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading