Commit e87d216d authored by Richard Glosner's avatar Richard Glosner
Browse files

Merge branch '410-implement-registration-endpoint-with-domain-restiction' into 'main'

Resolve "Implement registration endpoint with domain restiction"

See merge request inject/backend!375
parents a9d785b4 c5afd4f1
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -10,12 +10,17 @@ urlpatterns = [
    ),
    path(
        "auth/logout/",
        views.Logout.as_view(),
        views.LogoutView.as_view(),
        name="logout",
    ),
    path(
        "auth/session/",
        views.CheckSession.as_view(),
        views.CheckSessionView.as_view(),
        name="session",
    ),
    path(
        "auth/register/",
        views.RegisterView.as_view(),
        name="register",
    ),
]
+23 −2
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from rest_framework.views import APIView
from common_lib.exceptions import ApiException
from common_lib.logger import logger, log_user_msg
from user.models import User
from user.lib.user_manager import UserManager, RegisterUserInput


class LoginView(APIView):
@@ -35,7 +36,7 @@ class LoginView(APIView):
        )


class Logout(APIView):
class LogoutView(APIView):
    parser_classes = [parsers.JSONParser]
    renderer_classes = [JSONRenderer]

@@ -50,7 +51,7 @@ class Logout(APIView):
        )


class CheckSession(APIView):
class CheckSessionView(APIView):
    parser_classes = [parsers.JSONParser]
    renderer_classes = [JSONRenderer]

@@ -60,3 +61,23 @@ class CheckSession(APIView):
            None if request.session.is_empty() else request.session.session_key
        )
        return Response({"sessionid": session})


class RegisterView(APIView):
    parser_classes = [parsers.JSONParser]
    renderer_classes = [JSONRenderer]

    def post(self, request: Request, *args, **kwargs):
        """Register a new user."""
        registration_data = RegisterUserInput.from_request(request)
        new_user = UserManager.register_user(registration_data)

        login(request, new_user)
        logger.info(
            log_user_msg(request, request.user) + "successful registration"
        )
        return Response(
            {
                "sessionid": request.session.session_key,
            }
        )
+6 −1
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ from running_exercise.models import (
    InstructorComment,
    LogType,
)
from user.models import User, Tag
from user.models import User, Tag, DomainRestriction


class RestrictedUser(DjangoObjectType):
@@ -814,3 +814,8 @@ class TeamLearningActivityType(DjangoObjectType):

    def resolve_achieved_score(self, info):
        return self.achieved_score()


class DomainRestrictionType(DjangoObjectType):
    class Meta:
        model = DomainRestriction
+34 −1
Original line number Diff line number Diff line
@@ -275,7 +275,7 @@ paths:
        '500':
          $ref: "#/components/responses/Error"
  
  /auth/session:
  /auth/session/:
    get:
      tags:
        - auth
@@ -286,6 +286,39 @@ paths:
        '500':
          $ref: "#/components/responses/Error"
  
  /auth/register/:
    post:
      tags:
        - auth
      description: Register a new user
      requestBody:
        required: true
        content:
          application/json:
            schema:
              type: object
              properties:
                username:
                  type: string
                password:
                  type: string
                first_name:
                  type: string
                last_name:
                  type: string
              required:
                - username
                - password
                - first_name
                - last_name
      responses:
        '200':
          description: Registration successful
          $ref: "#/components/responses/AuthResponse"
        '500':
          $ref: "#/components/responses/Error"
  

components:
  schemas:
    JsonResponse:
+48 −0
Original line number Diff line number Diff line
from common_lib.logger import logger
from user.models import DomainRestriction, User


ALL_DOMAINS_ALLOWED = "*"


class DomainManager:
    @classmethod
    def extract_domain(cls, user_email: str) -> str:
        if "@" not in user_email:
            raise ValueError("Invalid email address format.")
        return user_email.split("@")[-1].lower()

    @classmethod
    def check_domain_restriction(cls, user_email: str) -> bool:
        allowed_domains = DomainRestriction.objects.values_list(
            "domain", flat=True
        )

        if ALL_DOMAINS_ALLOWED in allowed_domains:
            return True

        user_domain = cls.extract_domain(user_email)
        return user_domain in allowed_domains

    @classmethod
    def create_domain_restriction(
        cls, domain: str, added_by: User
    ) -> DomainRestriction:
        domain = domain.lower().strip()
        if not domain:
            raise ValueError("Domain cannot be empty.")

        if len(domain) > 255:
            raise ValueError("Domain length cannot exceed 255 characters.")

        if DomainRestriction.objects.filter(domain=domain).exists():
            raise ValueError(f"Domain `{domain}` already exists.")

        domain_restriction = DomainRestriction.objects.create(
            domain=domain, added_by=added_by
        )
        return domain_restriction

    @classmethod
    def delete_domain_restriction(cls, domain_id: int) -> None:
        DomainRestriction.objects.filter(id=domain_id).delete()
Loading