diff --git a/redash/query_runner/rockset.py b/redash/query_runner/rockset.py index 715fc100ae..6319f67503 100644 --- a/redash/query_runner/rockset.py +++ b/redash/query_runner/rockset.py @@ -1,3 +1,5 @@ +from multiprocessing.pool import ThreadPool + import requests from redash.query_runner import * from redash.utils import json_dumps @@ -23,7 +25,8 @@ def __init__(self, api_key, api_server): self.api_server = api_server def _request(self, endpoint, method='GET', body=None): - headers = {'Authorization': 'ApiKey {}'.format(self.api_key)} + headers = {'Authorization': 'ApiKey {}'.format(self.api_key), + 'User-Agent': 'rest:redash/1.0'} url = '{}/v1/orgs/self/{}'.format(self.api_server, endpoint) if method == 'GET': @@ -35,9 +38,17 @@ def _request(self, endpoint, method='GET', body=None): else: raise 'Unknown method: {}'.format(method) - def list(self): - response = self._request('ws/commons/collections') - return response['data'] + def list_workspaces(self): + response = self._request('ws') + return [x['name'] for x in response['data'] if x['collection_count'] > 0] + + def list_collections(self, workspace='commons'): + response = self._request('ws/{}/collections'.format(workspace)) + return [x['name'] for x in response['data']] + + def collection_columns(self, workspace, collection): + response = self.query('DESCRIBE "{}"."{}"'.format(workspace, collection)) + return list(set([x['field'][0] for x in response['results']])) def query(self, sql): return self._request('queries', 'POST', {'sql': {'query': sql}}) @@ -76,12 +87,23 @@ def __init__(self, configuration): 'api_server', "https://api.rs2.usw2.rockset.com")) def _get_tables(self, schema): - for col in self.api.list(): - table_name = col['name'] - describe = self.api.query('DESCRIBE "{}"'.format(table_name)) - columns = list(set(map(lambda x: x['field'][0], describe['results']))) - schema[table_name] = {'name': table_name, 'columns': columns} - return schema.values() + pool = ThreadPool(processes=10) + + try: + workspaces = self.api.list_workspaces() + collections = pool.map(self.api.list_collections, workspaces) + + args = [(w, c) for (w, cols) in zip(workspaces, collections) for c in cols] + describe_results = pool.map(lambda (w, c): self.api.collection_columns(w, c), args) + + for (w, c), columns in zip(args, describe_results): + table_name = c if w == 'commons' else '{}.{}'.format(w, c) + schema[table_name] = {'name': table_name, 'columns': columns} + + return sorted(schema.values(), key=lambda x: x['name']) + + finally: + pool.close() def run_query(self, query, user): results = self.api.query(query)