Skip to content

Commit

Permalink
refine typehints and expose visit_sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
rkaminsk committed Feb 10, 2021
1 parent 2e0bfcf commit 635b1cb
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions libpyclingo_cffi/clingo/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@
'''

from enum import Enum, IntEnum
from typing import Any, Callable, ContextManager, Dict, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, ContextManager, Dict, List, MutableSequence, NamedTuple, Optional, Sequence, Tuple, Union
from collections import abc
from functools import total_ordering

Expand Down Expand Up @@ -782,8 +782,8 @@ def _py_location(rep):
_attribute_names = { _to_str(_lib.g_clingo_ast_attribute_names.names[i]): i
for i in range(_lib.g_clingo_ast_attribute_names.size) }

ASTValue = Union[str, int, Symbol, None, 'AST', StrSequence, ASTSequence]
ASTUpdate = Union[str, int, Symbol, None, 'AST', Sequence[str], Sequence['AST']]
ASTValue = Union[str, int, Symbol, Location, None, 'AST', StrSequence, ASTSequence]
ASTUpdate = Union[str, int, Symbol, Location, None, 'AST', Sequence[str], Sequence['AST']]

@total_ordering
class AST:
Expand Down Expand Up @@ -1113,8 +1113,6 @@ def add(self, statement: AST) -> None:
'''
_handle_error(_lib.clingo_program_builder_add_ast(self._rep, statement._rep))

T = TypeVar('T', AST, Union[ASTSequence, List[AST]], None)

class Transformer:
'''
Utility class to transform ASTs.
Expand Down Expand Up @@ -1147,15 +1145,28 @@ def visit_children(self, ast: AST, *args: Any, **kwargs: Any) -> Dict[str, ASTUp
The functions returns a dictionary that can be passed to `AST.update`.
It contains the attributes and values that have been transformed.
'''
update = dict()
update: Dict[str, ASTUpdate] = dict()
for key in ast.child_keys:
old = getattr(ast, key)
new = self._visit(old, *args, **kwargs)
new = self._dispatch(old, *args, **kwargs)
if new is not old:
update[key] = new
return update

def _visit(self, ast: T, *args: Any, **kwargs: Any) -> T:
def visit_sequence(self, sequence: ASTSequence, *args: Any, **kwargs: Any) -> MutableSequence[AST]:
'''
Transform a sequence of ASTs returning the same sequnce if there are no
changes or a list of ASTs otherwise.
'''
ret: MutableSequence[AST]
ret, lst = sequence, []
for old in sequence:
lst.append(self(old, *args, **kwargs))
if lst[-1] is not old:
ret = lst
return ret

def _dispatch(self, ast: Union[None, AST, ASTSequence], *args: Any, **kwargs: Any) -> Union[None, AST, MutableSequence[AST]]:
'''
Visit and transform an (optional) AST or a sequence of ASTs.
'''
Expand All @@ -1166,12 +1177,7 @@ def _visit(self, ast: T, *args: Any, **kwargs: Any) -> T:
return self.visit(ast, *args, **kwargs) # type: ignore

if isinstance(ast, Sequence):
ret, lst = ast, []
for old in ast:
lst.append(self.visit(old, *args, **kwargs))
if lst[-1] is not old:
ret = lst
return ret
return self.visit_sequence(ast, *args, **kwargs)

raise TypeError('unexpected type')

Expand Down

0 comments on commit 635b1cb

Please sign in to comment.