bsition/backend/postgres.py

305 lines
7.4 KiB
Python

import inspect
from os import getenv
import psycopg2
from psycopg2 import sql
def get_connection():
return psycopg2.connect(
host=getenv("POSTGRES_HOST"),
database=getenv("POSTGRES_DB"),
user=getenv("POSTGRES_USER"),
password=getenv("POSTGRES_PASSWORD"),
)
def create_table(name):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("CREATE TABLE {table} (row_number SERIAL PRIMARY KEY)").format(
table=sql.Identifier(name)
)
)
conn.commit()
def add_column(name, column, type):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("ALTER TABLE {table} ADD {column}" + type).format(
table=sql.Identifier(name), column=sql.Identifier(column)
)
)
conn.commit()
def insert_columns(name, data):
conn = get_connection()
cur = conn.cursor()
str = "(" + "DEFAULT" + ", %s" * (len(data) - 1) + ", %s" + ")" # TODO: change.
print(str)
cur.execute(
sql.SQL("INSERT INTO {table} VALUES" + str).format(table=sql.Identifier(name)),
data,
)
conn.commit()
def edit_columns(name, columns, data, id):
conn = get_connection()
cur = conn.cursor()
i = 0
print(columns, data, id)
for column in columns:
cur.execute(
sql.SQL("UPDATE {table} SET {col} = %s WHERE row_number = {id}").format(
table=sql.Identifier(name), col=sql.Identifier(column), id=sql.Literal(id)
),
[data[i]],
)
i += 1
conn.commit()
def remove_column(name, column):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("ALTER TABLE {table} DROP COLUMN {column}").format(
table=sql.Identifier(name), column=sql.Identifier(column)
)
)
conn.commit()
def create_sort(name):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL(
"CREATE TABLE {table} (property TEXT, _order CHAR(3), priority int)"
).format(table=sql.Identifier(name + "_sort"))
)
conn.commit()
def add_sort(name, property, order, priority):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("INSERT INTO {table} VALUES (%s, %s, %s)").format(
table=sql.Identifier(name + "_sort")
),
(property, order, priority),
)
conn.commit()
def sort(name):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("SELECT * FROM {table} ORDER BY priority").format(
table=sql.Identifier(name + "_sort")
),
)
order_clause = "ORDER BY "
i = 0
for sort in cur:
if i > 0:
order_clause += ", "
order_clause += sort[0] + " " + sort[1]
i += 1
cur.execute(
sql.SQL("SELECT * FROM {table} " + order_clause).format(
table=sql.Identifier(name)
),
)
return list(cur.fetchall())
def add_function():
conn = get_connection()
cur = conn.cursor()
cur.execute(
"""
CREATE OR REPLACE FUNCTION trigger_function()
RETURNS TRIGGER
LANGUAGE PLPGSQL
AS $$
DECLARE
name text := TG_ARGV[0]::text;
BEGIN
IF NEW.property NOT IN (
SELECT column_name
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = name)
THEN
RAISE EXCEPTION 'ERROR %', NEW.property;
END IF;
RETURN NEW;
END;
$$;
"""
)
conn.commit()
def add_filter_trigger(name):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL(
"""
CREATE TRIGGER {filter}
BEFORE INSERT OR UPDATE
ON {filter}
FOR EACH ROW
EXECUTE PROCEDURE trigger_function({table});
"""
).format(table=sql.Identifier(name), filter=sql.Identifier(name + "_filter"))
)
conn.commit()
def create_filter(name):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL(
"""
CREATE TABLE {table} (
property TEXT,
value TEXT,
function TEXT CHECK (function IN ('c', 'e', 'n'))
)
"""
).format(table=sql.Identifier(name + "_filter"))
)
conn.commit()
def add_filter(name, property, value, function):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("INSERT INTO {table} VALUES (%s, %s, %s)").format(
table=sql.Identifier(name + "_filter")
),
(property, value, function),
)
conn.commit()
def filter(name):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("SELECT * FROM {table}").format(table=sql.Identifier(name + "_filter")),
)
filter_clause = "WHERE "
i = 0
for sort in cur:
if i > 0:
filter_clause += " AND "
filter_clause += sort[0]
match sort[2]:
case "e":
filter_clause += " = '" + sort[1] + "'"
case "ne":
filter_clause += " <> '" + sort[1] + "'"
case "le":
filter_clause += " <= " + sort[1]
case "ge":
filter_clause += " >= " + sort[1]
case "l":
filter_clause += " < " + sort[1]
case "g":
filter_clause += " > " + sort[1]
case "c":
filter_clause += " ILIKE '%" + sort[1] + "'"
case "_":
raise "Invalid filter function"
i += 1
cur.execute(
sql.SQL("SELECT * FROM {table} " + filter_clause).format(
table=sql.Identifier(name)
),
)
return list(cur.fetchall())
def create_user_table():
conn = get_connection()
cur = conn.cursor()
cur.execute(
"CREATE TABLE users (id SERIAL PRIMARY KEY, username TEXT UNIQUE, password TEXT)"
)
conn.commit()
def add_user(username, password):
conn = get_connection()
cur = conn.cursor()
cur.execute(
"INSERT INTO users VALUES (DEFAULT, %s, %s)",
(username, password),
)
conn.commit()
def get_users():
conn = get_connection()
cur = conn.cursor()
cur.execute(
"SELECT * FROM users"
)
return list(cur.fetchall())
def get_user_by_id(id):
conn = get_connection()
cur = conn.cursor()
cur.execute(
"SELECT * FROM users WHERE id = %s",
id,
)
return cur.fetchone()
def get_user_by_username(username):
conn = get_connection()
cur = conn.cursor()
cur.execute(
sql.SQL("SELECT * FROM users WHERE username = {username}").format(
username=sql.Literal(username)
)
)
return cur.fetchone()
def edit_user(id, username, password):
conn = get_connection()
cur = conn.cursor()
columns = inspect.getfullargspec(edit_user)[0][1:]
data = [username, password]
i = -1
for column in columns:
i += 1
if data[i] is None:
continue
print(id)
cur.execute(
sql.SQL("UPDATE users SET {col} = {value} WHERE id = {id}").format(
col=sql.Identifier(column), value=sql.Literal(data[i]), id=sql.Literal(id)
),
)
conn.commit()