diff --git a/DHParser/transform.pxd b/DHParser/transform.pxd index cfd18852b0f7c7958434b6d3f14b2ad853030603..4ac1c7017c12dd3035a16c8e3fc28c246576e42c 100644 --- a/DHParser/transform.pxd +++ b/DHParser/transform.pxd @@ -33,8 +33,8 @@ cpdef is_empty(context: List[Node]) # cpdef matches_re(context: List[Node], patterns: AbstractSet[str]) # cpdef has_content(context: List[Node], regexp: str) # cpdef has_ancestor(context: List[Node], tag_name_set: AbstractSet[str]) -cpdef _replace_by(node: Node, child: Node) -cpdef _reduce_child(node: Node, child: Node) +cpdef _replace_by(node: Node, child: Node, root: RootNode) +cpdef _reduce_child(node: Node, child: Node, root: RootNode) cpdef replace_by_single_child(context: List[Node]) cpdef reduce_single_child(context: List[Node]) # cpdef replace_or_reduce(context: List[Node], condition: Callable = ?) diff --git a/DHParser/transform.py b/DHParser/transform.py index 55740b123b01717fc4ba0c93e6455b5f2de3d28d..ef618d1f1335adefaf84eef6e135116daba9c089 100644 --- a/DHParser/transform.py +++ b/DHParser/transform.py @@ -36,7 +36,7 @@ from typing import AbstractSet, Any, ByteString, Callable, cast, Container, Dict from DHParser.error import Error, ErrorCode, AST_TRANSFORM_CRASH, ERROR from DHParser.syntaxtree import Node, WHITESPACE_PTYPE, TOKEN_PTYPE, LEAF_PTYPES, PLACEHOLDER, \ - RootNode, parse_sxpr, flatten_sxpr + RootNode, parse_sxpr, flatten_sxpr, ZOMBIE_TAG from DHParser.toolkit import issubtype, isgenerictype, expand_table, smart_list, re, cython, \ escape_formatstr @@ -647,12 +647,14 @@ def has_sibling(context: List[Node], tag_name_set: AbstractSet[str]): ####################################################################### -def update_attr(dest: Node, src: Tuple[Node, ...]): +def update_attr(dest: Node, src: Tuple[Node, ...], root: RootNode): """ - Adds all attributes from `src` to `dest`.This is needed, in order - to keep the attributes if the child node is going to be eliminated. + Adds all attributes from `src` to `dest` and transfers all errors + from `src` to `dest`. This is needed, in order to keep the attributes + if the child node is going to be eliminated. """ for s in src: + # update attributes if s != dest and hasattr(s, '_xml_attr'): for k, v in s.attr.items(): if k in dest.attr and v != dest.attr[k]: @@ -660,12 +662,21 @@ def update_attr(dest: Node, src: Tuple[Node, ...]): 'when reducing %s to %s ! Tree transformation stopped.' % (v, dest.attr[k], k, str(src), str(dest))) dest.attr[k] = v + # transfer errors + try: + ids = id(s) + if ids in root.error_nodes: + root.error_nodes.setdefault(id(dest), []).extend(root.error_nodes[ids]) + del root.error_nodes[ids] + except AttributeError: + # root was not of type RootNode, probably a doc-test + pass def swap_attributes(node: Node, other: Node): """ Exchanges the attributes between node and other. This might be - needed when rearanging trees. + needed when rearangeing trees. """ NA = node.has_attr() OA = other.has_attr() @@ -681,7 +692,7 @@ def swap_attributes(node: Node, other: Node): other._xml_attr = None -def _replace_by(node: Node, child: Node): +def _replace_by(node: Node, child: Node, root: RootNode): """ Replaces node's contents by child's content including the tag name. """ @@ -691,17 +702,18 @@ def _replace_by(node: Node, child: Node): # child.parser = MockParser(name, ptype) # parser names must not be overwritten, else: child.parser.name = node.parser.name node._set_result(child.result) - update_attr(node, (child,)) + update_attr(node, (child,), root) -def _reduce_child(node: Node, child: Node): +def _reduce_child(node: Node, child: Node, root: RootNode): """ Sets node's results to the child's result, keeping node's tag_name. """ node._set_result(child.result) - update_attr(child, (node,)) - if child.has_attr(): - node._xml_attr = child._xml_attr + update_attr(node, (child,), root) + # update_attr(child, (node,), root) + # if child.has_attr(): + # node._xml_attr = child._xml_attr ####################################################################### @@ -727,7 +739,7 @@ def replace_by_single_child(context: List[Node]): """ node = context[-1] if len(node.children) == 1: - _replace_by(node, node.children[0]) + _replace_by(node, node.children[0], cast(RootNode, context[0])) def replace_by_children(context: List[Node]): @@ -760,7 +772,7 @@ def reduce_single_child(context: List[Node]): """ node = context[-1] if len(node.children) == 1: - _reduce_child(node, node.children[0]) + _reduce_child(node, node.children[0], cast(RootNode, context[0])) @transformation_factory(collections.abc.Callable) @@ -773,9 +785,9 @@ def replace_or_reduce(context: List[Node], condition: Callable = is_named): if len(node.children) == 1: child = node.children[0] if condition(context): - _replace_by(node, child) + _replace_by(node, child, cast(RootNode, context[0])) else: - _reduce_child(node, child) + _reduce_child(node, child, cast(RootNode, context[0])) @transformation_factory(str) @@ -791,6 +803,8 @@ def change_tag_name(context: List[Node], tag_name: str, restriction: Callable = """ if restriction(context): node = context[-1] + # ZOMBIE_TAGS will not be changed, so that errors don't get overlooked + # if node.tag_name != ZOMBIE_TAG: node.tag_name = tag_name @@ -806,6 +820,8 @@ def replace_tag_names(context: List[Node], replacements: Dict[str, str]): node that exists as a key in the dictionary will be replaces by the value for that key. """ + # assert ZOMBIE_TAG not in replacements, 'Replacing ZOMBIE_TAGS is not allowed, " \ + # "because they result from errors that could otherwise be overlooked, subsequently!' for child in context[-1].children: child.tag_name = replacements.get(child.tag_name, child.tag_name) @@ -840,7 +856,7 @@ def flatten(context: List[Node], condition: Callable = is_anonymous, recursive: if recursive: flatten(context, condition, recursive) new_result.extend(child.children) - update_attr(node, (child,)) + update_attr(node, (child,), cast(RootNode, context[0])) else: new_result.append(child) context.pop() @@ -998,7 +1014,7 @@ def merge_adjacent(context: List[Node], condition: Callable, tag_name: str = '') tag_names = {nd.tag_name for nd in adjacent} head.result = reduce(operator.add, (nd.result for nd in adjacent), initial) for nd in adjacent[1:]: - update_attr(head, nd) + update_attr(head, nd, cast(RootNode, context[0])) if tag_name in tag_names: head.tag_name = tag_name new_result.append(head) @@ -1046,7 +1062,7 @@ def merge_connected(context: List[Node], content: Callable, delimiter: Callable, tag_names = {nd.tag_name for nd in adjacent} head.result = reduce(operator.add, (nd.result for nd in adjacent), initial) for nd in adjacent[1:]: - update_attr(head, nd) + update_attr(head, nd, cast(RootNode, context[0])) if content_name in tag_names: head.tag_name = content_name new_result.append(head) @@ -1056,7 +1072,7 @@ def merge_connected(context: List[Node], content: Callable, delimiter: Callable, node._set_result(tuple(new_result)) -def merge_results(dest: Node, src: Tuple[Node, ...]) -> bool: +def merge_results(dest: Node, src: Tuple[Node, ...], root: RootNode) -> bool: """ Merges the results of nodes `src` and writes them to the result of `dest` type-safely, if all src nodes are leaf-nodes (in which case @@ -1067,18 +1083,18 @@ def merge_results(dest: Node, src: Tuple[Node, ...]) -> bool: Example: >>> head, tail = Node('head', '123'), Node('tail', '456') - >>> merge_results(head, (head, tail)) # merge head and tail (in that order) into head + >>> merge_results(head, (head, tail), head) # merge head and tail (in that order) into head True >>> str(head) '123456' """ if all(nd.children for nd in src): dest.result = reduce(operator.add, (nd.children for nd in src[1:]), src[0].children) - update_attr(dest, src) + update_attr(dest, src, root) return True elif all(not nd.children for nd in src): dest.result = reduce(operator.add, (nd.content for nd in src[1:]), src[0].content) - update_attr(dest, src) + update_attr(dest, src, root) return True return False @@ -1132,13 +1148,13 @@ def move_adjacent(context: List[Node], condition: Callable, merge: bool = True): if len(before) + len(prevN) > 1: target = before[-1] if before else prevN[0] - if merge_results(target, prevN + before): + if merge_results(target, prevN + before, cast(RootNode, context[0])): before = (target,) before = before or prevN if len(after) + len(nextN) > 1: target = after[0] if after else nextN[-1] - if merge_results(target, after + nextN): + if merge_results(target, after + nextN, cast(RootNode, context[0])): after = (target,) after = after or nextN