import inspect from os import getenv import psycopg2 from psycopg2 import sql def get_connection(): return psycopg2.connect( host=getenv("POSTGRES_HOST"), database=getenv("POSTGRES_DB"), user=getenv("POSTGRES_USER"), password=getenv("POSTGRES_PASSWORD"), ) def create_table(name): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("CREATE TABLE {table} (row_number SERIAL PRIMARY KEY)").format( table=sql.Identifier(name) ) ) conn.commit() def add_column(name, column, type): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("ALTER TABLE {table} ADD {column}" + type).format( table=sql.Identifier(name), column=sql.Identifier(column) ) ) conn.commit() def insert_columns(name, data): conn = get_connection() cur = conn.cursor() str = "(" + "DEFAULT" + ", %s" * (len(data) - 1) + ", %s" + ")" # TODO: change. print(str) cur.execute( sql.SQL("INSERT INTO {table} VALUES" + str).format(table=sql.Identifier(name)), data, ) conn.commit() def edit_columns(name, columns, data, id): conn = get_connection() cur = conn.cursor() i = 0 print(columns, data, id) for column in columns: cur.execute( sql.SQL("UPDATE {table} SET {col} = %s WHERE row_number = {id}").format( table=sql.Identifier(name), col=sql.Identifier(column), id=sql.Literal(id) ), [data[i]], ) i += 1 conn.commit() def remove_column(name, column): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("ALTER TABLE {table} DROP COLUMN {column}").format( table=sql.Identifier(name), column=sql.Identifier(column) ) ) conn.commit() def create_sort(name): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL( "CREATE TABLE {table} (property TEXT, _order CHAR(3), priority int)" ).format(table=sql.Identifier(name + "_sort")) ) conn.commit() def add_sort(name, property, order, priority): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("INSERT INTO {table} VALUES (%s, %s, %s)").format( table=sql.Identifier(name + "_sort") ), (property, order, priority), ) conn.commit() def sort(name): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("SELECT * FROM {table} ORDER BY priority").format( table=sql.Identifier(name + "_sort") ), ) order_clause = "ORDER BY " i = 0 for sort in cur: if i > 0: order_clause += ", " order_clause += sort[0] + " " + sort[1] i += 1 cur.execute( sql.SQL("SELECT * FROM {table} " + order_clause).format( table=sql.Identifier(name) ), ) return list(cur.fetchall()) def add_function(): conn = get_connection() cur = conn.cursor() cur.execute( """ CREATE OR REPLACE FUNCTION trigger_function() RETURNS TRIGGER LANGUAGE PLPGSQL AS $$ DECLARE name text := TG_ARGV[0]::text; BEGIN IF NEW.property NOT IN ( SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = name) THEN RAISE EXCEPTION 'ERROR %', NEW.property; END IF; RETURN NEW; END; $$; """ ) conn.commit() def add_filter_trigger(name): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL( """ CREATE TRIGGER {filter} BEFORE INSERT OR UPDATE ON {filter} FOR EACH ROW EXECUTE PROCEDURE trigger_function({table}); """ ).format(table=sql.Identifier(name), filter=sql.Identifier(name + "_filter")) ) conn.commit() def create_filter(name): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL( """ CREATE TABLE {table} ( property TEXT, value TEXT, function TEXT CHECK (function IN ('c', 'e', 'n')) ) """ ).format(table=sql.Identifier(name + "_filter")) ) conn.commit() def add_filter(name, property, value, function): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("INSERT INTO {table} VALUES (%s, %s, %s)").format( table=sql.Identifier(name + "_filter") ), (property, value, function), ) conn.commit() def filter(name): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("SELECT * FROM {table}").format(table=sql.Identifier(name + "_filter")), ) filter_clause = "WHERE " i = 0 for sort in cur: if i > 0: filter_clause += " AND " filter_clause += sort[0] match sort[2]: case "e": filter_clause += " = '" + sort[1] + "'" case "ne": filter_clause += " <> '" + sort[1] + "'" case "le": filter_clause += " <= " + sort[1] case "ge": filter_clause += " >= " + sort[1] case "l": filter_clause += " < " + sort[1] case "g": filter_clause += " > " + sort[1] case "c": filter_clause += " ILIKE '%" + sort[1] + "'" case "_": raise "Invalid filter function" i += 1 cur.execute( sql.SQL("SELECT * FROM {table} " + filter_clause).format( table=sql.Identifier(name) ), ) return list(cur.fetchall()) def create_user_table(): conn = get_connection() cur = conn.cursor() cur.execute( "CREATE TABLE users (id SERIAL PRIMARY KEY, username TEXT UNIQUE, password TEXT)" ) conn.commit() def add_user(username, password): conn = get_connection() cur = conn.cursor() cur.execute( "INSERT INTO users VALUES (DEFAULT, %s, %s)", (username, password), ) conn.commit() def get_users(): conn = get_connection() cur = conn.cursor() cur.execute( "SELECT * FROM users" ) return list(cur.fetchall()) def get_user_by_id(id): conn = get_connection() cur = conn.cursor() cur.execute( "SELECT * FROM users WHERE id = %s", id, ) return cur.fetchone() def get_user_by_username(username): conn = get_connection() cur = conn.cursor() cur.execute( sql.SQL("SELECT * FROM users WHERE username = {username}").format( username=sql.Literal(username) ) ) return cur.fetchone() def edit_user(id, username, password): conn = get_connection() cur = conn.cursor() columns = inspect.getfullargspec(edit_user)[0][1:] data = [username, password] i = -1 for column in columns: i += 1 if data[i] is None: continue print(id) cur.execute( sql.SQL("UPDATE users SET {col} = {value} WHERE id = {id}").format( col=sql.Identifier(column), value=sql.Literal(data[i]), id=sql.Literal(id) ), ) conn.commit()