From 253d24bb493ee25fdd407f1074bb93a761b8bd15 Mon Sep 17 00:00:00 2001
From: Marek Medved <xmedved1@fi.muni.cz>
Date: Fri, 3 Mar 2023 17:08:56 +0100
Subject: [PATCH] fix models

---
 Makefile      |  3 ++-
 get_vector.py | 26 +++++++++++++++-----------
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/Makefile b/Makefile
index d4e0967..d128a20 100644
--- a/Makefile
+++ b/Makefile
@@ -59,7 +59,8 @@ updates:
 	@echo "$(hostname)" | mail -s "Done AQA job" "xmedved1@fi.muni.cz"
 
 run_ZODB_server:
-	exec "/usr/bin/python3.6" -m "ZEO.runzeo" -C /nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/zeo_server.conf
+	conda activate aqa-apollo; exec "python3" -m "ZEO.runzeo" -C /nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/zeo_server.conf
+	#exec "/usr/bin/python3.6" -m "ZEO.runzeo" -C /nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database/zeo_server.conf
 	#cd "$(HOME)/.local/lib/python3.6/site-packages/"; exec "/usr/bin/python3.6" -m "ZEO.runzeo" -a "0.0.0.0:9001" -f "/nlp/projekty/question_answering/AQA_v2/sqad_tools/sqad2database_devel/sqad_db/stable"
 
 demo_query:
diff --git a/get_vector.py b/get_vector.py
index df86e68..f0390e5 100755
--- a/get_vector.py
+++ b/get_vector.py
@@ -9,22 +9,26 @@ from transformers import BertTokenizer, BertConfig, BertModel
 dir_path = os.path.dirname(os.path.realpath(__file__))
 
 class Word2vec:
-    def __init__(self):
+    def __init__(self, dim='all'):
         """
         Load pretrained models
         """
-        self.model_100 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_100_5_5_0.05_skip')
-        self.model_300 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_300_5_5_0.05_skip')
-        self.model_500 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_500_5_5_0.05_skip')
+        if dim == 'all' or dim == '100':
+            self.model_100 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_100_5_5_0.05_skip')
+        if dim == 'all' or dim == '300':
+            self.model_300 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_300_5_5_0.05_skip')
+        if dim == 'all' or dim == '500':
+            self.model_500 = fasttext.load_model(f'{dir_path}/../../fasttext/models/cstenten_17_500_5_5_0.05_skip')
 
         # Bert embedding from slavic bert
-        bert_config = BertConfig.from_json_file(
-            f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/bert_config.json')
-        self.model_slavic_bert = BertModel.from_pretrained(
-            f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/pytorch_model.bin',
-            config=bert_config, local_files_only=True)
-        self.tokenizer_slavic_bert = BertTokenizer.from_pretrained(
-            f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/')
+        if dim == 'all' or dim == 'slavic_bert':
+            bert_config = BertConfig.from_json_file(
+                f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/bert_config.json')
+            self.model_slavic_bert = BertModel.from_pretrained(
+                f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/pytorch_model.bin',
+                config=bert_config, local_files_only=True)
+            self.tokenizer_slavic_bert = BertTokenizer.from_pretrained(
+                f'{dir_path}/bert_embeder_models/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt/')
 
     def word2embedding_cls(self, word):
         """
-- 
GitLab