Verified Commit 2d101c55 authored by Peter Stanko's avatar Peter Stanko
Browse files

More database tests and list refactor

parent b1b9ce2d
Loading
Loading
Loading
Loading
Loading
+30 −4
Original line number Diff line number Diff line
@@ -271,11 +271,30 @@ class User(EntityBase, Client):

        return Project.query.filter(Project.id.in_(pids))

    def query_all_projects(self):
        return Project.query.join(Project.course).join(Course.roles) \
            .join(Role.clients).filter(Client.id == self.id)

    def query_all_submissions(self):
        return Submission.query.join(Submission.course).join(Course.roles) \
            .join(Role.clients).filter(Client.id == self.id)

    def query_all_submissions_by_group(self):
        return self.query_all_submissions().join(Course.groups) \
            .filter(Group.users.contains(self))

    def get_all_submissions(self):
        return self.query_all_submissions().all()

    def query_all_projects_in_group(self):
        return self.query_all_projects().join(Course.groups).join(Group.users) \
            .filter(Group.users.contains(self))

    def get_all_projects(self):
        result = []
        for course in self.courses:
            result += self.get_projects_by_course(course)
        return result
        return self.query_all_projects().all()

    def get_all_projects_in_group(self):
        return self.query_all_projects_in_group().all()

    def get_projects_by_course(self, course: 'Course') -> List['Project']:
        """Gets projects from the course
@@ -367,6 +386,13 @@ class Course(db.Model, EntityBase, NamedMixin):
    def __eq__(self, other):
        return self.id == other.id

    def query_roles_by_client(self, client):
        return Role.query.join(Role.clients) \
            .filter(Role.clients.contains(client) | Role.course == self)

    def get_roles_by_client(self, client):
        return self.query_roles_by_client(client).all()

    def get_users_by_role(self, role: 'Role') -> List['User']:
        """Gets all users in the course based on their role
        Args:
+16 −9
Original line number Diff line number Diff line
@@ -5,8 +5,8 @@ Permissions service
import logging
from typing import Union

from portal.database.models import ClientType, Course, Role, RolePermissions, Submission, User, \
    Worker
from portal.database.models import ClientType, Course, Group, Role, RolePermissions, Submission, \
    User, Worker
from portal.service.errors import ForbiddenError
from portal.service.general import GeneralService

