From 253d24bb493ee25fdd407f1074bb93a761b8bd15 Mon Sep 17 00:00:00 2001 From: Marek Medved <xmedved1@fi.muni.cz> Date: Fri, 3 Mar 2023 17:08:56 +0100 Subject: [PATCH] fix models --- Makefile | 3 ++- get_vector.py | 26 +++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index d4e0967..d128a20 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,8 @@ updates: @echo "$(hostname)" | mail -s "Done AQA job" "xmedved1@fi.muni.cz" run_ZODB_server: - exec "/usr/bin/python3.6" -m "ZEO.runzeo" -C /nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/zeo_server.conf + conda activate aqa-apollo; exec "python3" -m "ZEO.runzeo" -C /nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/zeo_server.conf + #exec "/usr/bin/python3.6" -m "ZEO.runzeo" -C /nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/zeo_server.conf #cd "$(HOME)/.local/lib/python3.6/site-packages/"; exec "/usr/bin/python3.6" -m "ZEO.runzeo" -a "0.0.0.0:9001" -f "/nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database_devel/sqad_db/stable" demo_query: diff --git a/get_vector.py b/get_vector.py index df86e68..f0390e5 100755 --- a/get_vector.py +++ b/get_vector.py @@ -9,22 +9,26 @@ from transformers import BertTokenizer, BertConfig, BertModel dir_path = os.path.dirname(os.path.realpath(__file__)) class Word2vec: - def __init__(self): + def __init__(self, dim='all'): """ Load pretrained models """ - self.model_100 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_100_5_5_0.05_skip') - self.model_300 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_300_5_5_0.05_skip') - self.model_500 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_500_5_5_0.05_skip') + if dim == 'all' or dim == '100': + self.model_100 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_100_5_5_0.05_skip') + if dim == 'all' or dim == '300': + self.model_300 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_300_5_5_0.05_skip') + if dim == 'all' or dim == '500': + self.model_500 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_500_5_5_0.05_skip') # Bert embedding from slavic bert - bert_config = BertConfig.from_json_file( - f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/bert_config.json') - self.model_slavic_bert = BertModel.from_pretrained( - f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/pytorch_model.bin', - config=bert_config, local_files_only=True) - self.tokenizer_slavic_bert = BertTokenizer.from_pretrained( - f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/') + if dim == 'all' or dim == 'slavic_bert': + bert_config = BertConfig.from_json_file( + f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/bert_config.json') + self.model_slavic_bert = BertModel.from_pretrained( + f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/pytorch_model.bin', + config=bert_config, local_files_only=True) + self.tokenizer_slavic_bert = BertTokenizer.from_pretrained( + f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/') def word2embedding_cls(self, word): """ -- GitLab