Commit 2707beba authored by Marek Medved's avatar Marek Medved
Browse files

new embedding from robeczech

parent 7a6b5033
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ updates:
	@# Contains answer sentece
	./make_copy.sh $(DB) $(DB)_addAS
	@printf "Contains answer\n======================\n" >> $(DB)_addAS.log
	($(CONDA_ACTIVATE) base; ./add_contains_answer_sentences.py -d $(DB)_addAS 2>> $(DB)_addAS.log)
	($(CONDA_ACTIVATE) base; ./add_contains_answer_sentences.py -d $(DB)_addAS --no_partial 2>> $(DB)_addAS.log)
	@# Similar sentences
	./make_copy.sh $(DB)_addAS $(DB)_addAS_simS
	@printf "Similar answers\n======================\n" >> $(DB)_addAS_simS.log
+9 −6
Original line number Diff line number Diff line
@@ -23,14 +23,14 @@ def get_all_partial(answer):

def partial_match(answer, sentence, a_type):
    answer_content = [x['lemma'] if x['lemma'] != '[number]' else x['word']  for x in answer]
    answer_tags = [x['tag'] for x in answer]
    # answer_tags = [x['tag'] for x in answer]
    sentence_content = [x['lemma'] if x['lemma'] != '[number]' else x['word']  for x in sentence]
    sentence_tags = [x['tag'] for x in sentence]
    # sentence_tags = [x['tag'] for x in sentence]

    # Only if question type is suitable for partial match and exact answer is more than one word
    if a_type in ['PERSON', 'DATETIME', 'ENTITY', 'LOCATION'] and len(answer_content) > 1:
        partial_answers = get_all_partial(answer_content)
        print(f'{[x for x in zip(answer_content, answer_tags)]}::{partial_answers}::{[x for x in zip(sentence_content, sentence_tags)]}')
        # print(f'{[x for x in zip(answer_content, answer_tags)]}::{partial_answers}::{[x for x in zip(sentence_content, sentence_tags)]}')
        partial_answers.sort(key=len, reverse=True)
        # print(f'partial_answers: {partial_answers}')
        while partial_answers:
@@ -42,7 +42,7 @@ def partial_match(answer, sentence, a_type):
    return False


def find_sentences_containing_answer(db, verbose=False):
def find_sentences_containing_answer(db, no_partial=False, verbose=False):
    """
    Searching for sentences containing the exact answer
    :param db: ZODB database
@@ -64,7 +64,7 @@ def find_sentences_containing_answer(db, verbose=False):
                        print(f'Full match: {ans_ext_lemma} -> {doc_sent_content}')
                    containing_answer.append(idx)

                elif partial_match(sent['sent'], sent_and_phrs["sent"], a_type):
                elif not no_partial and partial_match(sent['sent'], sent_and_phrs["sent"], a_type):
                    if verbose:
                        print(f'Partial match: {ans_ext_lemma} -> {doc_sent_content}')
                    containing_answer_partial.append(idx)
@@ -87,6 +87,9 @@ def main():
    parser.add_argument('-v', '--verbose', action='store_true',
                        required=False, default=False,
                        help='Verbose mode')
    parser.add_argument('--no_partial', action='store_true',
                        required=False, default=False,
                        help='Compute partial match')
    args = parser.parse_args()

    if (args.url and args.port) or args.db_path:
@@ -99,7 +102,7 @@ def main():
        sys.exit()

    try:
        for record, sent_containing_answer, containing_answer_partial in find_sentences_containing_answer(db, verbose=args.verbose):
        for record, sent_containing_answer, containing_answer_partial in find_sentences_containing_answer(db, partial=args.no_partial, verbose=args.verbose):
            if args.verbose:
                print(f'{record.rec_id}: {sent_containing_answer} ::: {containing_answer_partial}')
                print('==============================')
+37 −8
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import transaction
from sentece_vector import add_vector_record
from sentece_vector import add_vector_text
from transformers import BertTokenizer, BertConfig, BertModel
from transformers import RobertaModel, RobertaTokenizer
import os
dir_path = os.path.dirname(os.path.realpath(__file__))

@@ -45,7 +46,33 @@ class S2ClsBert:
        return cls_emb


def compute_vector(data, model, vocabulary, ctx=False, verbose=False):
class S2RobeCzech:
    def __init__(self):
        """
        Load BERT models
        """

        self.model = RobertaModel.from_pretrained('ufal/robeczech-base')
        self.tokenizer = RobertaTokenizer.from_pretrained('ufal/robeczech-base')

    def get_sent_embedding(self, sentece, verbose=False):
        """
        Get BERT sentence embedding from CLS token
        :param sentece:
        :param verbose:
        :return:
        """
        input_ids = self.tokenizer.encode(sentece, return_tensors="pt", add_special_tokens=True)
        if verbose:
            input_ids_2 = self.tokenizer.encode(sentece, add_special_tokens=True)
            for i in input_ids_2:
                print(f'{self.tokenizer.decode(i)} -> {i}')
        outputs = self.model(input_ids)
        cls_emb = outputs[0][0][0].detach().numpy()
        return cls_emb


def compute_vector(data, model, vocabulary, name, ctx=False, verbose=False):
    """
    Add BERT vector to sentece
    :param data: dict
@@ -60,7 +87,7 @@ def compute_vector(data, model, vocabulary, ctx=False, verbose=False):
        for w_id in sentece['sent']:
            s_content.append(id2word(vocabulary, w_id, parts='w')['word'])
        sent_v = model.get_sent_embedding(s_content, verbose=verbose)
        sentece['cls_bert'] = persistent.list.PersistentList(sent_v)
        sentece[name] = persistent.list.PersistentList(sent_v)

        if verbose:
            print(f"{' '.join(s_content)}\t{sent_v}")
@@ -74,10 +101,9 @@ def compute_vector(data, model, vocabulary, ctx=False, verbose=False):
                            phr_content.append(id2word(vocabulary, w_id, parts='w')['word'])
                        if phr_content:
                            phr_sent_v = model.get_sent_embedding(phr_content, verbose=verbose)
                            sentece['cls_bert'] = persistent.list.PersistentList(phr_sent_v)
                            sentece[name] = persistent.list.PersistentList(phr_sent_v)
                        else:
                            sentece['cls_bert'] = persistent.list.PersistentList([0]*768)

                            sentece[name] = persistent.list.PersistentList([0]*768)


def main():
@@ -107,10 +133,13 @@ def main():
        sys.stderr.write('Please specify --db_path or (--port and --url)')
        sys.exit()

    model = S2ClsBert()
    try:
        add_vector_text(db, model, compute_vector, verbose=args.verbose)
        add_vector_record(db, model, compute_vector, verbose=args.verbose)
        model_slavic_bert = S2ClsBert()
        add_vector_text(db, model_slavic_bert, compute_vector, 'slavic_bert', verbose=args.verbose)
        add_vector_record(db, model_slavic_bert, compute_vector, 'slavic_bert', verbose=args.verbose)
        model_robe_czech = S2ClsBert()
        add_vector_text(db, model_robe_czech, compute_vector, 'robeczech', verbose=args.verbose)
        add_vector_record(db, model_robe_czech, compute_vector, 'robeczech', verbose=args.verbose)
        db.update()
        db._p_changed = True
        transaction.commit()