Commit 1e3b75a8 authored by Marek Medved's avatar Marek Medved
Browse files

simple question type

parent 26f0288f
Loading
Loading
Loading
Loading
+11 −4
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from request_service import get_answer_selection
from request_service import get_answer_extraction_ml
from request_service import get_answer_extraction_ml_yes_no
from request_service import get_document_selection
from request_service import get_qa_type
# from request_service import get_qa_type
from request_service import get_q_type_2


def print_basic_data(data, output):
@@ -62,16 +63,22 @@ def main():

        # Question type
        # q_type, a_type = 'VERB_PHR', 'YES_NO'
        q_type, a_type = get_qa_type(input_l_question)['types'].split('; ')
        # q_type, a_type = get_qa_type(input_l_question)['types'].split('; ')
        q_type = get_q_type_2(input_l_question)['type']
        a_type = ''

        if args.verbose:
            print(f'q_type: {q_type}, a_type: {a_type}')
            print(f'q_type: {q_type}')

        # Answer selection
        ans_sel = get_answer_selection(input_w_question, top_n_docs[0][0])

        if args.verbose:
            print(f'ans_sel: {ans_sel}')

        # Answer extraction
        if a_type == 'YES_NO':
        # if a_type == 'YES_NO':
        if q_type == 'YES_NO':
            ans_ext = get_answer_extraction_ml_yes_no(input_w_question, ans_sel['answer_list'][0][1])
        else:
            ans_ext = get_answer_extraction_ml(input_w_question, ans_sel['answer_list'][0][1])
+10 −0
Original line number Diff line number Diff line
@@ -68,3 +68,13 @@ def get_qa_type(question):
    """
    r = requests.post(f'{url}:4451/qa_type', data=json.dumps({'question': f'{question}'}))
    return json.loads(r.content)


def get_q_type_2(question):
    """
    Get answer selection senteces
    :param question: str
    :return:
    """
    r = requests.post(f'{url}:4450/q_type_2', data=json.dumps({'question': f'{question}'}))
    return json.loads(r.content)
+26 −0
Original line number Diff line number Diff line
@@ -58,6 +58,18 @@ answer_extraction_ml_yes_no_tokenizer.padding_side = "right"
answer_extraction_ml_yes_no_model = AutoModelForSequenceClassification.from_pretrained(answer_extraction_ml_yes_no_model_path)
answer_extraction_ml_yes_no_classify = pipeline('text-classification', model=answer_extraction_ml_yes_no_model,
                                                tokenizer=answer_extraction_ml_yes_no_tokenizer)

# =========================================
# Question classifier yes/no, rest
# =========================================
sys.stderr.write('Loading question classifier (yes/no, rest) model ... ')
id2label_2 = {'LABEL_0': 'REST', 'LABEL_1': 'YES_NO'}
question_classification_model_path = f'{current_dir}/question_classification_v2/train_2022_10_24_18-20-18/xlm-roberta-large-squad2-finetuned-V1/'
question_classification_tokenizer = AutoTokenizer.from_pretrained(question_classification_model_path)
question_classification_tokenizer.padding_side = "right"
question_classification_model = AutoModelForSequenceClassification.from_pretrained(question_classification_model_path)
question_classification_classify = pipeline('text-classification', model=question_classification_model,
                                            tokenizer=question_classification_tokenizer)
sys.stderr.write('loaded\n')

sys.stderr.write('Service ready.\n')
@@ -126,6 +138,20 @@ class JobServer(BaseHTTPRequestHandler):
                                         'answer_extraction': id2label[prediction["label"]],
                                         'answer_extraction_score': prediction["score"]}).encode('utf-8'))

        elif command == 'q_type_2':
            query_input = {
                'text': f'{data["question"]}'
            }
            translate_type = {'LABEL_0': 'REST', 'LABEL_1': 'YES_NO'}

            q_type = translate_type[question_classification_classify(query_input)['label']]

            logging.info(f'POST {command}; {data["question"]}')
            self._set_headers()
            self.wfile.write(json.dumps({'command': command,
                                         'question': data['question'],
                                         'type': q_type}).encode('utf-8'))

    @staticmethod
    def v100(sentence):
        data = []