Skip to content
This repository has been archived by the owner on Aug 5, 2024. It is now read-only.

Commit

Permalink
fix: Make merging from user_licenses multi-threaded
Browse files Browse the repository at this point in the history
  • Loading branch information
rsavoye committed Dec 25, 2023
1 parent 3aa60ef commit eaa1226
Showing 1 changed file with 82 additions and 22 deletions.
104 changes: 82 additions & 22 deletions tm_admin/users/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from dateutil.parser import parse
import tm_admin.types_tm
from tm_admin.types_tm import Userrole, Mappinglevel
import concurrent.futures
from cpuinfo import get_cpu_info

from tm_admin.dbsupport import DBSupport
from tm_admin.users.users_class import UsersTable
Expand All @@ -37,6 +39,29 @@
# Instantiate logger
log = logging.getLogger(__name__)

# The number of threads is based on the CPU cores
info = get_cpu_info()
cores = info["count"]

def importThread(
data: list,
db: PostgresClient,
):
"""Thread to handle importing
Args:
data (list): The list of records to import
db (PostgresClient): A database connection
"""
for record in data:
sql = f" UPDATE users SET licenses = ARRAY[{record[0]['license']}] WHERE id={record[0]['user']}"
# print(sql)
try:
result = db.dbcursor.execute(f"{sql};")
except:
return False

return True

class UsersDB(DBSupport):
def __init__(self,
Expand Down Expand Up @@ -77,43 +102,56 @@ def mergeInterests(self):
data[entry['user_id']] = list()
data[entry['user_id']].append(entry['interest_id'])

for uid, value in data.items():
sql = f" UPDATE users SET interests = ARRAY{str(value)} WHERE id={uid}"
print(sql)
try:
result = self.pg.dbcursor.execute(f"{sql};")
except:
return False
for uid, value in data.items():
sql = f" UPDATE users SET interests = ARRAY{str(value)} WHERE id={uid}"
print(sql)
try:
result = self.pg.dbcursor.execute(f"{sql};")
except:
return False

return True

def mergeLicenses(self):
"""Merge data from the TM user_licenses table into TM Admin."""
table = 'user_licenses'
# FIXME: this shouldn't be hardcoded!
pg = PostgresClient('localhost/tm4')
sql = f"SELECT row_to_json({table}) as row FROM {table}"
# print(sql)
# One database connection per thread
tmpg = list()
for i in range(0, cores + 1):
tmpg.append(PostgresClient('localhost/tm_admin'))

# just one thread to read the data
pg = PostgresClient('localhost/tm4')
try:
result = pg.dbcursor.execute(sql)
except:
log.error(f"Couldn't execute query! {sql}")
return False

result = pg.dbcursor.fetchall()

for record in result:
sql = f" UPDATE users SET licenses = ARRAY[{record[0]['license']}] WHERE id={record[0]['user']}"
print(sql)
try:
result = self.pg.dbcursor.execute(f"{sql};")
except:
return False
data = pg.dbcursor.fetchall()
entries = len(data)
log.debug(f"There are {entries} entries in {table}")
chunk = round(entries / cores)

# if True:
# importThread(data, tmpg[0])
index = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor:
block = 0
while block <= entries:
log.debug("Dispatching Block %d:%d" % (block, block + chunk))
result = executor.submit(importThread, data[block : block + chunk], tmpg[index])
block += chunk
index += 1
executor.shutdown()

return True

def mergeFavorites(self):
table = 'project_favorites'
pg = PostgresClient(inuri)
# FIXME: this shouldn't be hardcoded!
pg = PostgresClient('localhost/tm4')
sql = f"SELECT row_to_json({table}) as row FROM {table}"
# print(sql)
try:
Expand All @@ -122,6 +160,24 @@ def mergeFavorites(self):
log.error(f"Couldn't execute query! {sql}")
return False

result = pg.dbcursor.fetchall()
data = dict()
for record in result:
entry = record[0] # there's only one item in the input data
if entry['user_id'] not in data:
data[entry['user_id']] = list()
data[entry['user_id']].append(entry['project_id'])

for uid, value in data.items():
sql = f" UPDATE users SET favorite_projects = ARRAY{str(value)} WHERE id={uid}"
print(sql)
try:
result = self.pg.dbcursor.execute(f"{sql};")
except:
return False

return True

# These are just convience wrappers to support the REST API.
def updateRole(self,
id: int,
Expand Down Expand Up @@ -208,12 +264,16 @@ def main():

user = UsersDB(args.uri)

if user.mergeInterests():
log.info("UserDB.mergeInterests worked!")
# These may take a long time to complete
# if user.mergeInterests():
# log.info("UserDB.mergeInterests worked!")

if user.mergeLicenses():
log.info("UserDB.mergeLicenses worked!")

# if user.mergeFavorites():
# log.info("UserDB.mergeFavorites worked!")

# user.resetSequence()
# all = user.getAll()
# # Don't pass id, let postgres auto increment
Expand Down

0 comments on commit eaa1226

Please sign in to comment.