diff --git a/liteindex/defined_index.py b/liteindex/defined_index.py index dd10688..68b7fc0 100644 --- a/liteindex/defined_index.py +++ b/liteindex/defined_index.py @@ -245,12 +245,22 @@ def update(self, data): self._connection.executemany(sql, transactions) self._connection.commit() - def get(self, ids): + def get(self, ids, select_keys=[]): if isinstance(ids, str): ids = [ids] + if not select_keys: + select_keys = list(self.original_key_to_key_hash.values()) + else: + if [k for k in select_keys if k not in self.original_key_to_key_hash]: + raise ValueError( + f"Invalid select_keys: {[k for k in select_keys if k not in self.original_key_to_key_hash]}" + ) + + select_keys = [self.original_key_to_key_hash[k] for k in select_keys] + # Prepare the SQL command - columns = ", ".join([f'"{h}"' for h in self.original_key_to_key_hash.values()]) + columns = ", ".join([f'"{h}"' for h in select_keys]) column_str = "id, " + columns # Update this to include `id` # Format the ids for the where clause @@ -261,7 +271,7 @@ def get(self, ids): for row in self._connection.execute(sql, ids).fetchall(): record = { self.key_hash_to_original_key[h]: val - for h, val in zip(self.original_key_to_key_hash.values(), row[1:]) + for h, val in zip(select_keys, row[1:]) if val is not None } for k, v in record.items(): diff --git a/setup.py b/setup.py index 2168d7d..a5f936c 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ EMAIL = "praneeth@bpraneeth.com" AUTHOR = "BEDAPUDI PRANEETH" REQUIRES_PYTHON = ">=3.6.0" -VERSION = "0.0.2.dev3" +VERSION = "0.0.2.dev4" # What packages are required for this module to be executed? REQUIRED = [