Skip to content

Commit

Permalink
Merge pull request #26 from pedrobcst/multiuser_query
Browse files Browse the repository at this point in the history
Added support for multiple uses querying at same time (experimental..)
  • Loading branch information
pedrobcst authored Apr 15, 2022
2 parents d57bce4 + 7c55464 commit 13b1df5
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 73 deletions.
11 changes: 5 additions & 6 deletions Xerus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
from __future__ import annotations

import os
import sys
from multiprocessing import Pool
from pathlib import Path
from typing import Any, List, Tuple, Union
from typing import Any, List, Union

import numpy as np
import pandas as pd
Expand All @@ -42,10 +40,10 @@
run_correlation_analysis,
run_correlation_analysis_riet)
from Xerus.similarity.visualization import make_plot_all, make_plot_step
from Xerus.utils.cifutils import make_system, make_system_types
from Xerus.utils.cifutils import make_system
from Xerus.utils.preprocessing import remove_baseline, standarize_intensity
from Xerus.utils.tools import (blockPrinting, create_folder, group_data,
load_json, make_offset, normalize_formula)
from Xerus.utils.tools import (create_folder, group_data, load_json,
make_offset, normalize_formula)
from Xerus.utils.tools import plotly_add as to_add
from Xerus.utils.tools import save_json

Expand Down Expand Up @@ -227,6 +225,7 @@ def get_cifs(
outfolder=self.working_folder,
maxn=self.maxsys,
max_oxy=self.max_oxy,
name = self.name
)
self.cif_info = cif_meta
if ignore_provider is not None:
Expand Down
43 changes: 28 additions & 15 deletions Xerus/db/localdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from __future__ import annotations

import os
import shutil
import sys
from pathlib import Path
from Xerus.utils.cifutils import write_cif, make_system_types
from Xerus.utils.tools import create_folder, load_json
from typing import List, Tuple

import pandas as pd
import os
import pymongo
from pymongo.errors import ConnectionFailure
from typing import Tuple, List
from Xerus.settings.settings import DB_CONN
import shutil

from Xerus.utils.cifutils import make_system_types, write_cif
from Xerus.utils.tools import create_folder, load_json


class LocalDB:
Expand Down Expand Up @@ -61,8 +61,14 @@ def __init__(self, DB_NAME="CIF", COLLECTION_NAME="cifs"):
except:
raise ConnectionFailure
self.database = self.client[self.DB_NAME]
# If collection does not exist yet, create it and create a unique ID for providers.
if self.COLLECTION_NAME not in self.database.list_collection_names():
self.database.create_collection(self.COLLECTION_NAME)
self.database[self.COLLECTION_NAME].create_index("id", unique=True)
self.cif_meta = self.database[self.COLLECTION_NAME]
self.cif_p1 = self.database['AFLOW-P1']

# This will be removed in the future.
# self.cif_p1 = self.database['AFLOW-P1']

def check_system(self, system_type: str) -> bool:
"""
Expand Down Expand Up @@ -137,10 +143,15 @@ def upload_many(self, data: Tuple[dict]) -> None:
data : array-like
An array-like containing dictionaries with data for upload into the database
"""
self.cif_meta.insert_many(data)
# Ignore error due to duplicate keys (if we are ensuring uniqueness for provider ID).
try:
self.cif_meta.insert_many(data, ordered = False)
except pymongo.errors.BulkWriteError:
pass



def check_and_download(self, system_type: str) -> LocalDB:
def check_and_download(self, system_type: str, name : str) -> LocalDB:
"""
Check for system type in the database. If it is missing, query providers and download the respective CIFs
pertaining to that system.
Expand All @@ -152,6 +163,8 @@ def check_and_download(self, system_type: str) -> LocalDB:
`system_type` is a string representation of the chemical system of elements. For example, for Ho and B containing
materials, `system_type` would be equal to: "Ho-B". For constructing the string in correct manner, please refer
to Xerus.utils.tools make_system_type function.
name: str
Data set name used for folder creation.
Returns
-------
Expand All @@ -160,10 +173,10 @@ def check_and_download(self, system_type: str) -> LocalDB:
from Xerus.queriers.multiquery import multiquery
if not self.check_system(system_type):
elements = system_type.split("-")
multiquery(elements, max_num_elem=len(elements))
multiquery(elements, max_num_elem=len(elements), name = name)
return self

def check_all(self, system_types: Tuple[str]) -> LocalDB:
def check_all(self, system_types: Tuple[str], name: str) -> LocalDB:
"""
Check for a list of system types and download missing.
Expand All @@ -184,10 +197,10 @@ def check_all(self, system_types: Tuple[str]) -> LocalDB:
pass
else:
print("Checking the following combination:{}".format(combination))
self.check_and_download(combination)
self.check_and_download(combination, name = name)
return self

