Commit d55f9b68 authored by Marek Medved's avatar Marek Medved
Browse files

Sqad to ZODB databse - compressed

parents
Loading
Loading
Loading
Loading

.gitignore

0 → 100644
+1 −0
Original line number Diff line number Diff line
sqad_db

get_vector.py

0 → 100644
+45 −0
Original line number Diff line number Diff line
#! /usr/bin/python3
import fasttext
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))


class Word2vec():
    def __init__(self):
        self.model_100 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_100_5_5_0.05_skip')
        self.model_300 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_300_5_5_0.05_skip')
        self.model_500 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_500_5_5_0.05_skip')

    def get_vector(self, data, dim):
        result = []
        if dim == 100:
            for w in data.strip().split(' '):
                result.append(self.model_100[w])
            return result
        elif dim == 300:
            for w in data.strip().split(' '):
                result.append(self.model_300[w])
            return result
        elif dim == 500:
            for w in data.strip().split(' '):
                result.append(self.model_500[w])
            return result

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Word2vec fastext')
    parser.add_argument('-i', '--input', type=argparse.FileType('r'),
                        required=False, default=sys.stdin,
                        help='Input')
    parser.add_argument('-o', '--output', type=argparse.FileType('w'),
                        required=False, default=sys.stdout,
                        help='Output')

    args = parser.parse_args()
    w2v = Word2vec()
    print(w2v.get_vector(args.input.read(), 100))


if __name__ == "__main__":
    main()

query_database.py

0 → 100755
+92 −0
Original line number Diff line number Diff line
#! /usr/bin/python3
# coding: utf-8
from sqad_db import SqadDb
from sqad_db import id2word
from sqad_db import id2qt
from pprint import pprint


def get_record(db, record_id):
    record = db.get_record(record_id)
    vocabulary, qa_type_dict = db.get_dicts()

    # get word, lemma, tag, vector
    def get_wltv(data):
        result = []
        for sentence in data:
            sent_res = {'sentence': [], 'phrs': []}
            for w_id in sentence['sentence']:
                sent_res['sentence'].append(id2word(vocabulary, w_id))
            for phr in sentence['phrs']:
                p = []
                for w_id in phr:
                    p.append(id2word(vocabulary, w_id))
                sent_res['phrs'].append(p)
            result.append(sent_res)
        return result

    data = {}
    data['rec_id'] = record.rec_id
    data['q_type'] = id2qt(qa_type_dict, record.q_type)
    data['a_type'] = id2qt(qa_type_dict, record.a_type)
    data['question'] = get_wltv(record.question)
    data['a_sel'] = get_wltv(record.answer_selection)
    data['a_ext'] = get_wltv(record.answer_extraction)
    data['text'] = get_wltv(record.text)

    return data


def print_record(db, record_id):
    record = db.get_record(record_id)
    vocabulary, qa_type_dict = db.get_dicts()

    # get word, lemma, tag, vector
    def get_wltv(data):
        for sentence in data:
            sent = []
            for w_id in sentence['sentence']:
                sent.append(id2word(vocabulary, w_id)['word'])
            print(f'\ts: {" ".join(sent)}')
            for phr in sentence['phrs']:
                p = []
                for w_id in phr:
                    p.append(id2word(vocabulary, w_id)['word'])
                print(f'\t\tc: {" ".join(p)}')

    print(f'rec_id: {record.rec_id}')
    print(f'q_type: {id2qt(qa_type_dict, record.q_type)}')
    print(f'a_type: {id2qt(qa_type_dict, record.a_type)}')
    print('question:')
    get_wltv(record.question)
    print('a_sel:')
    get_wltv(record.answer_selection)
    print('a_ext:')
    get_wltv(record.answer_extraction)
    print('text:')
    get_wltv(record.text)


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Transfrom SQAD to one database file with all features')
    parser.add_argument('-d', '--database_file', type=str,
                        required=True,
                        help='Database file name')
    parser.add_argument('-r', '--record_id', type=str,
                        required=True,
                        help='Record id')
    parser.add_argument('--simple', action='store_true',
                        required=False, default=False,
                        help='Simple output')
    args = parser.parse_args()
    db = SqadDb(args.database_file)
    if args.simple:
        print_record(db, args.record_id)
    else:
        pprint(get_record(db, args.record_id))
    db.close()


if __name__ == "__main__":
    main()

sqad2database.py

