Commit 608802ba authored by Marek Medved's avatar Marek Medved
Browse files

new context

parent 78443ecf
Loading
Loading
Loading
Loading
+15 −1
Original line number Diff line number Diff line
@@ -50,7 +50,15 @@ def name_phrases(text, title, vocabulary, context_window, num_phr_per_sent, w2v)
            phr = persistent.list.PersistentList()
            for token in p.split('\n'):
                if not struct_re.match(token):
                    word, lemma, tag = token.strip().split('\t')[:3]
                    spl = token.strip().split('\t')
                    try:
                        word, lemma, tag = spl[:3]
                    except ValueError as e:
                        print(f'Something goes wrong while splitting line: "{token}" in:\n'
                              f'{p}\n'
                              f'splitted as {spl}')
                        print(e)
                        sys.exit()
                    wid = word2id(vocabulary, word, lemma, tag, w2v)
                    phr.append(wid)
            phrases.append(phr)
@@ -149,6 +157,12 @@ def main():
    db = SqadDb(args.db_path)
    add_np_phrases(db, args.context_window, args.num_phr_per_sent, w2v)

    db.root['__ctx_types__'].append(f'name_phrs_w{args.context_window}_n{args.num_phr_per_sent}')
    db._p_changed = True
    transaction.commit()
    print(db.root['__ctx_types__'])
    db.close()


if __name__ == "__main__":
    main()
+46 −0
Original line number Diff line number Diff line
#!/usr/bin/python3
# coding: utf-8
from sqad_db import SqadDb
from sqad_db import id2word
import persistent.list
import transaction


def add_ctx(db, number):
    vocabulary, _, kb = db.get_dicts()
    for url, text in kb.url2doc.items():
        for sent_num, sent in enumerate(text['text']):
            print(f"s:{' '.join([id2word(vocabulary, x)['word'] for x in sent['sent']])}")

            if not sent['ctx'].get(f'prev_sent_n{number}'):
                if sent_num == 0:
                    print(f"\tc:{' '.join([id2word(vocabulary, x)['word'] for x in text['title'][0]])}")
                    sent['ctx'][f'prev_sent_n{number}'] = persistent.list.PersistentList([text['title'][0]])
                else:
                    print(f"\tc:{' '.join([id2word(vocabulary, x)['word'] for x in text['text'][sent_num - 1]['sent']])}")
                    sent['ctx'][f'prev_sent_n{number}'] = persistent.list.PersistentList([text['text'][sent_num - 1]['sent']])
                db._p_changed = True
                transaction.commit()


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Add noun phrases as context to sentences')
    parser.add_argument('-n', '--number', type=int,
                        required=False, default=1,
                        help='Number of previous sentences as a context')
    parser.add_argument('-d', '--db_path', type=str,
                        required=True,
                        help='Database path')
    args = parser.parse_args()

    db = SqadDb(args.db_path)
    add_ctx(db, args.number)
    db.root['__ctx_types__'].append(f'prev_sent_n{args.number}')
    db._p_changed = True
    transaction.commit()
    print(db.root['__ctx_types__'])


if __name__ == "__main__":
    main()
+6 −1
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import persistent.list
from BTrees.OOBTree import BTree
from persistent import Persistent
from get_vector import Word2vec
import sys


# =============================================
@@ -133,7 +134,11 @@ class Record(Persistent):
class SqadDb:
    def __init__(self, file_name, read_only=False):
        self.file_name = file_name
        try:
            self.storage = ZODB.FileStorage.FileStorage(self.file_name, read_only=read_only)
        except BlockingIOError:
            print('ERROR: database currently unavailable.')
            sys.exit()
        self.db = ZODB.DB(self.storage)
        self.connection = self.db.open()
        self.root = self.connection.root()