-
Notifications
You must be signed in to change notification settings - Fork 657
/
fire_bind.py
332 lines (305 loc) · 14.1 KB
/
fire_bind.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
try:
import fire
import fire.core
import fire.helptext
except ImportError:
fire = None
import inspect
from .frameworks import _patched_call # noqa
from ..config import get_remote_task_id, running_remotely
from ..utilities.dicts import cast_str_to_bool
class SimpleNamespace(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
keys = sorted(self.__dict__)
items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
return "{}({})".format(type(self).__name__, ", ".join(items))
def __eq__(self, other):
return self.__dict__ == other.__dict__
class PatchFire:
_args = {}
_command_type = "fire.Command"
_command_arg_type_template = "fire.Arg@%s"
_shared_arg_type = "fire.Arg.shared"
_section_name = "Args"
_args_sep = "/"
_commands_sep = "."
_current_task = None
__remote_task_params = None
__remote_task_params_dict = {}
__patched = False
__groups = []
__commands = {}
__default_args = SimpleNamespace(
completion=None, help=False, interactive=False, separator="-", trace=False, verbose=False
)
__current_command = None
__fetched_current_command = False
__command_args = {}
@classmethod
def patch(cls, task=None):
if fire is None:
return
cls._current_task = task
if task:
cls._update_task_args()
if not cls.__patched:
cls.__patched = True
if running_remotely():
fire.core._Fire = _patched_call(fire.core._Fire, PatchFire.__Fire)
else:
fire.core._CallAndUpdateTrace = _patched_call(
fire.core._CallAndUpdateTrace, PatchFire.__CallAndUpdateTrace
)
@classmethod
def _update_task_args(cls):
if running_remotely() or not cls._current_task:
return
args = {}
parameters_types = {}
if cls.__current_command is None:
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
for k in PatchFire.__command_args.get(None) or []:
k = cls._section_name + cls._args_sep + k
if k not in args:
args[k] = None
else:
args[cls._section_name + cls._args_sep + cls.__current_command] = True
parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type
args.update(
{
cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v
for k, v in cls._args.items()
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
}
)
args.update(
{
cls._section_name + cls._args_sep + k: v
for k, v in cls._args.items()
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
}
)
parameters_types.update(
{
cls._section_name
+ cls._args_sep
+ cls.__current_command
+ cls._args_sep
+ k: cls._command_arg_type_template % cls.__current_command
for k in cls._args.keys()
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
}
)
parameters_types.update(
{
cls._section_name + cls._args_sep + k: cls._shared_arg_type
for k in cls._args.keys()
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
}
)
for command in cls.__commands:
if command == cls.__current_command:
continue
args[cls._section_name + cls._args_sep + command] = False
parameters_types[cls._section_name + cls._args_sep + command] = cls._command_type
unused_command_args = {
cls._section_name + cls._args_sep + command + cls._args_sep + k: None
for k in (cls.__command_args.get(command) or [])
}
unused_paramenters_types = {
cls._section_name
+ cls._args_sep
+ command
+ cls._args_sep
+ k: cls._command_arg_type_template % command
for k in (cls.__command_args.get(command) or [])
}
args.update(unused_command_args)
parameters_types.update(unused_paramenters_types)
# noinspection PyProtectedMember
cls._current_task._set_parameters(
args,
__update=True,
__parameters_types=parameters_types,
)
@staticmethod
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
if not running_remotely():
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
command = PatchFire._load_task_params()
if command is not None:
replaced_args = command.split(PatchFire._commands_sep)
else:
replaced_args = []
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
if command is not None and param.type == PatchFire._command_arg_type_template % command:
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0:
replaced_args.append(value)
if param.type == PatchFire._shared_arg_type:
replaced_args.append("--" + param.name)
value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0:
replaced_args.append(value)
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
@staticmethod
def __CallAndUpdateTrace( # noqa
original_fn, component, args_, component_trace, treatment, target, *args, **kwargs
):
if running_remotely():
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)
if not PatchFire.__fetched_current_command:
PatchFire.__fetched_current_command = True
context, component_context = PatchFire.__get_context_and_component(component)
PatchFire.__groups, PatchFire.__commands = PatchFire.__get_all_groups_and_commands(
component_context, context
)
PatchFire.__current_command = PatchFire.__get_current_command(
args_, PatchFire.__groups, PatchFire.__commands
)
for command in PatchFire.__commands:
PatchFire.__command_args[command] = PatchFire.__get_command_args(
component_context, command.split(PatchFire._commands_sep), PatchFire.__default_args, context
)
PatchFire.__command_args[None] = PatchFire.__get_command_args(
component_context,
"",
PatchFire.__default_args,
context,
)
for k, v in PatchFire.__commands.items():
if v == component:
PatchFire.__current_command = k
break
# Comparing methods in Python is equivalent to comparing the __func__ of the methods
# and the objects they are bound to. We do not care about the object in this case,
# so we just compare the __func__
if inspect.ismethod(component) and inspect.ismethod(v) and v.__func__ == component.__func__:
PatchFire.__current_command = k
break
fn = component.__call__ if treatment == "callable" else component
metadata = fire.decorators.GetMetadata(component)
fn_spec = fire.inspectutils.GetFullArgSpec(component)
parse = fire.core._MakeParseFn(fn, metadata) # noqa
(parsed_args, parsed_kwargs), _, _, _ = parse(args_)
PatchFire._args.update({k: v for k, v in zip(fn_spec.args, parsed_args)})
PatchFire._args.update(parsed_kwargs)
PatchFire._update_task_args()
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)
@staticmethod
def __get_context_and_component(component):
context = {}
component_context = component
# Walk through the stack to find the arguments with fire.Fire() has been called.
# Can't do it by patching the function because we want to patch _CallAndUpdateTrace,
# which is called by fire.Fire()
frame_infos = inspect.stack()
for frame_info_ind, frame_info in enumerate(frame_infos):
if frame_info.function == "Fire":
component_context = inspect.getargvalues(frame_info.frame).locals["component"]
if inspect.getargvalues(frame_info.frame).locals["component"] is None:
# This is similar to how fire finds this context
fire_context_frame = frame_infos[frame_info_ind + 1].frame
context.update(fire_context_frame.f_globals)
context.update(fire_context_frame.f_locals)
# Ignore modules, as they yield too many commands.
# Also ignore clearml.task.
context = {
k: v
for k, v in context.items()
if not inspect.ismodule(v) and (not inspect.isclass(v) or v.__module__ != "clearml.task")
}
break
return context, component_context
@staticmethod
def __get_all_groups_and_commands(component, context):
groups = []
commands = {}
# skip modules
if inspect.ismodule(component):
return groups, commands
component_trace_result = PatchFire.__safe_Fire(component, [], PatchFire.__default_args, context)
group_args = [[]]
while len(group_args) > 0:
query_group = group_args[-1]
groups.append(PatchFire._commands_sep.join(query_group))
group_args = group_args[:-1]
current_groups, current_commands = PatchFire.__get_groups_and_commands_for_args(
component_trace_result, query_group, PatchFire.__default_args, context
)
for command in current_commands:
prefix = (
PatchFire._commands_sep.join(query_group) + PatchFire._commands_sep if len(query_group) > 0 else ""
)
commands[prefix + command[0]] = command[1]
for group in current_groups:
group_args.append(query_group + [group[0]])
return groups, commands
@staticmethod
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=name)
# set verbose to True or else we might miss some commands
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=True) # noqa
groups = [(name, member) for name, member in groups.GetItems()]
commands = [(name, member) for name, member in commands.GetItems()]
return groups, commands
@staticmethod
def __get_current_command(args_, groups, commands):
current_command = ""
for arg in args_:
prefix = (current_command + PatchFire._commands_sep) if len(current_command) > 0 else ""
potential_current_command = prefix + arg
if potential_current_command not in groups:
if potential_current_command in commands:
return potential_current_command
else:
return None
current_command = potential_current_command
return None
@staticmethod
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=None)
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
return fn_spec.args
@staticmethod
def __safe_Fire(component, args_, parsed_flag_args, context, name=None):
orig = None
# noinspection PyBroadException
try:
def __CallAndUpdateTrace_rogue_call_guard(*args, **kwargs):
raise fire.core.FireError()
orig = fire.core._CallAndUpdateTrace # noqa
fire.core._CallAndUpdateTrace = __CallAndUpdateTrace_rogue_call_guard # noqa
result = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
except Exception:
result = None
finally:
if orig:
fire.core._CallAndUpdateTrace = orig # noqa
return result
@staticmethod
def _load_task_params():
if not PatchFire.__remote_task_params:
from clearml import Task
t = Task.get_task(task_id=get_remote_task_id())
# noinspection PyProtectedMember
PatchFire.__remote_task_params = t._get_task_property("hyperparams") or {}
params_dict = t.get_parameters(backwards_compatibility=False)
skip = len(PatchFire._section_name) + 1
PatchFire.__remote_task_params_dict = {
k[skip:]: v
for k, v in params_dict.items()
if k.startswith(PatchFire._section_name + PatchFire._args_sep)
}
command = [
p.name
for p in PatchFire.__remote_task_params[PatchFire._section_name].values()
if p.type == PatchFire._command_type and cast_str_to_bool(p.value, strip=True)
]
return command[0] if command else None
# patch fire before anything
PatchFire.patch()