-
Notifications
You must be signed in to change notification settings - Fork 133
/
Copy pathdriver.py
833 lines (746 loc) · 31.7 KB
/
driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
import datetime
import hashlib
import inspect
import json
import logging
import operator
import os
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from hamilton_sdk.api.clients import UnauthorizedException
from hamilton import base, driver, graph, node
from hamilton.driver import Variable
from hamilton.io import materialization
from hamilton.lifecycle.base import BaseDoNodeExecute
from hamilton.node import Node
try:
import git
except ImportError:
git = None
from hamilton_sdk.api import clients, constants
from hamilton_sdk.api.projecttypes import GitInfo
from hamilton_sdk.tracking.runs import Status, TrackingState, monkey_patch_adapter
logger = logging.getLogger(__name__)
def _hash_module(
module: ModuleType, hash_object: hashlib.sha256, seen_modules: Set[ModuleType]
) -> hashlib.sha256:
"""Generate a hash of the specified module and its imports.
It will recursively hash the contents of the modules and their imports, and only does so
if the import is from the same package. This is to avoid hashing the entire python
environment...
:param module: the python module to hash and then crawl.
:param hash_object: the object to update.
:param seen_modules: the python modules we've already hashed.
:return: the updated hash object
"""
# Check if we've already hashed this module
if module in seen_modules:
return hash_object
else:
seen_modules.add(module)
# Update the hash with the module's source code
if hasattr(module, "__file__") and module.__file__ is not None:
with open(module.__file__, "rb") as f:
hash_object.update(f.read())
else:
logger.debug(
"Skipping hash for module %s because it has no __file__ attribute or it is None.",
module,
)
# Loop through the module's attributes
for name, value in inspect.getmembers(module):
# Check if the attribute is a module
if inspect.ismodule(value):
if value.__package__ is None:
logger.info(
f"Skipping hash for module {value.__name__} because it has no __package__ "
f"attribute or it is None. This happens with lazy loaders."
)
continue
# Check if the module is in the same top level package
if value.__package__ != module.__package__ and not value.__package__.startswith(
module.__package__
):
logger.debug(
f"Skipping hash for module {value.__name__} because it is in a different "
f"package {value.__package__} than {module.__package__}"
)
continue
# Recursively hash the sub-module
hash_object = _hash_module(value, hash_object, seen_modules)
# Return the hash object
return hash_object
def _get_modules_hash(modules: Tuple[ModuleType]) -> str:
"""Generate a hash of the contents of the specified modules.
It recursively hashes the contents of the modules and their imports, and only does so
if the import is from the same package. This is to avoid hashing the entire python
environment...
:param modules: python modules to hash
:return: the hex digest of the hash
"""
# Create a hash object
h = hashlib.sha256()
seen_modules = set()
# Loop through each module name
for module in modules:
# Update the hash with the module's source code
h = _hash_module(module, h, seen_modules)
# Return the hex digest of the hash
return h.hexdigest()
def _derive_version_control_info(module_hash: str) -> GitInfo:
"""Derive the git info for the current project.
Currently, this decides whether we're in a git repository.
This is not going to work for everything, but we'll see what the customers want.
We might end up having to pass this data in...
"""
default = GitInfo(
branch="unknown",
commit_hash=module_hash,
committed=False,
repository="Error: No repository to link to.",
local_repo_base_path=os.getcwd(),
)
if git is None:
return default
try:
repo = git.Repo(".", search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
logger.warning(
"Warning: We are not currently in a git repository. We recommend using that as a "
"way to version the "
"project *if* your hamilton code lives within this repository too. If it does not,"
" then we'll try to "
"version code based on the python modules passed to the Driver. "
"Incase you want to get set up with git quickly you can run:\n "
"git init && git add . && git commit -m 'Initial commit'\n"
"Still have questions? Reach out to stefan @ dagworks.io, elijah @ dagworks.io "
"and we'll try to help you as soon as possible."
)
return default
if "COLAB_RELEASE_TAG" in os.environ:
logger.warning(
"We currently do not support logging version information inside a google"
"colab notebook. This is something we are planning to do. "
"If you have any questions, please reach out to support@dagworks.io"
"and we'll try to help you as soon as possible."
)
return default
commit = repo.head.commit
try:
repo_url = repo.remote().url
except ValueError:
# TODO: change this to point to our docs on what to do.
repo_url = "Error: No repository to link to."
try:
branch_name = repo.active_branch.name
except TypeError:
branch_name = "unknown" # detached head
logger.warning(
"Warning: we are unable to determine the branch name. "
"This is likely because you are in a detached head state. "
"If you are in a detached head state, you can check out a "
"branch by running `git checkout -b <branch_name>`. "
"If you intend to be (if you are using some sort of CI"
"system that checks out a detached head) then you can ignore this."
)
return GitInfo(
branch=branch_name,
commit_hash=commit.hexsha,
committed=not repo.is_dirty(),
repository=repo_url,
local_repo_base_path=repo.working_dir,
)
def filter_json_dict_to_serializable(
dict_to_filter: Dict[str, Any], curr_result: Dict[str, Any] = None
):
if curr_result is None:
curr_result = {}
if dict_to_filter is None:
dict_to_filter = {}
for key, value in dict_to_filter.items():
try:
json.dumps(value)
curr_result[key] = value
except TypeError:
if isinstance(value, dict):
new_result = {}
filter_json_dict_to_serializable(value, new_result)
curr_result[key] = new_result
else:
curr_result[key] = str(value)
return curr_result
def validate_tags(tags: Any):
"""Validates that tags are a dictionary of strings to strings.
:param tags: Tags to validate
:raises ValueError: If tags are not a dictionary of strings to strings
"""
if not isinstance(tags, dict):
raise ValueError(f"Tags must be a dictionary, but got {tags}")
for key, value in tags.items():
if not isinstance(key, str):
raise ValueError(f"Tag keys must be strings, but got {key}")
if not isinstance(value, str):
raise ValueError(f"Tag values must be strings, but got {value}")
def safe_len(x):
return len(x) if x is not None else 0
# Placeholder so we can monkey-patch later (see monkey_patch_adapter)
class DefaultExecutionMethod(BaseDoNodeExecute):
def do_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str] = None,
) -> Any:
return node_(**kwargs)
class Driver(driver.Driver):
def __init__(
self,
config: Dict[str, Any],
*modules: ModuleType,
project_id: int,
api_key: str,
username: str,
dag_name: str,
tags: Dict[str, str] = None,
client_factory: Callable[
[str, str, str], clients.HamiltonClient
] = clients.BasicSynchronousHamiltonClient,
adapter: base.HamiltonGraphAdapter = None,
dagworks_api_url=os.environ.get("DAGWORKS_API_URL", constants.HAMILTON_API_URL),
dagworks_ui_url=os.environ.get("DAGWORKS_UI_URL", constants.HAMILTON_UI_URL),
):
"""Instantiates a DAGWorks driver. This:
1. Requires a project to exist. Create one via https://app.dagworks.io/dashboard/projects.
2. Sends over the shape of the DAG.
3. Sets up execute() run-tracking.
:param config: Configuration to use, same as standard Hamilton driver.
:param modules: Modules to use, same as standard Hamilton driver.
:param project_id: Identifier for the project to use to store this DAG under.
:param api_key: API key to use for authentication. Remember not to save this in plaintext!
:param username: email address to use for authentication.
:param dag_name: name for this DAG. You will use this for top level curation of DAGs
within a project.
:param tags: Optional key value string pairs to help identify and curate this instance of
the DAG and subsequent execution runs. E.g. {"environment": "production"}.
Currently all .execute() runs will be tagged with these.
:param client_factory: Optional. Advanced use. Factory to use to create the underlying
client.
:param adapter: Optional. Adapter to use, same as standard Hamilton driver.
:param dagworks_api_url: Optional. URL to use for the DAGWorks API.
:param dagworks_ui_url: Optional. URL to use for the DAGWorks UI.
"""
if adapter is None:
adapter = base.SimplePythonGraphAdapter(result_builder=base.DictResult())
super(Driver, self).__init__(config, *modules, adapter=adapter)
self.config = config
self.project = project_id
self.api_key = api_key
self.username = username
# TODO -- figure out how to pass any additional configuration to the client if needed
self.client = client_factory(api_key, username, dagworks_api_url)
self.module_hash = _get_modules_hash(modules)
self.vcs_info = _derive_version_control_info(self.module_hash)
self.initialized = False
self.modules = modules
self.project_version = None
self.run_tags = tags if tags is not None else {}
validate_tags(self.run_tags)
self.dag_name = dag_name
self.dagworks_ui_url = dagworks_ui_url
# reassign the graph executor with all the information we have
self.graph_executor = DAGWorksGraphExecutor(
self.graph_executor,
self.client,
self.run_tags,
self.dagworks_ui_url,
self.project,
self.vcs_info.local_repo_base_path,
self.vcs_info,
self.dag_name,
list(self.graph_modules),
self.graph,
)
self.initialize()
def set_name(self, new_name: str):
"""Sets a name for the driver. This allows you to force a change in the name/version of the
DAG so the next run logs a new one.
:param new_name:
"""
self.dag_name = new_name
self.graph_executor.dag_name = new_name
def initialize(self):
"""Initializes the driver. This:
1. Validates authentication
2. Creates a project if it does not already exist
3. Sets initialization as true
Note this is idempotent -- it can be called by the user to test, but will get called when
the driver runs.
"""
logger.debug("Validating authentication against DAGWorks API...")
self.client.validate_auth()
logger.debug("Authentication successful!")
logger.debug(f"Ensuring project {self.project} exists...")
try:
exists = self.client.project_exists(self.project)
except UnauthorizedException:
logger.exception(
f"Authentication failed. Please check your credentials and try again. "
f"Username: {self.username}, API key: {self.api_key[0:6]}..."
)
raise
if not exists:
logger.error(
f"Project {self.project} does not exist. Please create it first in the UI! "
f"You can do so at https://www.app.dagworks.io/dashboard/projects"
)
self.graph_executor.initialize()
self.initialized = True
def execute(
self,
final_vars: List[Union[str, Callable]],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
) -> Any:
logger.warning(
f"\nCapturing execution run. All runs for project can be found at "
f"{self.dagworks_ui_url}/dashboard/project/{self.project}/runs"
)
return super(Driver, self).execute(final_vars, overrides, display_graph, inputs)
def raw_execute(
self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
) -> Dict[str, Any]:
return super(Driver, self).raw_execute(
final_vars, overrides, display_graph, inputs, _fn_graph
)
def materialize(
self,
*materializers: materialization.MaterializerFactory,
additional_vars: List[Union[str, Callable, Variable]] = None,
overrides: Dict[str, Any] = None,
inputs: Dict[str, Any] = None,
) -> Tuple[Any, Dict[str, Any]]:
return super(Driver, self).materialize(
*materializers, additional_vars=additional_vars, overrides=overrides, inputs=inputs
)
def _get_fully_qualified_function_path(fn: Callable) -> str:
"""Gets the fully qualified path of a function.
:param fn: Function to get the path of
:return: Fully qualified path of the function
"""
module = inspect.getmodule(fn)
fn_name = fn.__name__
if module is not None:
fn_name = f"{module.__name__}.{fn_name}"
return fn_name
def hash_dag(dag: graph.FunctionGraph) -> str:
"""Hashes a DAG.
:param dag: DAG to hash
:return: Hash of the DAG
"""
digest = hashlib.sha256()
hashing_node_fields = {
"name": str,
"type": str, # Hash it to a string for the type
"node_role": str,
"dependencies": lambda x: " ".join(
[item.name for item in sorted(x, key=operator.attrgetter("name"))]
), # Sort dependencies
}
for node_ in sorted(dag.nodes.values(), key=operator.attrgetter("name")):
for field, serializer in hashing_node_fields.items():
digest.update(json.dumps(serializer(getattr(node_, field))).encode())
if node_.originating_functions is None:
continue
for fn in node_.originating_functions:
fn_name = _get_fully_qualified_function_path(fn)
digest.update(fn_name.encode())
return digest.hexdigest()
def hash_dag_modules(dag: graph.FunctionGraph, modules: List[ModuleType]):
modules_by_path = {}
for module in modules:
if hasattr(module, "__file__") and module.__file__ is not None:
try:
modules_by_path[module.__file__] = inspect.getsource(module)
except OSError:
logger.warning(
f"Skipping hashing of module {module.__name__} because we could not read the source code."
)
digest = hashlib.sha256()
for node_ in sorted(dag.nodes.values(), key=operator.attrgetter("name")):
if node_.originating_functions is None:
continue
for fn in node_.originating_functions:
module = inspect.getmodule(fn)
if hasattr(module, "__file__") and module.__file__ is not None:
modules_by_path[module.__file__] = module
for module_path, module_contents in sorted(modules_by_path.items()):
# if the filename is tmpXXXXXXXX.py assume it's a temporary file and skip hashing the name
# this could be in a jupyter context in which case this will cause different code
# versions when in fact there are none.
file_name = os.path.basename(module_path)
if file_name.startswith("tmp") and len(file_name) == 14:
pass
else:
digest.update(module_path.encode())
digest.update(_get_modules_hash((module_contents,)).encode())
return digest.hexdigest()
def _convert_node_dependencies(node: Node) -> dict:
dependencies = []
dependency_specs = []
dependency_specs_type = "python_type"
dependency_specs_schema_version = 1
for node_ in node.dependencies:
dependencies.append(node_.name)
dependency_specs.append({"type_name": str(node_.type)})
return {
"dependencies": dependencies,
"dependency_specs": dependency_specs,
"dependency_specs_type": dependency_specs_type,
"dependency_specs_schema_version": dependency_specs_schema_version,
}
def _convert_classifications(node_: Node) -> List[str]:
out = []
if (
node_.tags.get("hamilton.data_loader")
and node_.tags.get("hamilton.data_loader.has_metadata") is not False
):
out.append("data_loader")
elif node_.tags.get("hamilton.data_saver"):
out.append("data_saver")
elif node_.user_defined:
out.append("input")
else:
out.append("transform")
return out
def _extract_node_templates_from_function_graph(fn_graph: graph.FunctionGraph) -> List[dict]:
"""Converts a function graph to a list of nodes that the DAGWorks graph can understand.
@param fn: Function graph to convert
@return: A list of node objects
"""
node_templates = []
for node_ in fn_graph.nodes.values():
code_artifact_pointers = (
[]
if (node_.originating_functions is None or len(node_.originating_functions) == 0)
else [_get_fully_qualified_function_path(fn) for fn in node_.originating_functions]
)
node_templates.append(
dict(
name=node_.name,
output={"type_name": str(node_.type)},
output_type="python_type",
output_schema_version=1, # TODO -- merge this with _convert_node_dependencies
documentation=node_.documentation,
tags=node_.tags, # TODO -- ensure serializable
classifications=_convert_classifications(node_), # TODO -- manage classifications
code_artifact_pointers=(
code_artifact_pointers
if node_.originating_functions is None or len(node_.originating_functions) == 0
else [
_get_fully_qualified_function_path(fn) for fn in node_.originating_functions
]
),
**_convert_node_dependencies(node_),
)
)
return node_templates
def _derive_url(vcs_info: GitInfo, path: str, line: int) -> str:
"""Derives a URL from a VCS info, a path, and a line number.
@param vcs_info: VCS info
@param path: Path
@param line: Line number
@return: A URL
"""
if vcs_info.repository == "Error: No repository to link to.":
return "Error: No repository to link to."
if vcs_info.repository.endswith(".git"):
repo_url = vcs_info.repository[:-4]
else:
repo_url = vcs_info.repository
return f"{repo_url}/blob/{vcs_info.commit_hash}/{path}#L{line}"
def getsourcelines(object, stop: Callable = None) -> tuple:
"""Adding this here incase we want to pull decorator code too.
This is modification of the underlying inspect function.
Return a list of source lines and starting line number for an object.
The argument may be a module, class, method, function, traceback, frame,
or code object. The source code is returned as a list of the lines
corresponding to the object and the line number indicates where in the
original source file the first line of code was found. An OSError is
raised if the source code cannot be retrieved.
This will return the decorator code, or the underlying wrapped function code.
"""
object = inspect.unwrap(object, stop=stop)
lines, lnum = inspect.findsource(object)
if inspect.istraceback(object):
object = object.tb_frame
# for module or frame that corresponds to module, return all source lines
if inspect.ismodule(object) or (
inspect.isframe(object) and object.f_code.co_name == "<module>"
):
return lines, 0
else:
return inspect.getblock(lines[lnum:]), lnum + 1
def extract_code_artifacts_from_function_graph(
fn_graph: graph.FunctionGraph, vcs_info: GitInfo, repo_base_path: str
) -> List[dict]:
"""Converts a function graph to a list of code artifacts that the function graph uses.
@param fn_graph: Function graph to convert.
@return: A list of node objects.
"""
seen = set()
out = []
for node_ in fn_graph.nodes.values():
originating_functions = node_.originating_functions
if originating_functions is None:
continue
for fn in originating_functions:
fn_name = _get_fully_qualified_function_path(fn)
if fn_name not in seen:
seen.add(fn_name)
# need to handle decorators -- they will return the wrong sourcefile.
unwrapped_fn = inspect.unwrap(fn)
if unwrapped_fn != fn:
# TODO: pull decorator stuff too
source_file = inspect.getsourcefile(unwrapped_fn)
else:
source_file = inspect.getsourcefile(fn)
if source_file is not None:
path = os.path.relpath(source_file, repo_base_path)
else:
path = ""
try:
source_lines = inspect.getsourcelines(fn)
out.append(
dict(
name=fn_name,
type="p_function",
path=path,
start=inspect.getsourcelines(fn)[1] - 1,
end=inspect.getsourcelines(fn)[1] - 1 + len(source_lines[0]),
url=_derive_url(vcs_info, path, source_lines[1]),
)
)
except OSError:
# This is an error state where somehow we don't have
# source code.
out.append(
dict(
name=fn_name,
type="p_function",
path=path,
start=0,
end=0,
url=_derive_url(vcs_info, path, 0),
)
)
return out
def extract_attributes_from_tracking_state(tracking_state: TrackingState) -> List[dict]:
"""Extracts attributes from tracking state. We'll likely rewrite this shortly --
this is just to bridge so we can get the client out. Next, we'll want it putting
stuff on a queue, and then sends it over in batches. The tracking state is a hack
and we'll get rid of it.
@param tracking_state: Tracking state
@return: A list of attributes
"""
# This just bridges some of the old code so we can move quickly
# TODO -- fix so we don't have to do that
out = []
dag_run = tracking_state.get()
for task in dag_run.tasks:
if task.error is not None:
out.append(
dict(
node_name=task.node_name,
name="stack_trace",
type="error",
schema_version=1,
value={
"stack_trace": task.error,
},
attribute_role="error",
)
)
if task.result_summary is not None:
out.append(
dict(
node_name=task.node_name,
name="result_summary",
type=task.result_summary["observability_type"],
# 0.0.3 -> 3
schema_version=int(
task.result_summary["observability_schema_version"].split(".")[-1]
),
value=task.result_summary["observability_value"],
attribute_role="result_summary",
)
)
return out
def extract_task_updates_from_tracking_state(
tracking_state: TrackingState, fg: graph.FunctionGraph
) -> List[dict]:
"""Extracts task updates from tracking state. We'll likely rewrite this shortly --
this is a hack (using the tracking state) -- we'll want to extract these as we go along,
and we'll want it putting stuff on a queue, and then sends it over in batches.
@param tracking_state:
@return:
"""
# TODO -- do the tracking state in a cleaner way
# This is left over from the old way we were doing things
dag_run = tracking_state.get()
out = []
for task in dag_run.tasks:
node_ = fg.nodes[task.node_name]
out.append(
dict(
node_template_name=task.node_name,
node_name=task.node_name,
realized_dependencies=[dep.name for dep in node_.dependencies],
status=task.status,
start_time=task.start_time,
end_time=task.end_time,
)
)
return out
def _slurp_code(fg: graph.FunctionGraph, repo_base: str) -> List[dict]:
modules = set()
for node_ in fg.nodes.values():
originating_functions = node_.originating_functions
if originating_functions is None:
continue
for fn in originating_functions:
module = inspect.getmodule(fn)
modules.add(module)
out = []
for module in modules:
if hasattr(module, "__file__") and module.__file__ is not None:
module_path = os.path.relpath(module.__file__, repo_base)
with open(module.__file__, "r") as f:
out.append({"path": module_path, "contents": f.read()})
return out
class DAGWorksGraphExecutor(driver.GraphExecutor):
def __init__(
self,
wrapping_executor: driver.GraphExecutor,
client: clients.HamiltonClient,
run_tags: Dict[str, str],
dagworks_ui_url: str,
project_id: int,
repo_base: str,
vcs_info: GitInfo,
dag_name: str,
graph_modules: List[ModuleType],
initial_graph: graph.FunctionGraph,
):
self.executor = wrapping_executor
self.client = client
self.run_tags = run_tags
self.dagworks_ui_url = dagworks_ui_url
self.project_id = project_id
self.repo_base = repo_base
self.vcs_info = vcs_info
self.dag_name = dag_name
self.graph_modules = graph_modules
self.dag_template_id_cache = {}
self.initial_graph = initial_graph
def initialize(self):
self._register_or_query(self.initial_graph)
def _register_or_query(self, fg: graph.FunctionGraph) -> int:
"""Creates a DAG template if the funtion graph doesn't exist. Otherwise
we use the cache we see. Note that if this already
@param fg:
@return:
"""
# Quick way to bypass anything complicated
fg_id = id(fg)
if fg_id in self.dag_template_id_cache:
return self.dag_template_id_cache[fg_id]
dag_hash = hash_dag(fg)
code_hash = hash_dag_modules(fg, self.graph_modules)
dag_template_id = self.client.register_dag_template_if_not_exists(
project_id=self.project_id,
dag_hash=dag_hash,
code_hash=code_hash,
name=self.dag_name,
nodes=_extract_node_templates_from_function_graph(fg),
code_artifacts=extract_code_artifacts_from_function_graph(
fg, self.vcs_info, self.repo_base
),
config=fg.config,
tags=self.run_tags,
code=_slurp_code(fg, self.repo_base),
vcs_info=self.vcs_info,
)
self.dag_template_id_cache[fg_id] = dag_template_id
return dag_template_id
def execute(
self,
fg: graph.FunctionGraph,
final_vars: List[Union[str, Callable, Variable]],
overrides: Dict[str, Any],
inputs: Dict[str, Any],
run_id: str,
) -> Dict[str, Any]:
"""Executes a graph in a blocking function.
:param fg: Graph to execute
:param final_vars: Variables we want
:param overrides: Overrides --- these short-circuit computation
:param inputs: Inputs to the Graph.
:return: The output of the final variables, in dictionary form.
"""
logger.info(f"Logging code version for DAG {self.dag_name}...")
dag_template_id = self._register_or_query(fg)
tracking_state = TrackingState(run_id)
with monkey_patch_adapter(fg.adapter, tracking_state):
tracking_state.clock_start()
dag_run_id = self.client.create_and_start_dag_run(
dag_template_id=dag_template_id,
tags=self.run_tags,
inputs=inputs if inputs is not None else {},
outputs=final_vars,
)
try:
out = self.executor.execute(fg, final_vars, overrides, inputs, run_id=dag_run_id)
tracking_state.clock_end(status=Status.SUCCESS)
return out
except Exception as e:
tracking_state.clock_end(status=Status.FAILURE)
raise e
finally:
finally_block_time = datetime.datetime.utcnow()
if tracking_state.status != Status.SUCCESS:
tracking_state.status = Status.FAILURE
# this assumes the task map only has things that have been processed, not
# nodes that have yet to be computed.
for task_name, task_run in tracking_state.task_map.items():
if task_run.status != Status.SUCCESS:
task_run.status = Status.FAILURE
task_run.end_time = finally_block_time
if task_run.error is None: # we likely aborted it.
# Note if we start to do concurrent execution we'll likely
# need to adjust this.
task_run.error = ["Run was likely aborted."]
if task_run.end_time is None and task_run.status == Status.SUCCESS:
task_run.end_time = finally_block_time
self.client.update_tasks(
dag_run_id=dag_run_id,
attributes=extract_attributes_from_tracking_state(tracking_state),
task_updates=extract_task_updates_from_tracking_state(tracking_state, fg),
)
self.client.log_dag_run_end(
dag_run_id=dag_run_id,
status=tracking_state.status.value,
)
logger.warning(
f"\nCaptured execution run. Results can be found at "
f"{self.dagworks_ui_url}/dashboard/project/{self.project_id}/runs/{dag_run_id}\n"
)
def validate(self, nodes_to_execute: List[node.Node]):
pass