Commit 9fc9c2d2 authored by eckhart's avatar eckhart
Browse files

- transformation.py: bug fixes for transformation factory; unit tests added!

parent 2aeaf0de
......@@ -27,6 +27,7 @@ for CST -> AST transformations.
"""
import collections.abc
import inspect
from functools import partial, reduce, singledispatch
......@@ -34,7 +35,7 @@ from DHParser.error import Error
from DHParser.syntaxtree import Node, WHITESPACE_PTYPE, TOKEN_PTYPE, MockParser, ZOMBIE_NODE
from DHParser.toolkit import expand_table, smart_list, re, typing
from typing import AbstractSet, Any, ByteString, Callable, cast, Container, Dict, \
List, Sequence, Union, Text
Tuple, List, Sequence, Union, Text, GenericMeta
__all__ = ('TransformationDict',
'TransformationProc',
......@@ -149,7 +150,22 @@ def transformation_factory(t1=None, t2=None, t3=None, t4=None, t5=None):
does not have type annotations.
"""
def type_guard(t):
"""Raises an error if type `t` is a generic type or could be mistaken
for the type of the canonical first parameter "List[Node] of
transformation functions. Returns `t`."""
if isinstance(t, GenericMeta):
raise TypeError("Generic Type %s not permitted\n in transformation_factory "
"decorator. Use the equivalent non-generic type instead!"
% str(t))
if issubclass(List[Node], t):
raise TypeError("Sequence type %s not permitted\nin transformation_factory "
"decorator, because it could be mistaken for a base class "
"of List[Node].\nTry 'tuple' instead!" % str(t))
return t
def decorator(f):
nonlocal t1
sig = inspect.signature(f)
params = list(sig.parameters.values())[1:]
if len(params) == 0:
......@@ -157,29 +173,35 @@ def transformation_factory(t1=None, t2=None, t3=None, t4=None, t5=None):
assert t1 or params[0].annotation != params[0].empty, \
"No type information on second parameter found! Please, use type " \
"annotation or provide the type information via transformer-decorator."
p1type = t1 or params[0].annotation
f = singledispatch(f)
try:
if len(params) == 1 and issubclass(p1type, Container) \
and not (issubclass(p1type, Text) or issubclass(p1type, ByteString)):
def gen_special(*args):
c = set(args) if issubclass(p1type, AbstractSet) else \
list(args) if issubclass(p1type, Sequence) else args
d = {params[0].name: c}
return partial(f, **d)
f.register(p1type.__args__[0], gen_special)
except AttributeError:
pass # Union Type does not allow subclassing, but is not needed here
p1type = params[0].annotation
if t1 is None:
t1 = type_guard(p1type)
elif issubclass(p1type, type_guard(t1)):
try:
if len(params) == 1 and issubclass(p1type, Container) \
and not (issubclass(p1type, Text) or issubclass(p1type, ByteString)):
def gen_special(*args):
c = set(args) if issubclass(p1type, AbstractSet) else \
tuple(args) if issubclass(p1type, Sequence) else args
d = {params[0].name: c}
return partial(f, **d)
f.register(type_guard(p1type.__args__[0]), gen_special)
except AttributeError:
pass # Union Type does not allow subclassing, but is not needed here
else:
raise TypeError("Annotated type %s is not a subclass of decorated type %s !"
% (str(p1type), str(t1)))
def gen_partial(*args, **kwargs):
d = {p.name: arg for p, arg in zip(params, args)}
d.update(kwargs)
return partial(f, **d)
for t in (p1type, t2, t3, t4, t5):
for t in (t1, t2, t3, t4, t5):
if t:
f.register(t, gen_partial)
f.register(type_guard(t), gen_partial)
else:
break
return f
......@@ -314,26 +336,14 @@ def traverse_locally(context: List[Node],
traverse(context[-1], processing_table, key_func)
# @transformation_factory(List[Callable])
# def apply_to_child(context: List[Node], transformations: List[Callable], condition: Callable):
# """Applies a list of transformations to those children that meet a specifc condition."""
# node = context[-1]
# for child in node.children:
# context.append(child)
# if condition(context):
# for transform in transformations:
# transform(context)
# context.pop()
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def apply_if(context: List[Node], transformation: Callable, condition: Callable):
"""Applies a transformation only if a certain condition is met."""
if condition(context):
transformation(context)
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def apply_unless(context: List[Node], transformation: Callable, condition: Callable):
"""Applies a transformation if a certain condition is *not* met."""
if not condition(context):
......@@ -385,7 +395,7 @@ def is_expendable(context: List[Node]) -> bool:
return is_empty(context) or is_whitespace(context)
@transformation_factory # (AbstractSet[str])
@transformation_factory(collections.abc.Set)
def is_token(context: List[Node], tokens: AbstractSet[str] = frozenset()) -> bool:
"""Checks whether the last node in the context has `ptype == TOKEN_PTYPE`
and it's content matches one of the given tokens. Leading and trailing
......@@ -407,7 +417,7 @@ def is_token(context: List[Node], tokens: AbstractSet[str] = frozenset()) -> boo
return node.parser.ptype == TOKEN_PTYPE and (not tokens or stripped(node) in tokens)
@transformation_factory
@transformation_factory(collections.abc.Set)
def is_one_of(context: List[Node], tag_name_set: AbstractSet[str]) -> bool:
"""Returns true, if the node's tag_name is one of the given tag names."""
return context[-1].tag_name in tag_name_set
......@@ -426,7 +436,7 @@ def has_content(context: List[Node], regexp: str) -> bool:
return bool(re.match(regexp, context[-1].content))
@transformation_factory
@transformation_factory(collections.abc.Set)
def has_parent(context: List[Node], tag_name_set: AbstractSet[str]) -> bool:
"""Checks whether a node with one of the given tag names appears somewhere
in the context before the last node in the context."""
......@@ -550,7 +560,7 @@ def reduce_single_child(context: List[Node]):
_reduce_child(node, node.children[0])
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def replace_or_reduce(context: List[Node], condition: Callable=is_named):
"""
Replaces node by a single child, if condition is met on child,
......@@ -580,7 +590,7 @@ def replace_parser(context: List[Node], name: str):
node.parser = MockParser(name, ptype)
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def flatten(context: List[Node], condition: Callable=is_anonymous, recursive: bool=True):
"""
Flattens all children, that fulfil the given ``condition``
......@@ -645,8 +655,8 @@ def collapse(context: List[Node]):
# node.result = (nd for nd in leaves_iterator)
@transformation_factory
def merge_children(context: List[Node], tag_names: List[str]):
@transformation_factory(tuple)
def merge_children(context: List[Node], tag_names: Tuple[str]):
"""
Joins all children next to each other and with particular tag-names
into a single child node with a mock-parser with the name of the
......@@ -674,7 +684,7 @@ def merge_children(context: List[Node], tag_names: List[str]):
node.result = tuple(result)
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def replace_content(context: List[Node], func: Callable): # Callable[[Node], ResultType]
"""Replaces the content of the node. ``func`` takes the node's result
as an argument an returns the mapped result.
......@@ -683,7 +693,7 @@ def replace_content(context: List[Node], func: Callable): # Callable[[Node], Re
node.result = func(node.result)
@transformation_factory(str)
@transformation_factory # (str)
def replace_content_by(context: List[Node], content: str): # Callable[[Node], ResultType]
"""Replaces the content of the node with the given text content.
"""
......@@ -702,7 +712,7 @@ def replace_content_by(context: List[Node], content: str): # Callable[[Node], R
#######################################################################
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def lstrip(context: List[Node], condition: Callable = is_expendable):
"""Recursively removes all leading child-nodes that fulfill a given condition."""
node = context[-1]
......@@ -716,7 +726,7 @@ def lstrip(context: List[Node], condition: Callable = is_expendable):
node.result = node.children[i:]
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def rstrip(context: List[Node], condition: Callable = is_expendable):
"""Recursively removes all leading nodes that fulfill a given condition."""
node = context[-1]
......@@ -731,14 +741,14 @@ def rstrip(context: List[Node], condition: Callable = is_expendable):
node.result = node.children[:i]
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def strip(context: List[Node], condition: Callable = is_expendable):
"""Removes leading and trailing child-nodes that fulfill a given condition."""
lstrip(context, condition)
rstrip(context, condition)
@transformation_factory(slice)
@transformation_factory # (slice)
def keep_children(context: List[Node], section: slice = slice(None)):
"""Keeps only child-nodes which fall into a slice of the result field."""
node = context[-1]
......@@ -746,7 +756,7 @@ def keep_children(context: List[Node], section: slice = slice(None)):
node.result = node.children[section]
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def keep_children_if(context: List[Node], condition: Callable):
"""Removes all children for which `condition()` returns `True`."""
node = context[-1]
......@@ -754,7 +764,7 @@ def keep_children_if(context: List[Node], condition: Callable):
node.result = tuple(c for c in node.children if condition(context + [c]))
@transformation_factory
@transformation_factory(collections.abc.Set)
def keep_tokens(context: List[Node], tokens: AbstractSet[str]=frozenset()):
"""Removes any among a particular set of tokens from the immediate
descendants of a node. If ``tokens`` is the empty set, all tokens
......@@ -762,7 +772,7 @@ def keep_tokens(context: List[Node], tokens: AbstractSet[str]=frozenset()):
keep_children_if(context, partial(is_token, tokens=tokens))
@transformation_factory
@transformation_factory(collections.abc.Set)
def keep_nodes(context: List[Node], tag_names: AbstractSet[str]):
"""Removes children by tag name."""
keep_children_if(context, partial(is_one_of, tag_name_set=tag_names))
......@@ -774,7 +784,7 @@ def keep_content(context: List[Node], regexp: str):
keep_children_if(context, partial(has_content, regexp=regexp))
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def remove_children_if(context: List[Node], condition: Callable):
"""Removes all children for which `condition()` returns `True`."""
node = context[-1]
......@@ -821,7 +831,7 @@ remove_infix_operator = keep_children(slice(0, None, 2))
remove_single_child = apply_if(keep_children(slice(0)), lambda ctx: len(ctx[-1].children) == 1)
@transformation_factory(AbstractSet[str])
@transformation_factory(collections.abc.Set)
def remove_tokens(context: List[Node], tokens: AbstractSet[str]=frozenset()):
"""Removes any among a particular set of tokens from the immediate
descendants of a node. If ``tokens`` is the empty set, all tokens
......@@ -829,7 +839,7 @@ def remove_tokens(context: List[Node], tokens: AbstractSet[str]=frozenset()):
remove_children_if(context, partial(is_token, tokens=tokens))
@transformation_factory
@transformation_factory(collections.abc.Set)
def remove_nodes(context: List[Node], tag_names: AbstractSet[str]):
"""Removes children by tag name."""
remove_children_if(context, partial(is_one_of, tag_name_set=tag_names))
......@@ -847,7 +857,7 @@ def remove_content(context: List[Node], regexp: str):
#
########################################################################
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def error_on(context: List[Node], condition: Callable, error_msg: str = ''):
"""
Checks for `condition`; adds an error message if condition is not met.
......@@ -863,7 +873,7 @@ def error_on(context: List[Node], condition: Callable, error_msg: str = ''):
node.add_error("transform.error_on: Failed to meet condition " + cond_name)
@transformation_factory(Callable)
@transformation_factory(collections.abc.Callable)
def warn_on(context: List[Node], condition: Callable, warning: str = ''):
"""
Checks for `condition`; adds an warning message if condition is not met.
......@@ -892,7 +902,7 @@ def assert_content(context: List[Node], regexp: str):
(node.parser.name, str(regexp), node.content))
@transformation_factory
@transformation_factory(collections.abc.Set)
def require(context: List[Node], child_tags: AbstractSet[str]):
node = context[-1]
for child in node.children:
......@@ -901,7 +911,7 @@ def require(context: List[Node], child_tags: AbstractSet[str]):
(child.parser.name, node.parser.name))
@transformation_factory
@transformation_factory(collections.abc.Set)
def forbid(context: List[Node], child_tags: AbstractSet[str]):
node = context[-1]
for child in node.children:
......
......@@ -19,14 +19,17 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
import collections.abc
import sys
sys.path.extend(['../', './'])
from DHParser.syntaxtree import parse_sxpr
from DHParser.syntaxtree import Node, parse_sxpr, ZOMBIE_NODE
from DHParser.transform import traverse, reduce_single_child, remove_whitespace, \
traverse_locally, collapse, lstrip, rstrip, remove_content, remove_tokens
traverse_locally, collapse, lstrip, rstrip, remove_content, remove_tokens, \
transformation_factory
from DHParser.toolkit import typing
from typing import AbstractSet, List, Sequence, Tuple
class TestRemoval:
......@@ -86,7 +89,68 @@ class TestRemoval:
"*": []
}
traverse(cst, ast_table)
print(cst.as_sxpr())
cst1 = cst.as_sxpr()
assert cst1.find('et') < 0
ast_table = {
"wortarten": [remove_tokens("et")],
"*": []
}
traverse(cst, ast_table)
assert cst1 == cst.as_sxpr()
class TestTransformationFactory:
def test_mismatching_types(self):
@transformation_factory(tuple)
def good_transformation(context: List[Node], parameters: Tuple[str]):
pass
try:
@transformation_factory(tuple)
def bad_transformation(context: List[Node], parameters: AbstractSet[str]):
pass
assert False, "mismatching types not recognized by transform.transformation_factory()"
except TypeError:
pass
def test_forbidden_generic_types_in_decorator(self):
try:
@transformation_factory(AbstractSet[str])
def forbidden_transformation(context: List[Node], parameters: AbstractSet[str]):
pass
assert False, "use of generics not recognized in transform.transformation_factory()"
except TypeError:
pass
def test_forbidden_mutable_sequence_types_in_decorator(self):
try:
@transformation_factory(collections.abc.Sequence)
def parameterized_transformation(context: List[Node], parameters: Sequence[str]):
pass
_ = parameterized_transformation('a', 'b', 'c')
assert False, ("use of mutable sequences not recognized in "
"transform.transformation_factory()")
except TypeError:
pass
def test_parameter_set_expansion1(self):
save = None
@transformation_factory(collections.abc.Set)
def parameterized_transformation(context: List[Node], parameters: AbstractSet[str]):
nonlocal save
save = parameters
transformation = parameterized_transformation('a', 'b', 'c')
transformation([ZOMBIE_NODE])
assert save == {'a', 'b', 'c'}
def test_parameter_set_expansion2(self):
save = None
@transformation_factory(tuple)
def parameterized_transformation(context: List[Node], parameters: Tuple[str]):
nonlocal save
save = parameters
transformation = parameterized_transformation('a', 'b', 'c')
transformation([ZOMBIE_NODE])
assert save == ('a', 'b', 'c'), str(save)
class TestConditionalTransformations:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment