Commit 0289dd73 authored by Martin Juhás's avatar Martin Juhás
Browse files

fix: change how questionnaires are evaluated for more safety

No API changes

Closes #400
parent 37198b25
Loading
Loading
Loading
Loading
+33 −1
Original line number Diff line number Diff line
import json
import os
from io import BytesIO
from typing import List, Optional
from typing import List, Optional, Dict
from zipfile import ZipFile

from django.conf import settings
@@ -115,6 +115,38 @@ def is_milestone_reached(team: Team, milestone: Milestone) -> bool:
    return milestone_state.reached


def compare_team_milestone_state(
    team: Team,
    expected_states: Dict[str, bool],
) -> bool:
    """
    This function checks whether the given milestones have the expected state for the given team.
    Milestones not included in the expected state are ignored.
    """
    milestone_states = MilestoneState.objects.select_related(
        "milestone"
    ).filter(
        team_state_id=team.state_id, milestone__name__in=expected_states.keys()
    )

    for milestone_state in milestone_states:
        if (
            expected_states[milestone_state.milestone.name]
            != milestone_state.reached
        ):
            return False
    return True


def reach_milestone_str(team: Team, milestone: str, reached: bool):
    milestone_state = MilestoneState.objects.filter(
        team_state_id=team.state_id, milestone__name=milestone
    ).get()

    milestone_state.reached = reached
    milestone_state.save()


def print_milestones(team: Team):
    print(f"team: {team.id}")
    for milestone_state in team.state.milestone_states.all():
+8 −0
Original line number Diff line number Diff line
@@ -156,6 +156,14 @@ class Control(models.Model):
    class Meta:
        default_permissions = ()

    def join_modifications(self, other: "Control"):
        self.activate_milestone.extend(other.activate_milestone)
        self.deactivate_milestone.extend(other.deactivate_milestone)

    def dedupe_modifications(self):
        self.activate_milestone = list(set(self.activate_milestone))
        self.deactivate_milestone = list(set(self.deactivate_milestone))


class LearningObjective(models.Model):
    name = models.TextField()
+166 −126
Original line number Diff line number Diff line
import re
from typing import Dict, Optional, List, Union
from typing import Dict, Optional, List, Union, Set, Callable

from django.utils import timezone

@@ -26,98 +26,117 @@ from running_exercise.lib.milestone_handler import update_milestones
from running_exercise.models import QuestionnaireAnswer, ActionLog, LogType
from user.models import User

Error = Optional[Err[str]]


def get_controls(questions: Dict[int, Question]) -> Dict[int, Control]:
    control_ids = []
    for question in questions.values():
        if question.type != QuestionTypes.RADIO:
            continue
        if question.type == QuestionTypes.RADIO:
            control_ids.extend(question.details.controls.values())
        elif question.type == QuestionTypes.AUTO_FREE_FORM:
            control_ids.append(question.details.correct_id)
            control_ids.append(question.details.incorrect_id)

    return Control.objects.in_bulk(id_list=control_ids)


def handle_radio_question_answer(
    question: Question,
    tqs: TeamQuestionnaireState,
    controls: Dict[int, Control],
    answer: AnswerInput,
) -> QuestionnaireAnswer:
def check_answer_length(
    details: Union[AutoFreeFormQuestion, FreeFormQuestion], answer: AnswerInput
) -> Error:
    if len(answer.value) < details.min:
        return Err(f"Answer `{answer.value}` is too short")

    if details.max != -1 and len(answer.value) > details.max:
        return Err(f"Answer `{answer.value}` is too long")
    return None


def broadcast_tqs_update(tqs: TeamQuestionnaireState):
    log = ensure_exists(
        ActionLog.objects.filter(
            team_id=tqs.team_id,
            type=LogType.FORM,
            details_id=tqs.id,
        )
    )
    SubscriptionHandler.broadcast_action_logs(log, tqs.team, EventType.modify())


