Skip to content

Commit ab42d09

Browse files
committed
add a check for hash uniqueness across all tables
1 parent 053843a commit ab42d09

File tree

6 files changed

+60
-1
lines changed

6 files changed

+60
-1
lines changed

sqlab/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.3"
1+
__version__ = "0.5.4"

sqlab/cmd_create.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ def run(config: dict):
7979
sql_dump.write(data_inserts_queries)
8080
db.execute_non_select(data_inserts_queries)
8181

82+
# Check that the hash columns contain unique values across all tables. An alternative would be
83+
# to declare each hash column as UNIQUE, but:
84+
# 1. The uniqueness would not be cross-table.
85+
# 2. In MySQL, the hash column is calculated from all other columns except the auto-incremented
86+
# primary key column. In a stateful game like SQL Island, the player is instructed to insert
87+
# a row in the table inhabitant. If they repeat the same insertion, the hash will be the
88+
# same (although the personid will be different), raising an IntegrityError.
89+
90+
seen_hashes = {} # use a dictionary for better warning messages
91+
for table_name in db.get_table_names():
92+
query = f"SELECT * FROM {table_name};"
93+
(_, _, rows) = db.execute_select(query)
94+
for row in rows:
95+
if (hash := row[-1]) in seen_hashes:
96+
(t1, v1) = seen_hashes[hash]
97+
(t2, v2) = (table_name, row[:-1])
98+
print(f"{WARNING}Hash collision:\n {t1}: {v1}\n {t2}: {v2}\nhave same hash {hash}.{RESET}")
99+
seen_hashes[hash] = (table_name, row[:-1])
100+
82101
sql_dump.write(db.fk_constraints_queries)
83102

84103
# If the source is a notebook, parse it and populate the `records` list.

sqlab/database.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def get_headers(self, table: str, keep_auto_increment_columns=True) -> list[str]
3535
"""
3636
raise NotImplementedError
3737

38+
def get_table_names(self) -> list[str]:
39+
"""Return the names of all the tables in the DB, except the utility tables.
40+
These include those starting with "sqlab_", and, in SQLite, the virtual tables
41+
decrypt and sqlean_define."""
42+
raise NotImplementedError
43+
3844
def encrypt(self, plain: str, token: int) -> str:
3945
"""Return the encrypted version of the given plain text."""
4046
raise NotImplementedError

sqlab/dbms/mysql/database.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ def get_headers(self, table: str, keep_auto_increment_columns=False) -> list[str
3737
headers = [row[0] for row in cursor]
3838
return headers
3939

40+
def get_table_names(self) -> list[str]:
41+
query = f"""
42+
SELECT table_name
43+
FROM information_schema.tables
44+
WHERE table_schema = "{self.cnx.database}"
45+
AND table_name NOT LIKE 'sqlab_%';
46+
"""
47+
with self.cnx.cursor() as cursor:
48+
cursor.execute(query)
49+
return [row[0] for row in cursor]
50+
4051
def encrypt(self, clear_text, token):
4152
"""In MySQL, the function aes_encrypt() takes a numeric key."""
4253
with self.cnx.cursor() as cursor:

sqlab/dbms/postgresql/database.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ def get_headers(self, table, keep_auto_increment_columns=True):
3939
headers = [row[0] for row in cursor.fetchall()]
4040
return headers
4141

42+
def get_table_names(self) -> list[str]:
43+
query = """
44+
SELECT table_name
45+
FROM information_schema.tables
46+
WHERE table_schema = 'public'
47+
AND table_name NOT LIKE 'sqlab_%';
48+
"""
49+
with self.cnx.cursor() as cursor:
50+
cursor.execute(query)
51+
return [row[0] for row in cursor]
52+
4253
def encrypt(self, clear_text, token):
4354
"""
4455
In PostgreSQL, the function pgp_sym_encrypt() takes a textual key, not a numeric one.

sqlab/dbms/sqlite/database.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def get_headers(self, table: str, keep_auto_increment_columns=True) -> list[str]
4444
headers = [header for header in headers if header != "hash"]
4545
return headers
4646

47+
def get_table_names(self) -> list[str]:
48+
query = """
49+
SELECT name
50+
FROM sqlite_master
51+
WHERE type = 'table'
52+
AND name NOT LIKE 'sqlab_%'
53+
AND name NOT IN ('sqlean_define', 'decrypt');
54+
"""
55+
cursor = self.cnx.cursor()
56+
cursor.execute(query)
57+
return [row[0] for row in cursor.fetchall()]
58+
4759
def encrypt(self, clear_text, token):
4860
query = f"SELECT encode(sha256({token}), 'hex') || encode(brotli({repr(clear_text)}), 'hex');"
4961
return repr(self.execute_select(query)[2][0][0])

0 commit comments

Comments
 (0)