@@ -266,14 +266,21 @@ class PermissionsService(GeneralService):
            for key, value in vars(permission).items():
                if not key.startswith("_") and key not in FILTER_PERMISSION_ATTRS:
                    result[key] = result.get(key) or value
        log.debug(f"[PERM] Effective permissions: {self.client_name} "
        log.debug(f"[PERM] Effective permissions for client {self.client_name} "
                  f"in course {course.log_name}: {result}")
        return result

    def submission_access_group(self, submission, perm):
        if self.check.permissions(perm):
            group_intersection = [
                group for group in self.client.groups if submission.user in group.users]
            return any(
                group in submission.project.groups for group in group_intersection)
        if not self.check.permissions(perm):
            return False
        # Filter groups that contains current user and submission users
        query = Group.query.join(Group.users).filter(
            Group.users.contains(submission.user) & Group.users.contains(self.client_owner)
        )
        # Filter Groups that also has active projects
        query = query.join(Group.projects).filter(Group.projects.contains(submission.project))
        # group_intersection = [
        #     group for group in self.client.groups if submission.user in group.users]
        # return any(
        #     group in submission.project.groups for group in group_intersection)
        return len(query.all()) > 0
+13 −11
Original line number Diff line number Diff line
@@ -4,14 +4,15 @@ Submissions service
import json
import logging
from pathlib import Path
from typing import List, Union
from typing import Union

from celery.result import AsyncResult
from werkzeug.utils import secure_filename

from portal import storage
from portal.async_celery import submission_processor, tasks
from portal.database.models import Group, Project, Role, Submission, SubmissionState, User, Worker
from portal.database.models import Course, Group, Project, Role, Submission, SubmissionState, User, \
    Worker
from portal.rest.rest_helpers import FlaskRequestHelper
from portal.service import errors
from portal.service.general import GeneralService
@@ -173,12 +174,10 @@ class SubmissionsService(GeneralService):
        path = upload_files_to_storage(file)
        return path

    def filter_user_avail_submissions(self, query, roles: List[Role], groups: List[Group]):
    def filter_user_avail_submissions(self, query):
        submissions = query.all()
        return [submission for submission in submissions
                if nonempty_intersection(submission.user.roles, roles)
                and nonempty_intersection(submission.user.groups, groups)
                and self.perm_service(submission=submission).check.read_submission()]
        return [submission for submission in submissions if
                self.perm_service(submission=submission).check.read_submission()]

    def find_all(self):
        request_helper = FlaskRequestHelper()
@@ -188,8 +187,8 @@ class SubmissionsService(GeneralService):
        course = request_helper.args.course()

        projects = request_helper.args.projects()
        roles = request_helper.args.roles()
        groups = request_helper.args.groups()
        role_ids = self.request.args.get('roles')
        group_ids = self.request.args.get('groups')
        state = request_helper.args.state()

        if user:
@@ -204,8 +203,11 @@ class SubmissionsService(GeneralService):
            if projects:
                query = query.filter(Submission.project_id.in_(projects))
            # TODO: filter by groups and roles in the query here

        return self.filter_user_avail_submissions(query, roles, groups)
            if role_ids:
                query.join(Submission.course).join(Course.roles).filter(Role.id.in_(role_ids))
            if group_ids:
                query.join(Submission.course).join(Course.groups).filter(Group.id.in_(group_ids))
        return self.filter_user_avail_submissions(query)

    def process_submission_params(self, params: dict, project: Project, user: User):
        file_params = params.get('file_params')
+109 −2
Original line number Diff line number Diff line
import os

import pytest

from portal import create_app
from portal import db
from portal import create_app, db
from portal.database import Course, Group, Project, Role, User


@pytest.fixture(scope='session')
@@ -38,3 +39,109 @@ http://alexmic.net/flask-sqlalchemy-pytest/
https://xvrdm.github.io/2017/07/03/testing-flask-sqlalchemy-database-with-pytest/
https://scotch.io/tutorials/build-a-crud-web-app-with-python-and-flask-part-one
'''


def _get_user(username):
    return User(username=username, email=f'username@example.com')


def _get_course(codename):
    return Course(codename)


def _get_role(course, codename):
    return Role(course=course, codename=codename)


def _get_group(course, codename):
    return Group(course=course, codename=codename)


@pytest.fixture(autouse=True)
def db_student():
    return _get_user(username='student1')


@pytest.fixture(autouse=True)
def db_student2():
    return _get_user(username='student1')


@pytest.fixture(autouse=True)
def db_teacher():
    return _get_user(username='teacher1')


@pytest.fixture(autouse=True)
def db_teacher():
    return _get_user(username='teacher2')


@pytest.fixture(autouse=True)
def db_course():
    return Course(codename='course1')


@pytest.fixture(autouse=True)
def db_course2():
    return Course(codename='course2')


@pytest.fixture(autouse=True)
def db_c1_project1(db_course):
    return Project(course=db_course, codename='project1')


@pytest.fixture(autouse=True)
def db_c1_project2(db_course):
    return Project(course=db_course, codename='project2')


@pytest.fixture(autouse=True)
def db_c2_project1(db_course2):
    return Project(course=db_course2, codename='project1')


@pytest.fixture(autouse=True)
def db_c2_project2(db_course2):
    return Project(course=db_course2, codename='project2')


@pytest.fixture(autouse=True)
def db_c1_role1(db_course):
    return Role(course=db_course, codename='role1')


@pytest.fixture(autouse=True)
def db_c1_role2(db_course):
    return Role(course=db_course, codename='role2')


@pytest.fixture(autouse=True)
def db_c2_role1(db_course2):
    return Role(course=db_course2, codename='role1')


@pytest.fixture(autouse=True)
def db_c2_role2(db_course2):
    return Role(course=db_course2, codename='role2')


@pytest.fixture(autouse=True)
def db_c1_group1(db_course):
    return Group(course=db_course, codename='group1')


@pytest.fixture(autouse=True)
def db_c1_group2(db_course):
    return Group(course=db_course, codename='group2')


@pytest.fixture(autouse=True)
def db_c2_group1(db_course2):
    return Group(course=db_course2, codename='group1')


@pytest.fixture(autouse=True)
def db_c2_group2(db_course2):
    return Group(course=db_course2, codename='group2')
+3 −0
Original line number Diff line number Diff line
@@ -966,3 +966,6 @@ def test_get_users_in_group_based_on_role(session):
    assert user in res
    assert teacher not in res
    assert teacher2 not in res


Loading