0 → 100755
+138 −0
Original line number Diff line number Diff line
#! /usr/bin/python3
# coding: utf-8
import re
import os
import datetime
from sqad_db import SqadDb
from sqad_db import Vocabulary
from sqad_db import W2V
from sqad_db import QAType
from sqad_db import Record
from sqad_db import word2id
from sqad_db import qt2id
from BTrees.OOBTree import BTree
import persistent.list

struct_re = re.compile('^<[^>]+>$')
q_type_re = re.compile('^<q_type>(.*)</q_type>$')
a_type_re = re.compile('^<a_type>(.*)</a_type>$')
hash_re = re.compile('^#$')


def get_struct(data, struct_id):
    struct = []
    struct_start = False
    struct_end = False

    for line in data:
        if re.match("^<{}>$".format(struct_id), line.strip()) or \
                re.match("^<{} .*?>$".format(struct_id), line.strip()):
            struct_start = True
        if re.match("^</{}>$".format(struct_id), line.strip()):
            struct_end = True
        if struct_start:
            struct.append(line.strip())
        if struct_start and struct_end:
            yield struct
            struct = []
            struct_start = False
            struct_end = False


def fill(data, rec_part, vocabulary, w2v, context=False):
    if not context:
        for s in get_struct(data, 's'):
            sent = persistent.list.PersistentList()
            for token in s:
                if not struct_re.match(token):
                    if hash_re.match(token.strip()):
                        word, lemma, tag = ('#', '#', '#')
                    else:
                        word, lemma, tag = token.strip().split('\t')[:3]
                    wid = word2id(vocabulary, word, lemma, tag, w2v)
                    sent.append(wid)
            rec_part.append(BTree({'sentence': sent, 'phrs': []}))
    else:
        for c in get_struct(data, 'context'):
            sent = persistent.list.PersistentList()
            phrs = persistent.list.PersistentList()
            for s in get_struct(c, 's'):
                for token in s:
                    if not struct_re.match(token):
                        word, lemma, tag = token.strip().split('\t')[:3]
                        wid = word2id(vocabulary, word, lemma, tag, w2v)
                        sent.append(wid)
            for p in get_struct(c, 'phr'):
                phr = persistent.list.PersistentList()
                for token in p:
                    if not struct_re.match(token):
                        word, lemma, tag = token.strip().split('\t')[:3]
                        wid = word2id(vocabulary, word, lemma, tag, w2v)
                        phr.append(wid)
                phrs.append(phr)
            rec_part.append(BTree({'sentence': sent, 'phrs': phrs}))


def fill_qa_type(data, qa_type):
    q_type = -1
    a_type = -1
    for line in data:
        if q_type_re.match(line):
            q_type = qt2id(qa_type, q_type_re.match(line).group(1))
        elif a_type_re.match(line):
            a_type = qt2id(qa_type, a_type_re.match(line).group(1))
    return q_type, a_type


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Transfrom SQAD to one database file with all features')
    parser.add_argument('-p', '--path', type=str,
                        required=True,
                        help='Path to SQAD records root dir')
    parser.add_argument('-n', '--name', type=str,
                        required=True,
                        help='Resulting database name')
    parser.add_argument('--test', action='store_true',
                        required=False, default=False,
                        help='Testing switch. For development (no vectors are loaded)')
    args = parser.parse_args()
    db = SqadDb('sqad_db/{0}_{1:%d_%m_%Y-%H:%M:%S}'.format(args.name, datetime.datetime.now()))
    rec_id_re = re.compile('(\d+)')
    counter = 0

    vocabulary = Vocabulary()
    qa_type = QAType()
    vectors = W2V(test=args.test)

    print(f'Processing records:')
    for root, dirs, files in os.walk(args.path):
        if rec_id_re.match(root.split('/')[-1]):
            rec_id = rec_id_re.match(root.split('/')[-1]).group(1)
            print(f'{counter}\r')
            record = Record(rec_id)
            for file_name in files:
                if file_name in ['01question.vert', '03text.context.vert',
                                 '05metadata.txt', '06answer.selection.vert',
                                 '09answer_extraction.vert']:
                    with open(os.path.join(root, file_name), 'r') as f:
                        if file_name == '01question.vert':
                            fill(f.readlines(), record.question, vocabulary, vectors)
                        elif file_name == '03text.context.vert':
                            fill(f.readlines(), record.text, vocabulary, vectors, context=True)
                        elif file_name == '05metadata.txt':
                            record.q_type, record.a_type = fill_qa_type(f.readlines(), qa_type)
                        elif file_name == '06answer.selection.vert':
                            fill(f.readlines(), record.answer_selection, vocabulary, vectors)
                        elif file_name == '09answer_extraction.vert':
                            fill(f.readlines(), record.answer_extraction, vocabulary, vectors)
            counter += 1
            db.add(rec_id, record)

    db.add_vocab(vocabulary)
    db.add_qa_type(qa_type)
    db.close()