def get_cifs_and_write(self, element_list : List[str], outfolder: str, maxn: int, max_oxy: int = 2) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
def get_cifs_and_write(self, element_list : List[str], name: str, outfolder: str, maxn: int, max_oxy: int = 2) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Automatically receives a list of elements and make `system_types`.
Expand Down Expand Up @@ -224,7 +237,7 @@ def get_cifs_and_write(self, element_list : List[str], outfolder: str, maxn: int
folder_to_write = 'cifs/'
final_path = os.path.join(outfolder, folder_to_write)
queries = make_system_types(element_list, maxn)
self.check_all(queries)
self.check_all(queries, name = name)

# check oxygen limit
if max_oxy is not None:
Expand Down Expand Up @@ -274,7 +287,7 @@ def find_duplicate(self, field):

pipeline = [
{"$group": {"_id": field, "count": {"$sum": 1}}}, # group by field and sum occurence of field
{"$match": {"count": {"$gte": 1}}} # match to values of count > 1.
{"$match": {"count": {"$gt": 1}}} # match to values of count > 1.
]
return pd.DataFrame(self.cif_meta.aggregate(pipeline))

Expand Down
55 changes: 16 additions & 39 deletions Xerus/queriers/multiquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from Xerus.utils.cifutils import (get_provider, make_system, rename_multicif,
standarize)

dbhandler = LocalDB()
abs_path = Path(__file__).parent
proj_path = os.sep.join(str(abs_path).split(os.sep)[1:-1])
dump_folders = ["mp_dump/", "cod_dump/", "aflow_dump/", "oqmd_dump/"]
Expand Down Expand Up @@ -81,7 +80,7 @@ def load_json(path: os.PathLike) -> dict:
return json.load(fp)


def multiquery(element_list: List[str], max_num_elem: int, resync:bool = False) -> None:
def multiquery(element_list: List[str], max_num_elem: int, name: str, resync:bool = False) -> None:
"""
Query multiple providers for a given element list and maximum number of elements
Expand All @@ -98,15 +97,18 @@ def multiquery(element_list: List[str], max_num_elem: int, resync:bool = False)
-------
None
"""
dbhandler = LocalDB()
query_type = make_system(standarize(element_list))
td=[]
if dbhandler.check_system(query_type) and not resync:
return "Asked to query {}, but already present in database".format(query_type)
return f"Asked to query {query_type}, but already present in database"

cod_path = str(os.path.join(abs_path, "cod_dump")) + os.sep
mp_path = os.path.join(abs_path, "mp_dump")
aflow_path = os.path.join(abs_path, "aflow_dump")
oqmd_path = os.path.join(abs_path, "oqmd_dump")


cod_path = str(os.path.join(abs_path, f"{name}_{query_type}_COD")) + os.sep
mp_path = os.path.join(abs_path, f"{name}_{query_type}_MP")
# aflow_path = os.path.join(abs_path, "aflow_dump")
oqmd_path = os.path.join(abs_path, f"{name}_{query_type}_OQMD")
## MP query
querymp(inc_eles=element_list,
max_num_elem=max_num_elem,
Expand All @@ -121,32 +123,6 @@ def multiquery(element_list: List[str], max_num_elem: int, resync:bool = False)
td.append(cod_path)
cod.query_one(element_list=element_list, rename=True)

# AFLOW query
# AFLOWQuery(element_list, outfolder=aflow_path).query()
# td.append(aflow_path)

# # OQMD Query
# print("Querying OQMD...")
# oqmd = OQMDQuery(element_list=element_list, dumpfolder=oqmd_path)
# oqmd.query()
# td.append(oqmd_path)


# TODO: Evaluate on how to implement each querier through optimade via the generic OptimadeQuery interface
# # Some example OPTIMADE queries
# print("Querying some OPTIMADE APIs")
# cod_optimade = OptimadeQuery(
# base_url="https://www.crystallography.net/tcod/optimade/v1/",
# elements=element_list,
# folder_path=Path("optimade")
# )

# Currently only MP works.
# mp_optimade = OptimadeQuery(
# base_url="https://optimade.materialsproject.org/v1/",
# elements=element_list,
# folder_path=Path("optimade")
# )
print(f"Querying OQMD through OPTIMADE....")
oqmd_optimade = OptimadeQuery(
base_url="https://oqmd.org/optimade/v1",
Expand All @@ -157,9 +133,8 @@ def multiquery(element_list: List[str], max_num_elem: int, resync:bool = False)
oqmd_optimade.query()
td.append(oqmd_path)

# td.append(Path("optimade"))

movecifs(dump_folders=td)
test_folder = os.path.join(abs_path,f"{name}_{query_type}_cifs")
movecifs(dump_folders=td, test_folder=test_folder)
print("Finished downloading CIFs.")

if resync:
Expand All @@ -178,13 +153,13 @@ def multiquery(element_list: List[str], max_num_elem: int, resync:bool = False)
## TEST CIFS ##
print("Testing, uploading and deleting cifs...")
#cmd = 'python ' +os.sep + str(os.path.join(proj_path,'test_cif.py'))
cmd = "python " + str(os.path.join(abs_path, "tcif.py"))
cmd = f"python {os.path.join(abs_path, 'tcif.py')} {test_folder}"
print(cmd)
os.system(cmd)

# ## UPDATE DB ##
print("Uploading database with cifs..")
data = load_json(os.path.join(abs_path,'queried_cifs','cif.json'))
data = load_json(os.path.join(test_folder, 'cif.json'))
print(len(data))
if len(data) == 0:
if resync:
Expand All @@ -208,5 +183,7 @@ def multiquery(element_list: List[str], max_num_elem: int, resync:bool = False)

## Remove rest ##
print("Deleting..")
shutil.rmtree(os.path.join(abs_path,'queried_cifs'))
shutil.rmtree(test_folder)
# Kill connection? We should really add a direct way to do this instead of deleting. This is a placeholder.
del(dbhandler)

29 changes: 16 additions & 13 deletions Xerus/queriers/tcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,15 @@
import sys
from pathlib import Path

from Xerus.engine.gsas2riet import run_gsas, simulate_pattern
from tqdm import tqdm
from Xerus.engine.gsas2riet import run_gsas
from Xerus.settings.settings import INSTR_PARAMS, TEST_XRD
from Xerus.utils.cifutils import get_ciflist

instr_params = INSTR_PARAMS
powder_data = TEST_XRD
abs_path = Path(__file__).parent
cif_folder = os.path.join(abs_path, "queried_cifs")
cif_name = "cif.json"
datapath = os.path.join(cif_folder,cif_name)
from tqdm import tqdm
from Xerus.utils.cifutils import get_ciflist

if not os.path.exists(datapath):
get_ciflist(str(cif_folder) + os.sep, "r", True)
with open(datapath, "r") as fp:
data = json.load(fp)
# cif_folder = os.path.join(abs_path, "queried_cifs")

def load_json(path):
with open(path, "r") as fp:
Expand All @@ -66,7 +59,7 @@ def load_json(path):
def check_extension(file, allowed=(".json", ".cif", ".ipynb_checkpoints")):
return any([file.endswith(extension) for extension in allowed])

def save_json(dict, filename=cif_name, folder=cif_folder):
def save_json(dict, filename, folder):
"""
:param dict:
Expand All @@ -78,6 +71,16 @@ def save_json(dict, filename=cif_name, folder=cif_folder):
json.dump(dict, fp)

if __name__ == "__main__":
# Get cif folder from command line
cif_folder = sys.argv[1]
cif_name = "cif.json"
datapath = os.path.join(cif_folder,cif_name)

if not os.path.exists(datapath):
get_ciflist(str(cif_folder) + os.sep, "r", True)
with open(datapath, "r") as fp:
data = json.load(fp)


for i, entry in tqdm(enumerate(data)):
filename = entry['filename']
Expand Down Expand Up @@ -120,7 +123,7 @@ def save_json(dict, filename=cif_name, folder=cif_folder):
# data[i]['gsas_status']['status'] = 1
# data[i]['gsas_status']['tested'] = True
# data[i]['ran'] = True
save_json(data)
save_json(data, "cif.json", cif_folder)
## clean up .lst, .gpx, .bak
files = os.listdir(cif_folder)
for file in files:
Expand Down

0 comments on commit 13b1df5

Please sign in to comment.