Commit 3d8cd55c authored by Marek Medved's avatar Marek Medved
Browse files

bert embeddings

parent 8b27823c
Loading
Loading
Loading
Loading

add_bert_emberdings.py

0 → 100755
+83 −0
Original line number Diff line number Diff line
#! /usr/bin/python3
# coding: utf-8
import sys
from sqad_db import SqadDb
import persistent.list
import transaction
from deeppavlov.core.common.file import read_json
from deeppavlov import build_model, configs

class Bert_Embeddings:
    def __init__(self):
        bert_config = read_json(configs.embedder.bert_embedder)
        bert_config['metadata']['variables'][
            'BERT_PATH'] = '/nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt'

        # tokens, token_embs, subtokens, subtoken_embs, sent_max_embs, sent_mean_embs, bert_pooler_outputs = self.model
        self.model = build_model(bert_config)

    def word2embedding(self, word):
        tokens, token_embs, _, _, _, _, _ = self.model(word)
        return token_embs[0][0]

    # def sent2embedding(self, sent):
    #     sent = ''
    #     tokens, _, _, _, _, sent_mean_embs, _ = self.model(sent)
    #     pass
    #


def add_bert_word_embeddings_word(vocabulary, model, db):
    vocab_size = len(vocabulary.id2wlt.keys())
    progress = 0
    for w_id, value in vocabulary.id2wlt.items():
        progress += 1
        word = value['word']
        bert_embedding = model.word2embedding(word)
        sys.stderr.write(f'{progress}/{vocab_size}\r')
        vocabulary.vectors[w_id].append(persistent.list.PersistentList(bert_embedding))
        db._p_changed = True
        transaction.commit()


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Add bert embeddings to vocabulary')
    parser.add_argument('-u', '--url', type=str,
                        required=False, default='',
                        help='Database URL')
    parser.add_argument('-p', '--port', type=int,
                        required=False, default=None,
                        help='Server port')
    parser.add_argument('-d', '--db_path', type=str,
                        required=False, default='',
                        help='Database path')
    parser.add_argument('-v', '--verbose', action='store_true',
                        required=False, default=False,
                        help='Verbose mode')
    args = parser.parse_args()

    if (args.url and args.port) or args.db_path:
        if args.url and args.port:
            db = SqadDb(url=args.url, port=args.port)
        elif args.db_path:
            db = SqadDb(file_name=args.db_path)
    else:
        sys.stderr.write('Please specify --db_path or (--port and --url)')
        sys.exit()

    model = Bert_Embeddings()
    vocabulary, _, kb = db.get_dicts()
    try:
        add_bert_word_embeddings_word(vocabulary, model, db)
        # add_bert_word_embeddings_sent(vocabulary, kb, model)
        db.update()
        db._p_changed = True
        transaction.commit()
        db.close()
    except KeyboardInterrupt:
        db.close()
        sys.exit()

if __name__ == "__main__":
    main()
+4 −0
Original line number Diff line number Diff line
# turnus03
# python3 -m deeppavlov install bert_sentence_embedder
all:
	CUDA_VISIBLE_DEVICES=1,2 ./test_bert.py
+632 MiB

File added.

No diff preview for this file type.

+19 −0
Original line number Diff line number Diff line
{
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": 119547
}
+681 MiB

File added.

No diff preview for this file type.

Loading