if __name__ == "__main__":
    main()

sqad_db.py

0 → 100644
+145 −0
Original line number Diff line number Diff line
#! /usr/bin/python3
# coding: utf-8
import ZODB
import ZODB.FileStorage
import transaction
import persistent.list
from BTrees.OOBTree import BTree
from persistent import Persistent
from get_vector import Word2vec


# =============================================
# Vocabulary
# =============================================
def word2id(vocabulary, word, lemma, tag, w2v):
    if vocabulary.w2id.get(word, None):
        return vocabulary.w2id[word]
    else:
        key = vocabulary.new_id()
        vocabulary.w2id[word] = key
        vocabulary.wlt[key] = {'word': word, 'lemma': lemma, 'tag': tag}
        w2v.add_vector(vocabulary, key, word)
        return key


def id2word(vocabulary, key):
    result = {}
    result['word'] = vocabulary.wlt[key]['word']
    result['lemma'] = vocabulary.wlt[key]['lemma']
    result['tag'] = vocabulary.wlt[key]['tag']
    result['v100'] = vocabulary.vectors[key][0]
    result['v300'] = vocabulary.vectors[key][1]
    result['v500'] = vocabulary.vectors[key][2]
    return result


class W2V:
    def __init__(self, test=False):
        self.test = test
        if not self.test:
            self.w2v = Word2vec()

    def get_vect(self, word):
        result = persistent.list.PersistentList()
        if self.test:
            result.append(None)
            result.append(None)
            result.append(None)
        else:
            result.append(persistent.list.PersistentList(self.w2v.get_vector(word, 100)))
            result.append(persistent.list.PersistentList(self.w2v.get_vector(word, 300)))
            result.append(persistent.list.PersistentList(self.w2v.get_vector(word, 500)))
        return result

    def add_vector(self, vocabulary, key, word):
        vocabulary.vectors[key] = self.get_vect(word)


class Vocabulary(Persistent):
    def __init__(self):
        self.wlt = BTree()  # key: id, value: word, lemma, tag
        self.w2id = BTree()  # key: word, value: id
        self.key = 0
        self.vectors = BTree()  # key: word_id, value: v100, v300, v300

    def new_id(self):
        self.key += 1
        return self.key


# =============================================
# QA type
# =============================================
def qt2id(qa_type, qt_type):
    if qa_type.t2id.get(qt_type, None):
        return qa_type.t2id[qt_type]
    else:
        key = qa_type.new_id()
        qa_type.id2t[key] = qt_type
        qa_type.t2id[qt_type] = key
        return key


def id2qt(qa_type, key):
    return qa_type.id2t.get(key, -1)


class QAType(Persistent):
    def __init__(self):
        self.id2t = BTree()  # key: id, value: type
        self.t2id = BTree()  # key: type, value: id
        self.key = 0

    def new_id(self):
        self.key += 1
        return self.key


# =============================================
# Record
# =============================================
class Record(Persistent):
    def __init__(self, rec_id):
        self.rec_id = rec_id
        self.question = persistent.list.PersistentList()  # List of sentences
        self.answer_selection = persistent.list.PersistentList()  # List of sentences
        self.answer_extraction = persistent.list.PersistentList()  # List of sentences
        self.text = persistent.list.PersistentList()  # List of sentences
        self.q_type = -1
        self.a_type = -1


# =============================================
# Sqad database
# =============================================
class SqadDb:
    def __init__(self, file_name):
        self.file_name = file_name
        self.storage = ZODB.FileStorage.FileStorage(self.file_name)
        self.db = ZODB.DB(self.storage)
        self.connection = self.db.open()
        self.root = self.connection.root()

    def add(self, rec_id, record_object):
        self.root[rec_id] = record_object
        transaction.commit()

    def add_vocab(self, vocab):
        self.root['__vocabulary__'] = vocab
        transaction.commit()

    def add_qa_type(self, qa_type):
        self.root['__qa_type__'] = qa_type
        transaction.commit()

    def get_dicts(self):
        return self.root['__vocabulary__'], self.root['__qa_type__']

    def get_record(self, rec_id):
        return self.root[rec_id]

    def close(self):
        self.connection.close()
        self.db.close()
        self.storage.close()