Skip to content

Commit

Permalink
docs generate RPC task
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Beck committed Oct 2, 2019
1 parent 31e085b commit 90dff84
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 95 deletions.
73 changes: 71 additions & 2 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dbt.contracts.graph.manifest import CompileResultNode
from dbt.contracts.graph.unparsed import Time, FreshnessStatus
from dbt.contracts.graph.parsed import ParsedSourceDefinition
from dbt.contracts.util import Writable
from dbt.contracts.util import Writable, Replaceable
from dbt.logger import LogMessage
from hologram.helpers import StrEnum
from hologram import JsonSchemaMixin
Expand All @@ -10,7 +10,7 @@

from dataclasses import dataclass, field
from datetime import datetime
from typing import Union, Dict, List, Optional, Any
from typing import Union, Dict, List, Optional, Any, NamedTuple
from numbers import Real


Expand Down Expand Up @@ -227,3 +227,72 @@ class ResultTable(JsonSchemaMixin):
@dataclass
class RemoteRunResult(RemoteCompileResult):
table: ResultTable


Primitive = Union[bool, str, float, None]

CatalogKey = NamedTuple(
'CatalogKey',
[('database', str), ('schema', str), ('name', str)]
)


@dataclass
class StatsItem(JsonSchemaMixin):
id: str
label: str
value: Primitive
description: str
include: bool


StatsDict = Dict[str, StatsItem]


@dataclass
class ColumnMetadata(JsonSchemaMixin):
type: str
comment: Optional[str]
index: int
name: str


ColumnMap = Dict[str, ColumnMetadata]


@dataclass
class TableMetadata(JsonSchemaMixin):
type: str
database: str
schema: str
name: str
comment: Optional[str]
owner: Optional[str]


@dataclass
class CatalogTable(JsonSchemaMixin, Replaceable):
metadata: TableMetadata
columns: ColumnMap
stats: StatsDict
# the same table with two unique IDs will just be listed two times
unique_id: Optional[str] = None

def key(self) -> CatalogKey:
return CatalogKey(
self.metadata.database.lower(),
self.metadata.schema.lower(),
self.metadata.name.lower(),
)


@dataclass
class CatalogResults(JsonSchemaMixin, Writable):
nodes: Dict[str, CatalogTable]
generated_at: datetime
_compile_results: Optional[Any] = None


@dataclass
class RemoteCatalogResults(CatalogResults):
logs: List[LogMessage] = field(default_factory=list)
10 changes: 8 additions & 2 deletions core/dbt/rpc/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
from queue import Empty
from typing import Optional, Any, Union

from dbt.contracts.results import RemoteCompileResult, RemoteExecutionResult
from dbt.contracts.results import (
RemoteCompileResult, RemoteExecutionResult, RemoteCatalogResults
)
from dbt.exceptions import InternalException
from dbt.utils import restrict_to


RemoteCallableResult = Union[RemoteCompileResult, RemoteExecutionResult]
RemoteCallableResult = Union[
RemoteCompileResult,
RemoteExecutionResult,
RemoteCatalogResults,
]


class QueueMessageType(StrEnum):
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/rpc/response_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _get_responses(cls, requests, dispatcher):
# to_dict
if hasattr(output, 'result'):
if isinstance(output.result, JsonSchemaMixin):
output.result = output.result.to_dict(omit_empty=False)
output.result = output.result.to_dict(omit_none=False)
yield output

@classmethod
Expand Down
21 changes: 21 additions & 0 deletions core/dbt/rpc/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RemoteCompileResult,
RemoteRunResult,
RemoteExecutionResult,
RemoteCatalogResults,
)
from dbt.logger import LogMessage
from dbt.rpc.error import dbt_error, RPCException
Expand Down Expand Up @@ -210,6 +211,24 @@ def from_result(cls, status, base):
)


@dataclass
class PollCatalogSuccessResult(PollResult, RemoteCatalogResults):
status: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success),
default=TaskHandlerState.Success
)

@classmethod
def from_result(cls, status, base):
return cls(
status=status,
nodes=base.nodes,
generated_at=base.generated_at,
_compile_results=base._compile_results,
logs=base.logs,
)


