Commit 95ecc339 authored by Petr Babic's avatar Petr Babic
Browse files

fix pylint

parent 13d8c6f1
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
import secrets
from flask import Flask, render_template, request, redirect, url_for, session
import src.database as db

import database as db

app = Flask(__name__, static_folder='static')
app.secret_key = 'dev_key'  # TODO: change key
app.secret_key = secrets.token_hex(16)


database = db.Database('data/data.db')
+17 −16
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from typing import Tuple, Optional


class Database:
    __slots__ = "path"
    __slots__ = ["path"]

    def __init__(self, path: str):
        self.path = path
@@ -33,32 +33,33 @@ class Database:
        if 'customer' not in tables:
            cur.execute('CREATE TABLE customer(email TEXT, password_hash TEXT)')
        if 'order_' not in tables:
            cur.execute('CREATE TABLE order_(status TEXT, time_created INTEGER, customer_id INTEGER)')
            cur.execute(
                'CREATE TABLE order_(status TEXT, time_created INTEGER, customer_id INTEGER)')
        if 'menu_item' not in tables:
            cur.execute('CREATE TABLE menu_item(name TEXT, description TEXT, price INTEGER)')
        con.commit()


class Customer:
    __slots__ = "row_id", "email", "pass_hash"
    __slots__ = ["row_id", "email", "pass_hash"]

    def __init__(self, email: str, pass_hash: str, row_id: int = -1) -> None:
        self.row_id, self.email, self.pass_hash = row_id, email, pass_hash

    @classmethod
    def get(cls, db: Database, email: str) -> Optional['Customer']:
        con, cur = db.connect()
    def get(cls, data: Database, email: str) -> Optional['Customer']:
        _, cur = data.connect()
        res = cur.execute(f"SELECT *, rowid FROM customer WHERE email = '{email}'").fetchall()
        return cls(*res[0]) if res else None

    def write(self, db: Database) -> None:
        con, cur = db.connect()
    def write(self, data: Database) -> None:
        con, cur = data.connect()
        cur.execute(f'INSERT INTO customer VALUES(\'{self.email}\', \'{self.pass_hash}\')')
        con.commit()


class Order:
    __slots__ = 'status', 'time_created', 'customer_id', 'row_id'
    __slots__ = ['status', 'time_created', 'customer_id', 'row_id']

    def __init__(self, status: str, time_created: int,
                 customer_id: int, row_id: int = -1) -> None:
@@ -68,31 +69,31 @@ class Order:
        self.row_id = row_id

    @classmethod
    def get(cls, db: Database, customer_id: int) -> Optional['Order']:
        con, cur = db.connect()
    def get(cls, data: Database, customer_id: int) -> Optional['Order']:
        _, cur = data.connect()
        res = cur.execute(
            f"SELECT *, rowid FROM order_ WHERE customer_id = '{customer_id}'"
        ).fetchall()
        return cls(*res[0]) if res else None

    def write(self, db: Database) -> None:
        con, cur = db.connect()
    def write(self, data: Database) -> None:
        con, cur = data.connect()
        cur.execute(f'INSERT INTO order_ VALUES(\'{self.status}\','
                    f'\'{self.time_created}\', \'{self.customer_id}\')')
        con.commit()


def login(db: Database, email: str, password: str) -> Tuple[int, bool]:
def login(data: Database, email: str, password: str) -> Tuple[int, bool]:
    if not re.match(r'^[\w.-]+@([\w-]+\.)+[\w-]{2,4}$', email) or not password:
        return -1, False

    pass_hash = hashlib.sha256(bytes(password, 'utf-8'),
                               usedforsecurity=True).hexdigest()
    cust = Customer.get(db, email)
    cust = Customer.get(data, email)
    if cust:
        if cust.pass_hash == pass_hash:
            return cust.row_id, True
        return -1, False

    Customer(email, pass_hash).write(db)
    return Customer.get(db, email).row_id, True
    Customer(email, pass_hash).write(data)
    return Customer.get(data, email).row_id, True
+5 −6
Original line number Diff line number Diff line
import unittest
import src.database as db
import os

import unittest
import database as db

DATABASE = 'tests.db'

@@ -24,7 +23,7 @@ class TestLogin(unittest.TestCase):
        """ This test focuses on the `login` function from `database.py`. """
        # connect to database
        database = db.Database(DATABASE)
        con, cur = database.connect()
        _, cur = database.connect()

        # tests invalid email and password
        self.assertFalse(db.login(database, '', '')[1])
@@ -91,7 +90,7 @@ class TestCustomer(unittest.TestCase):
        """ This test focuses on the `write` method in `Customer`. """
        # connect to database
        database = db.Database(DATABASE)
        con, cur = database.connect()
        _, cur = database.connect()

        # make sure database is empty
        res = cur.execute('SELECT * FROM customer').fetchall()
@@ -103,7 +102,7 @@ class TestCustomer(unittest.TestCase):
            db.Customer(mail, pass_hash).write(database)

        # make sure they're there
        res = cur.execute(f"SELECT * FROM customer").fetchall()
        res = cur.execute("SELECT * FROM customer").fetchall()
        self.assertEqual(20, len(res))

        # check the data