Skip to content

Commit

Permalink
Add custom formatter for Fire result (#345)
Browse files Browse the repository at this point in the history
Fixes #344 (see issue for more details)

This lets you define a function that will take the result from the Fire component and allows the user to alter it before fire looks at it to render it.
  • Loading branch information
beasteers authored Apr 16, 2022
1 parent 8469e48 commit 8bddeec
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
13 changes: 10 additions & 3 deletions fire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(argv):
import asyncio # pylint: disable=import-error,g-import-not-at-top # pytype: disable=import-error


def Fire(component=None, command=None, name=None):
def Fire(component=None, command=None, name=None, serialize=None):
"""This function, Fire, is the main entrypoint for Python Fire.
Executes a command either from the `command` argument or from sys.argv by
Expand Down Expand Up @@ -164,7 +164,7 @@ def Fire(component=None, command=None, name=None):
raise FireExit(0, component_trace)

# The command succeeded normally; print the result.
_PrintResult(component_trace, verbose=component_trace.verbose)
_PrintResult(component_trace, verbose=component_trace.verbose, serialize=serialize)
result = component_trace.GetResult()
return result

Expand Down Expand Up @@ -241,12 +241,19 @@ def _IsHelpShortcut(component_trace, remaining_args):
return show_help


def _PrintResult(component_trace, verbose=False):
def _PrintResult(component_trace, verbose=False, serialize=None):
"""Prints the result of the Fire call to stdout in a human readable way."""
# TODO(dbieber): Design human readable deserializable serialization method
# and move serialization to its own module.
result = component_trace.GetResult()

# Allow users to modify the return value of the component and provide
# custom formatting.
if serialize:
if not callable(serialize):
raise FireError("serialize argument {} must be empty or callable.".format(serialize))
result = serialize(result)

if value_types.HasCustomStr(result):
# If the object has a custom __str__ method, rather than one inherited from
# object, then we use that to serialize the object.
Expand Down
24 changes: 24 additions & 0 deletions fire/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def testClassMethod(self):
7,
)

def testCustomSerialize(self):
def serialize(x):
if isinstance(x, list):
return ', '.join(str(xi) for xi in x)
if isinstance(x, dict):
return ', '.join('{}={!r}'.format(k, v) for k, v in x.items())
if x == 'special':
return ['SURPRISE!!', "I'm a list!"]
return x

ident = lambda x: x

with self.assertOutputMatches(stdout='a, b', stderr=None):
result = core.Fire(ident, command=['[a,b]'], serialize=serialize)
with self.assertOutputMatches(stdout='a=5, b=6', stderr=None):
result = core.Fire(ident, command=['{a:5,b:6}'], serialize=serialize)
with self.assertOutputMatches(stdout='asdf', stderr=None):
result = core.Fire(ident, command=['asdf'], serialize=serialize)
with self.assertOutputMatches(stdout="SURPRISE!!\nI'm a list!\n", stderr=None):
result = core.Fire(ident, command=['special'], serialize=serialize)
with self.assertRaises(core.FireError):
core.Fire(ident, command=['asdf'], serialize=55)


@testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.')
def testLruCacheDecoratorBoundArg(self):
self.assertEqual(
Expand Down

0 comments on commit 8bddeec

Please sign in to comment.