def poll_success(status, logs, result):
if status != TaskHandlerState.Success:
raise dbt.exceptions.InternalException(
Expand All @@ -223,6 +242,8 @@ def poll_success(status, logs, result):
return PollRunSuccessResult.from_result(status=status, base=result)
elif isinstance(result, RemoteCompileResult):
return PollCompileSuccessResult.from_result(status=status, base=result)
elif isinstance(result, RemoteCatalogResults):
return PollCatalogSuccessResult.from_result(status=status, base=result)
else:
raise dbt.exceptions.InternalException(
'got invalid result in poll_success: {}'.format(result)
Expand Down
129 changes: 40 additions & 89 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from typing import Union, Dict, List, Optional, Any, NamedTuple
from typing import Dict, List, Any

from hologram import JsonSchemaMixin, ValidationError
from hologram import ValidationError

from dbt.adapters.factory import get_adapter
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.util import Writable, Replaceable
from dbt.contracts.results import (
TableMetadata, CatalogTable, CatalogResults, Primitive, CatalogKey,
StatsItem, StatsDict, ColumnMetadata
)
from dbt.include.global_project import DOCS_INDEX_FILE_PATH
import dbt.ui.printer
import dbt.utils
Expand All @@ -33,87 +35,31 @@ def get_stripped_prefix(source: Dict[str, Any], prefix: str) -> Dict[str, Any]:
}


Primitive = Union[bool, str, float, None]
PrimitiveDict = Dict[str, Primitive]


Key = NamedTuple(
'Key',
[('database', str), ('schema', str), ('name', str)]
)


@dataclass
class StatsItem(JsonSchemaMixin):
id: str
label: str
value: Primitive
description: str
include: bool


StatsDict = Dict[str, StatsItem]


@dataclass
class ColumnMetadata(JsonSchemaMixin):
type: str
comment: Optional[str]
index: int
name: str


ColumnMap = Dict[str, ColumnMetadata]


@dataclass
class TableMetadata(JsonSchemaMixin):
type: str
database: str
schema: str
name: str
comment: Optional[str]
owner: Optional[str]

def build_catalog_table(data) -> CatalogTable:
# build the new table's metadata + stats
metadata = TableMetadata.from_dict(get_stripped_prefix(data, 'table_'))
stats = format_stats(get_stripped_prefix(data, 'stats:'))

@dataclass
class Table(JsonSchemaMixin, Replaceable):
metadata: TableMetadata
columns: ColumnMap
stats: StatsDict
# the same table with two unique IDs will just be listed two times
unique_id: Optional[str] = None

@classmethod
def from_query_result(cls, data) -> 'Table':
# build the new table's metadata + stats
metadata = TableMetadata.from_dict(get_stripped_prefix(data, 'table_'))
stats = format_stats(get_stripped_prefix(data, 'stats:'))

return cls(
metadata=metadata,
stats=stats,
columns={},
)

def key(self) -> Key:
return Key(
self.metadata.database.lower(),
self.metadata.schema.lower(),
self.metadata.name.lower(),
)
return CatalogTable(
metadata=metadata,
stats=stats,
columns={},
)


# keys are database name, schema name, table name
class Catalog(Dict[Key, Table]):
class Catalog(Dict[CatalogKey, CatalogTable]):
def __init__(self, columns: List[PrimitiveDict]):
super().__init__()
for col in columns:
self.add_column(col)

def get_table(self, data: PrimitiveDict) -> Table:
def get_table(self, data: PrimitiveDict) -> CatalogTable:
try:
key = Key(
key = CatalogKey(
str(data['table_database']),
str(data['table_schema']),
str(data['table_name']),
Expand All @@ -123,10 +69,11 @@ def get_table(self, data: PrimitiveDict) -> Table:
'Catalog information missing required key {} (got {})'
.format(exc, data)
)
table: CatalogTable
if key in self:
table = self[key]
else:
table = Table.from_query_result(data)
table = build_catalog_table(data)
self[key] = table
return table

Expand All @@ -140,8 +87,10 @@ def add_column(self, data: PrimitiveDict):
column = ColumnMetadata.from_dict(column_data)
table.columns[column.name] = column

def make_unique_id_map(self, manifest: Manifest) -> Dict[str, Table]:
nodes: Dict[str, Table] = {}
def make_unique_id_map(
self, manifest: Manifest
) -> Dict[str, CatalogTable]:
nodes: Dict[str, CatalogTable] = {}

manifest_mapping = get_unique_id_mapping(manifest)
for table in self.values():
Expand Down Expand Up @@ -201,16 +150,16 @@ def format_stats(stats: PrimitiveDict) -> StatsDict:
return stats_collector


def mapping_key(node: CompileResultNode) -> Key:
return Key(
def mapping_key(node: CompileResultNode) -> CatalogKey:
return CatalogKey(
node.database.lower(), node.schema.lower(), node.identifier.lower()
)


def get_unique_id_mapping(manifest: Manifest) -> Dict[Key, List[str]]:
def get_unique_id_mapping(manifest: Manifest) -> Dict[CatalogKey, List[str]]:
# A single relation could have multiple unique IDs pointing to it if a
# source were also a node.
ident_map: Dict[Key, List[str]] = {}
ident_map: Dict[CatalogKey, List[str]] = {}
for unique_id, node in manifest.nodes.items():
key = mapping_key(node)

Expand All @@ -221,13 +170,6 @@ def get_unique_id_mapping(manifest: Manifest) -> Dict[Key, List[str]]:
return ident_map


@dataclass
class CatalogResults(JsonSchemaMixin, Writable):
nodes: Dict[str, Table]
generated_at: datetime
_compile_results: Optional[Any] = None


def _coerce_decimal(value):
if isinstance(value, dbt.utils.DECIMALS):
return float(value)
Expand All @@ -242,7 +184,7 @@ def _get_manifest(self) -> Manifest:
def run(self):
compile_results = None
if self.args.compile:
compile_results = super().run()
compile_results = CompileTask.run(self)
if any(r.error is not None for r in compile_results):
dbt.ui.printer.print_timestamped_line(
'compile failed, cannot generate docs'
Expand All @@ -266,10 +208,10 @@ def run(self):
]

catalog = Catalog(catalog_data)
results = CatalogResults(
results = self.get_catalog_results(
nodes=catalog.make_unique_id_map(manifest),
generated_at=datetime.utcnow(),
_compile_results=compile_results,
compile_results=compile_results,
)

path = os.path.join(self.config.target_path, CATALOG_FILENAME)
Expand All @@ -280,6 +222,15 @@ def run(self):
)
return results

def get_catalog_results(
self, nodes, generated_at, compile_results
) -> CatalogResults:
return CatalogResults(
nodes=nodes,
generated_at=datetime.utcnow(),
_compile_results=compile_results,
)

def interpret_results(self, results):
compile_results = results._compile_results
if compile_results is None:
Expand Down
Loading

0 comments on commit 90dff84

Please sign in to comment.