Commit 5983bb17 authored by Marek Medved's avatar Marek Medved
Browse files

parameter for context type

parent b49dbdfb
Loading
Loading
Loading
Loading
+2 −1
Original line number Original line Diff line number Diff line
@@ -120,6 +120,7 @@ def get_title(text_by_url, vocabulary):
def add_np_phrases(db, context_window, num_phr_per_sent, w2v, verbose=False):
def add_np_phrases(db, context_window, num_phr_per_sent, w2v, verbose=False):
    vocabulary, qa_type_dict, kb = db.get_dicts()
    vocabulary, qa_type_dict, kb = db.get_dicts()
    for url, text in kb.url2doc.items():
    for url, text in kb.url2doc.items():
        if verbose:
            print(f'Processing: {url}')
            print(f'Processing: {url}')
        text_title_vert = get_title(text['title'], vocabulary)
        text_title_vert = get_title(text['title'], vocabulary)
        text_vert = get_text_vert(text['text'], vocabulary)
        text_vert = get_text_vert(text['text'], vocabulary)
+51 −24
Original line number Original line Diff line number Diff line
@@ -4,11 +4,17 @@ from sqad_db import SqadDb
from sqad_db import id2word
from sqad_db import id2word
from sqad_db import id2qt
from sqad_db import id2qt
from pprint import pprint
from pprint import pprint
import sys




def get_ctx(data, vocabulary, part=''):
def get_ctx(data, vocabulary, part='', context_type=''):
    sentence_phrases = {}
    sentence_phrases = {}
    if context_type:
        required_ctx = context_type.strip().split(';')
    else:
        required_ctx = ['all']
    for ctx_type, phrs in data.items():
    for ctx_type, phrs in data.items():
        if ctx_type in required_ctx or 'all' in required_ctx:
            for p in phrs:
            for p in phrs:
                p_content = []
                p_content = []
                for w_id_cx in p:
                for w_id_cx in p:
@@ -42,15 +48,15 @@ def get_content(data, vocabulary, part=''):
    return result
    return result




def get_content_ctx(url, kb, vocabulary, part=''):
def get_content_ctx(url, kb, vocabulary, part='', context_type=''):
    result = []
    result = []
    for sentence in kb.url2doc.get(url)['text']:
    for sentence in kb.url2doc.get(url)['text']:
        result.append({'sent': get_senence(sentence['sent'], vocabulary, part),
        result.append({'sent': get_senence(sentence['sent'], vocabulary, part),
                       'ctx' : get_ctx(sentence['ctx'], vocabulary, part)})
                       'ctx' : get_ctx(sentence['ctx'], vocabulary, part, context_type)})
    return result
    return result




def get_record(db, record_id, word_parts=''):
def get_record(db, record_id, word_parts='', context_type=''):
    record = db.get_record(record_id)
    record = db.get_record(record_id)
    vocabulary, qa_type_dict, kb = db.get_dicts()
    vocabulary, qa_type_dict, kb = db.get_dicts()
    """
    """
@@ -75,13 +81,13 @@ def get_record(db, record_id, word_parts=''):
    data['a_sel_pos'] = record.text_answer_position
    data['a_sel_pos'] = record.text_answer_position
    data['a_ext'] = get_content(record.answer_extraction, vocabulary, word_parts)
    data['a_ext'] = get_content(record.answer_extraction, vocabulary, word_parts)
    data['similar_answers'] = record.similar_answers
    data['similar_answers'] = record.similar_answers
    data['text_title'] = kb.url2doc.get(record.text)['title']
    data['text_title'] = get_content(kb.url2doc.get(record.text)["title"], vocabulary, word_parts)
    data['text'] = get_content_ctx(record.text, kb, vocabulary, word_parts)
    data['text'] = get_content_ctx(record.text, kb, vocabulary, word_parts, context_type)


    return data
    return data




def print_record(db, record_id):
def print_record(db, record_id, context_type=''):
    record = db.get_record(record_id)
    record = db.get_record(record_id)
    vocabulary, qa_type_dict, kb = db.get_dicts()
    vocabulary, qa_type_dict, kb = db.get_dicts()


