diff --git a/dptb/data/build.py b/dptb/data/build.py index a58a5b04..d2c6e56a 100644 --- a/dptb/data/build.py +++ b/dptb/data/build.py @@ -13,7 +13,7 @@ from dptb.data import AtomicDataset, register_fields from dptb.utils import instantiate, get_w_prefix from dptb.utils.tools import j_loader -from dptb.utils.argcheck import normalize_setinfo +from dptb.utils.argcheck import normalize_setinfo, normalize_lmdbsetinfo def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: @@ -153,7 +153,7 @@ def build_dataset( else: idp = None - if dataset_type in ["DefaultDataset", "DeePHDataset", "HDF5Dataset"]: + if dataset_type in ["DefaultDataset", "DeePHDataset", "HDF5Dataset", "LMDBDataset"]: # Explore the dataset's folder structure. #include_folders = [] @@ -176,7 +176,10 @@ def build_dataset( include_folders=[] for idir in prefix_folders: if os.path.isdir(idir): - if not glob.glob(os.path.join(idir, '*.dat')) and not glob.glob(os.path.join(idir, '*.traj')) and not glob.glob(os.path.join(idir, '*.h5')): + if not glob.glob(os.path.join(idir, '*.dat')) \ + and not glob.glob(os.path.join(idir, '*.traj')) \ + and not glob.glob(os.path.join(idir, '*.h5')) \ + and not glob.glob(os.path.join(idir, '*.mdb')): raise Exception(f"{idir} does not have the proper traj data files. Please check the data files.") include_folders.append(idir.split('/')[-1]) @@ -191,7 +194,10 @@ def build_dataset( #if "info.json" in os.listdir(root): if os.path.exists(f"{root}/info.json"): public_info = j_loader(os.path.join(root, "info.json")) - public_info = normalize_setinfo(public_info) + if dataset_type == "LMDBDataset": + public_info = normalize_lmdbsetinfo(public_info) + else: + public_info = normalize_setinfo(public_info) print("A public `info.json` file is provided, and will be used by the subfolders who do not have their own `info.json` file.") else: public_info = None @@ -202,7 +208,10 @@ def build_dataset( if os.path.exists(f"{root}/{file}/info.json"): # use info provided in this trajectory. info = j_loader(f"{root}/{file}/info.json") - info = normalize_setinfo(info) + if dataset_type == "LMDBDataset": + info = normalize_lmdbsetinfo(info) + else: + info = normalize_setinfo(info) info_files[file] = info elif public_info is not None: # use public info instead @@ -234,7 +243,7 @@ def build_dataset( get_eigenvalues=get_eigenvalues, info_files = info_files ) - else: + elif dataset_type == "HDF5Dataset": dataset = HDF5Dataset( root=root, type_mapper=idp, @@ -244,40 +253,16 @@ def build_dataset( get_eigenvalues=get_eigenvalues, info_files = info_files ) - - elif dataset_type == "LMDBDataset": - assert prefix is not None, "The prefix is not provided. Please provide the prefix to select the trajectory folders." - prefix_folders = glob.glob(f"{root}/{prefix}*.lmdb") - include_folders=[] - for idir in prefix_folders: - if os.path.isdir(idir): - if not glob.glob(os.path.join(idir, '*.mdb')): - raise Exception(f"{idir} does not have the proper traj data files. Please check the data files.") - include_folders.append(idir.split('/')[-1]) - - assert isinstance(include_folders, list) and len(include_folders) == 1, "No trajectory folders are found. Please check the prefix." - - # See if a public info is provided. - #if "info.json" in os.listdir(root): - - if os.path.exists(f"{root}/info.json"): - info = j_loader(f"{root}/info.json") - else: - print("Please provide a info.json file.") - raise Exception("info.json is not properly provided for this dataset.") - - # We will sort the info_files here. - # The order itself is not important, but must be consistant for the same list. - - dataset = LMDBDataset( - root=os.path.join(root, include_folders[0]), + elif dataset_type == "LMDBDataset": + dataset = LMDBDataset( + root=root, type_mapper=idp, - info=info, orthogonal=orthogonal, get_Hamiltonian=get_Hamiltonian, get_overlap=get_overlap, get_DM=get_DM, get_eigenvalues=get_eigenvalues, + info_files = info_files ) else: diff --git a/dptb/data/dataset/lmdb_dataset.py b/dptb/data/dataset/lmdb_dataset.py index 4f8a9064..85f4d928 100644 --- a/dptb/data/dataset/lmdb_dataset.py +++ b/dptb/data/dataset/lmdb_dataset.py @@ -25,7 +25,7 @@ class LMDBDataset(AtomicDataset): def __init__( self, root: str, - info: dict, + info_files: dict, url: Optional[str] = None, include_frames: Optional[List[int]] = None, type_mapper: TypeMapper = None, @@ -39,9 +39,7 @@ def __init__( # See if a subclass defines some inputs self.url = getattr(type(self), "URL", url) self.include_frames = include_frames - self.info = info # there should be one info file for one LMDB Dataset - - assert "r_max" in info + self.info_files = info_files # there should be one info file for one LMDB Dataset self.data = None @@ -66,10 +64,16 @@ def __init__( assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData." - db_env = lmdb.open(os.path.join(self.root), readonly=True, lock=False) - with db_env.begin() as txn: - self.num_graphs = txn.stat()['entries'] - db_env.close() + self.num_graphs = 0 + self.file_map = [] + self.index_map = [] + for file in self.info_files.keys(): + db_env = lmdb.open(os.path.join(self.root, file), readonly=True, lock=False) + with db_env.begin() as txn: + self.num_graphs += txn.stat()['entries'] + self.file_map += [file] * txn.stat()['entries'] + self.index_map += list(range(txn.stat()['entries'])) + db_env.close() def len(self): return self.num_graphs @@ -94,9 +98,9 @@ def download(self): extract_zip(download_path, self.raw_dir) def get(self, idx): - db_env = lmdb.open(os.path.join(self.root), readonly=True, lock=False) + db_env = lmdb.open(os.path.join(self.root, self.file_map[idx]), readonly=True, lock=False) with db_env.begin() as txn: - data_dict = txn.get(int(idx).to_bytes(length=4, byteorder='big')) + data_dict = txn.get(self.index_map[int(idx)].to_bytes(length=4, byteorder='big')) data_dict = pickle.loads(data_dict) cell, pos, atomic_numbers = \ data_dict[AtomicDataDict.CELL_KEY], \ @@ -141,7 +145,7 @@ def get(self, idx): cell=cell.reshape(3,3), atomic_numbers=atomic_numbers, pbc=pbc, - **self.info + **self.info_files[self.file_map[idx]] ) # transform blocks to atomicdata features diff --git a/dptb/entrypoints/run.py b/dptb/entrypoints/run.py index 85c29569..0bcb4f5f 100644 --- a/dptb/entrypoints/run.py +++ b/dptb/entrypoints/run.py @@ -91,7 +91,7 @@ def run( block = write_ham(data=struct_file, AtomicData_options=jdata['AtomicData_options'], model=model, device=jdata["device"]) # write to h5 file, block is a dict, write to a h5 file with h5py.File(os.path.join(results_path, task+".h5"), 'w') as fid: - default_group = fid.create_group("1") + default_group = fid.create_group("0") for key_str, value in block.items(): default_group[key_str] = value.detach().cpu().numpy() log.info(msg='write block successfully completed.') \ No newline at end of file diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index d921811d..ceee4d2a 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -1423,10 +1423,41 @@ def set_info_options(): return Argument("setinfo", dict, sub_fields=args) +def set_info_options(): + doc_nframes = "Number of frames in this trajectory." + doc_natoms = "Number of atoms in each frame." + doc_pos_type = "Type of atomic position input. Can be frac / cart / ase." + + args = [ + Argument("nframes", int, optional=False, doc=doc_nframes), + Argument("natoms", int, optional=True, default=-1, doc=doc_natoms), + Argument("pos_type", str, optional=False, doc=doc_pos_type), + bandinfo_sub(), + AtomicData_options_sub() + ] + + return Argument("setinfo", dict, sub_fields=args) + +def lmdbset_info_options(): + doc_r_max = "the cutoff value for bond considering in TB model." + + args = [ + Argument("r_max", [float, int, dict], optional=False, doc=doc_r_max, default=4.0) + ] + return Argument("setinfo", dict, sub_fields=args) + def normalize_setinfo(data): setinfo = set_info_options() data = setinfo.normalize_value(data) setinfo.check_value(data, strict=True) + return data + +def normalize_lmdbsetinfo(data): + + setinfo = lmdbset_info_options() + data = setinfo.normalize_value(data) + setinfo.check_value(data, strict=True) + return data \ No newline at end of file