diff --git a/bsition/backend/postgres/relations.py b/bsition/backend/postgres/relations.py new file mode 100644 index 0000000..7f44c89 --- /dev/null +++ b/bsition/backend/postgres/relations.py @@ -0,0 +1,120 @@ +from psycopg2 import sql + +from bsition.backend.postgres.utils import get_connection + + +def create_relations_tables(): + conn = get_connection() + cur = conn.cursor() + cur.execute( + sql.SQL(""" + CREATE TABLE table_access ( + user_id INTEGER REFERENCES users(id), + table_id INTEGER, + access_type INTEGER CHECK (access_type IN (1, 2, 3)), + PRIMARY KEY (user_id, table_id) + ) + """) + ) + cur.execute( + sql.SQL(""" + CREATE TABLE doc_access ( + user_id INTEGER REFERENCES users(id), + doc_id INTEGER, + access_type INTEGER CHECK (access_type IN (1, 2, 3)), + PRIMARY KEY (user_id, doc_id) + ) + """) + ) + conn.commit() + + +def give_access_table(user_id, table_id, access_type): + give_access(user_id, table_id, access_type, "table") + + +def give_access_doc(user_id, table_id, access_type): + give_access(user_id, table_id, access_type, "doc") + + +def give_access(user_id, id, access_type, destination): + if destination != "table" and destination != "doc": + raise "Invalid access destination" + + conn = get_connection() + cur = conn.cursor() + cur.execute( + sql.SQL(""" + INSERT INTO {destination_name} (user_id, {destination_id}, access_type) + VALUES ({user_id}, {id}, {access_type}) + ON CONFLICT (user_id, {destination_id}) DO UPDATE + SET access_type = {access_type} + """).format( + user_id=sql.Literal(user_id), + destination_name=sql.Identifier(destination + "_access"), + destination_id=sql.Identifier(destination + "_id"), + access_type=sql.Literal(access_type), + id=sql.Literal(id) + ) + ) + conn.commit() + + +def has_access_table(user_id, table_id): + return has_access(user_id, table_id, "table") + + +def has_access_doc(user_id, table_id): + return has_access(user_id, table_id, "doc") + + +def has_access(user_id, id, destination): + if destination != "table" and destination != "doc": + raise "Invalid access destination" + conn = get_connection() + cur = conn.cursor() + cur.execute( + sql.SQL("SELECT access_type FROM {destination_access} WHERE user_id = {user_id} AND {destination_id} = {id}").format( + user_id=sql.Literal(user_id), + destination_access=sql.Identifier(destination + "_access"), + destination_id=sql.Identifier(destination + "_id"), + id=sql.Literal(id) + ) + ) + return list(cur.fetchall()) + + +def deny_access_table(user_id, table_id): + return deny_access(user_id, table_id, "table") + + +def deny_access_doc(user_id, table_id): + return deny_access(user_id, table_id, "doc") + + +def deny_access(user_id, id, destination): + if destination != "table" and destination != "doc": + raise "Invalid access destination" + + conn = get_connection() + cur = conn.cursor() + cur.execute( + sql.SQL("DELETE FROM {destination_access} WHERE user_id = {user_id} AND {destination_id} = {id}").format( + user_id=sql.Literal(user_id), + destination_access=sql.Identifier(destination + "_access"), + destination_id=sql.Identifier(destination + "_id"), + id=sql.Literal(id) + ) + ) + conn.commit() + + +def get_accesible_documents(user_id): + conn = get_connection() + cur = conn.cursor() + cur.execute( + sql.SQL("SELECT doc_id FROM doc_access WHERE user_id = {user_id}").format( + user_id=sql.Literal(user_id) + ) + ) + return list(cur.fetchall())