@@ -115,7 +121,8 @@ def print_record(db, record_id):
        print(f'\ts: {" ".join([x["word"] for x in i])}')
        print(f'\ts: {" ".join([x["word"] for x in i])}')


    print('text:')
    print('text:')
    for idx, sent_and_phrs in enumerate(get_content_ctx(record.text, kb, vocabulary, part='w')):
    for idx, sent_and_phrs in enumerate(get_content_ctx(record.text, kb, vocabulary,
                                                        part='w', context_type=context_type)):
        print(f'\ts_{idx}: {" ".join([x["word"] for x in sent_and_phrs["sent"]])}')
        print(f'\ts_{idx}: {" ".join([x["word"] for x in sent_and_phrs["sent"]])}')
        for key, phrs in sent_and_phrs['ctx'].items():
        for key, phrs in sent_and_phrs['ctx'].items():
            print(f'\t\tctx_type: {key}')
            print(f'\t\tctx_type: {key}')
@@ -126,12 +133,21 @@ def print_record(db, record_id):
def main():
def main():
    import argparse
    import argparse
    parser = argparse.ArgumentParser(description='Transfrom SQAD to one database file with all features')
    parser = argparse.ArgumentParser(description='Transfrom SQAD to one database file with all features')
    # ================================================================
    # Required parameters
    # ================================================================
    parser.add_argument('-d', '--database_file', type=str,
    parser.add_argument('-d', '--database_file', type=str,
                        required=True,
                        required=True,
                        help='Database file name')
                        help='Database file name')
    parser.add_argument('-r', '--record_id', type=str,
    parser.add_argument('-r', '--record_id', type=str,
                        required=True,
                        required=False, default='',
                        help='Record id')
                        help='Record id')
    parser.add_argument('--list_ctx_types', action='store_true',
                        required=False, default=False,
                        help='List context types')
    # ================================================================
    # Optional parameters
    # ================================================================
    parser.add_argument('--simple', action='store_true',
    parser.add_argument('--simple', action='store_true',
                        required=False, default=False,
                        required=False, default=False,
                        help='Simple output')
                        help='Simple output')
@@ -139,13 +155,24 @@ def main():
                        required=False, default='',
                        required=False, default='',
                        help='Which word parts will be provided. Semicolon separated. For example "w;l;t;v100" '
                        help='Which word parts will be provided. Semicolon separated. For example "w;l;t;v100" '
                             'will return word, lemma, tag and 100 dim. vector')
                             'will return word, lemma, tag and 100 dim. vector')
    parser.add_argument('--context_type', type=str,
                        required=False, default='',
                        help='List of context types separated by semicolon. Example "name_phrs_w5_n5;prev_sent_n1"')
    args = parser.parse_args()
    args = parser.parse_args()

    if args.list_ctx_types:
        db = SqadDb(args.database_file, read_only=True)
        print(db.get_ctx_types())
    elif args.record_id:
        db = SqadDb(args.database_file, read_only=True)
        db = SqadDb(args.database_file, read_only=True)
        if args.simple:
        if args.simple:
        print_record(db, args.record_id)
            print_record(db, args.record_id, args.context_type)
        else:
        else:
        pprint(get_record(db, args.record_id, args.word_parts))
            pprint(get_record(db, args.record_id, args.word_parts, args.context_type))
        db.close()
        db.close()
    else:
        sys.stderr.write('Please specify one of attributes: record_id, list_ctx_types')
        sys.exit()




if __name__ == "__main__":
if __name__ == "__main__":
+3 −0
Original line number Original line Diff line number Diff line
@@ -170,6 +170,9 @@ class SqadDb:
        self.root['__knowledge_base__'] = knowledge_base
        self.root['__knowledge_base__'] = knowledge_base
        transaction.commit()
        transaction.commit()


    def get_ctx_types(self):
        return self.root['__ctx_types__']

    def get_dicts(self):
    def get_dicts(self):
        return self.root['__vocabulary__'], self.root['__qa_type__'], self.root['__knowledge_base__']
        return self.root['__vocabulary__'], self.root['__qa_type__'], self.root['__knowledge_base__']