class AnswerHandler:
    tqs: TeamQuestionnaireState
    questions: Dict[int, Question]
    controls: Dict[int, Control]
    answers: List[QuestionnaireAnswer]
    milestone_modifications: Control

    def __init__(self, tqs: TeamQuestionnaireState):
        self.tqs = tqs
        self.questions = tqs.questionnaire.questions.in_bulk()
        self.controls = get_controls(self.questions)
        self.answers = []
        self.milestone_modifications = (
            tqs.questionnaire.control
        )  # Do not use .save for the tqs control in the AnswerHandler!

    def handle_radio_question(
        self, question: Question, answer: AnswerInput
    ) -> Error:
        details: RadioQuestion = question.details
        if not answer.value.isdecimal():
        raise RunningExerciseOperationException(
            f"Invalid answer ({answer.value}) for question ({question.id}): Not a number"
        )
            return Err("Not a number")

        choice = int(answer.value)

        if choice > details.max or choice < 1:
        raise RunningExerciseOperationException(
            f"Invalid answer ({choice}) for question ({question.id})"
        )
            return Err(f"Invalid answer ({choice})")

    correct: Optional[bool]
        is_correct: Optional[bool]
        if details.correct == 0:
        correct = None
            is_correct = None
        else:
        correct = details.correct == choice
            is_correct = details.correct == choice

        if (control_id := details.controls.get(answer.value, None)) is not None:
            # we know that the control always exists
        update_milestones(tqs.team, controls[control_id])
            self.milestone_modifications.join_modifications(
                self.controls[control_id]
            )

    return QuestionnaireAnswer(
        self.answers.append(
            QuestionnaireAnswer(
                question=question,
        team_questionnaire_state=tqs,
                team_questionnaire_state=self.tqs,
                answer=answer.value,
        is_correct=correct,
                is_correct=is_correct,
            )
        )


Error = Optional[Err[str]]


def check_answer_length(
    details: Union[AutoFreeFormQuestion, FreeFormQuestion], answer: AnswerInput
) -> Error:
    if len(answer.value) < details.min:
        return Err(f"Answer `{answer.value}` is too short")

    if details.max != -1 and len(answer.value) > details.max:
        return Err(f"Answer `{answer.value}` is too long")
        return None


    def handle_free_form_question(
    question, tqs: TeamQuestionnaireState, answer: AnswerInput
) -> QuestionnaireAnswer:
        self, question: Question, answer: AnswerInput
    ) -> Error:
        if (err := check_answer_length(question.details, answer)) is not None:
        raise RunningExerciseOperationException(
            f"Question ({question.id}): {err.unwrap_err()}"
        )
    # nothing to do here except create the answer
            return err

    return QuestionnaireAnswer(
        self.answers.append(
            QuestionnaireAnswer(
                question=question,
        team_questionnaire_state=tqs,
                team_questionnaire_state=self.tqs,
                answer=answer.value,
                is_correct=None,
            )
        )

        return None

    def handle_auto_free_form_question(
    question: Question, tqs: TeamQuestionnaireState, answer: AnswerInput
) -> QuestionnaireAnswer:
    details: AutoFreeFormQuestion = ensure_exists(
        AutoFreeFormQuestion.objects.filter(
            id=question.details_id
        ).select_related("correct", "incorrect")
    )

        self, question: Question, answer: AnswerInput
    ) -> Error:
        details: AutoFreeFormQuestion = question.details
        if (err := check_answer_length(details, answer)) is not None:
        raise RunningExerciseOperationException(
            f"Question ({question.id}): {err.unwrap_err()}"
        )
            return err

        if details.regex:
            is_correct = (
@@ -127,63 +146,83 @@ def handle_auto_free_form_question(
            is_correct = details.correct_answer == answer.value

        if is_correct:
        update_milestones(tqs.team, details.correct)
            self.milestone_modifications.join_modifications(
                self.controls[details.correct_id]
            )
        else:
        update_milestones(tqs.team, details.incorrect)
            self.milestone_modifications.join_modifications(
                self.controls[details.incorrect_id]
            )

    return QuestionnaireAnswer(
        self.answers.append(
            QuestionnaireAnswer(
                question=question,
        team_questionnaire_state=tqs,
                team_questionnaire_state=self.tqs,
                answer=answer.value,
                is_correct=is_correct,
            )
        )

        return None

def check_questionnaire_answers(
    tqs: TeamQuestionnaireState, answers: List[AnswerInput]
):
    questions: Dict[int, Question] = tqs.questionnaire.questions.in_bulk()
    controls = get_controls(questions)
    def handle_unknown_question_type(
        self, question: Question, answer: AnswerInput
    ) -> Error:
        return Err(f"Unhandled question type `{question.type}`")

    if len(questions) != len(answers):
        raise RunningExerciseOperationException(
            f"Mismatch between the number of questions ({len(questions)})"
    def validate_answers(self, answers: List[AnswerInput]) -> Error:
        if len(answers) != len(self.questions):
            return Err(
                f"Mismatch between the number of questions ({len(self.questions)})"
                f" and number of answers ({len(answers)})"
            )

    question_answers: List[QuestionnaireAnswer] = []
        errors: List[str] = []
        answered_questions: Set[int] = set()
        for answer in answers:
        question = questions.get(int(answer.question_id))
            question = self.questions.get(int(answer.question_id), None)
            if question is None:
            raise RunningExerciseOperationException(
                f"Question ({answer.question_id}) does not exist"
                errors.append(
                    f"Question ({answer.question_id}): Question does not exist"
                )
                continue

        if question.type == QuestionTypes.RADIO:
            question_answer = handle_radio_question_answer(
                question, tqs, controls, answer
            if question.id in answered_questions:
                errors.append(
                    f"Question ({answer.question_id}): Question has already been answered"
                )
                continue

            answered_questions.add(question.id)

            handler: Callable[
                [Question, AnswerInput], Error
            ] = self.handle_unknown_question_type
            if question.type == QuestionTypes.RADIO:
                handler = self.handle_radio_question
            elif question.type == QuestionTypes.FREE_FORM:
                handler = self.handle_free_form_question
            elif question.type == QuestionTypes.AUTO_FREE_FORM:
            question_answer = handle_auto_free_form_question(
                question, tqs, answer
            )
        else:
            question_answer = handle_free_form_question(question, tqs, answer)
                handler = self.handle_auto_free_form_question

        question_answers.append(question_answer)
            if (err := handler(question, answer)) is not None:
                errors.append(f"Question ({question.id}): {err.unwrap_err()}")

    QuestionnaireAnswer.objects.bulk_create(question_answers)
        if len(errors) == 0:
            return None

        return Err("\n".join(errors))

def broadcast_tqs_update(tqs: TeamQuestionnaireState):
    log = ensure_exists(
        ActionLog.objects.filter(
            team_id=tqs.team_id,
            type=LogType.FORM,
            details_id=tqs.id,
        )
    )
    SubscriptionHandler.broadcast_action_logs(log, tqs.team, EventType.modify())
    def commit(self):
        QuestionnaireAnswer.objects.bulk_create(self.answers)
        self.milestone_modifications.dedupe_modifications()
        update_milestones(self.tqs.team, self.milestone_modifications)

        self.tqs.status = TeamQuestionnaireState.Status.ANSWERED
        self.tqs.timestamp_answered = timezone.now()
        self.tqs.save()

        broadcast_tqs_update(self.tqs)


class QuestionnaireHandler:
@@ -212,12 +251,13 @@ class QuestionnaireHandler:
                "Can only answer questionnaires which have been already sent"
            )

        tqs.status = TeamQuestionnaireState.Status.ANSWERED
        tqs.timestamp_answered = timezone.now()
        check_questionnaire_answers(tqs, quest_input.answers)
        tqs.save()
        update_milestones(tqs.team, tqs.questionnaire.control)
        broadcast_tqs_update(tqs)
        ah = AnswerHandler(tqs)
        if (err := ah.validate_answers(quest_input.answers)) is not None:
            raise RunningExerciseOperationException(
                f"Error: {err.unwrap_err()}"
            )
        ah.commit()

        return tqs

    @staticmethod
+131 −81

File changed.

Preview size limit exceeded, changes collapsed.