transform.py 21.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""transformation.py - transformation functions for converting the
                       concrete into the abstract syntax tree

Copyright 2016  by Eckhart Arnold (arnold@badw.de)
                Bavarian Academy of Sciences an Humanities (badw.de)

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied.  See the License for the specific language governing
permissions and limitations under the License.
"""

import inspect
from functools import partial, reduce, singledispatch

23
from DHParser.syntaxtree import Node, WHITESPACE_PTYPE, TOKEN_PTYPE, MockParser
24

25
26
27
28
from DHParser.toolkit import expand_table, smart_list, re, typing

from typing import AbstractSet, Any, ByteString, Callable, cast, Container, Dict, \
    Iterator, List, NamedTuple, Sequence, Union, Text, Tuple
29

30
31
32
33
34
__all__ = ('TransformationDict',
           'TransformationProc',
           'ConditionFunc',
           'KeyFunc',
           'transformation_factory',
35
36
37
           'key_parser_name',
           'key_tag_name',
           'traverse',
38
           'is_named',
39
40
           'replace_by_single_child',
           'reduce_single_child',
41
           'replace_or_reduce',
42
43
           'replace_parser',
           'collapse',
44
           'merge_children',
45
46
           'replace_content',
           'apply_if',
47
           'is_anonymous',
48
49
50
51
           'is_whitespace',
           'is_empty',
           'is_expendable',
           'is_token',
52
           'is_one_of',
53
           'has_content',
54
           'remove_children_if',
55
56
57
58
59
60
61
62
           'remove_parser',
           'remove_content',
           'remove_first',
           'remove_last',
           'remove_whitespace',
           'remove_empty',
           'remove_expendables',
           'remove_brackets',
63
64
           'remove_infix_operator',
           'remove_single_child',
65
           'remove_tokens',
66
           'keep_children',
67
68
69
           'flatten',
           'forbid',
           'require',
70
71
72
73
           'assert_content',
           'assert_condition',
           'assert_has_children',
           'TRUE_CONDITION')
74
75


76
TransformationProc = Callable[[List[Node]], None]
Eckhart Arnold's avatar
Eckhart Arnold committed
77
78
TransformationDict = Dict[str, Sequence[Callable]]
ProcessingTableType = Dict[str, Union[Sequence[Callable], TransformationDict]]
79
80
81
82
ConditionFunc = Callable  # Callable[[List[Node]], bool]
KeyFunc = Callable[[Node], str]


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def transformation_factory(t=None):
    """Creates factory functions from transformation-functions that
    dispatch on the first parameter after the node parameter.

    Decorating a transformation-function that has more than merely the
    ``node``-parameter with ``transformation_factory`` creates a
    function with the same name, which returns a partial-function that
    takes just the node-parameter.

    Additionally, there is some some syntactic sugar for
    transformation-functions that receive a collection as their second
    parameter and do not have any further parameters. In this case a
    list of parameters passed to the factory function will be converted
    into a collection.

    Main benefit is readability of processing tables.

    Usage:
        @transformation_factory(AbtractSet[str])
102
        def remove_tokens(context, tokens):
103
104
105
            ...
      or, alternatively:
        @transformation_factory
106
        def remove_tokens(context, tokens: AbstractSet[str]):
107
108
109
110
111
112
            ...

    Example:
        trans_table = { 'expression': remove_tokens('+', '-') }
      instead of:
        trans_table = { 'expression': partial(remove_tokens, tokens={'+', '-'}) }
113
114
115
116
117

    Parameters:
        t:  type of the second argument of the transformation function,
            only necessary if the transformation functions' parameter list
            does not have type annotations.
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    """

    def decorator(f):
        sig = inspect.signature(f)
        params = list(sig.parameters.values())[1:]
        if len(params) == 0:
            return f  # '@transformer' not needed w/o free parameters
        assert t or params[0].annotation != params[0].empty, \
            "No type information on second parameter found! Please, use type " \
            "annotation or provide the type information via transfomer-decorator."
        p1type = t or params[0].annotation
        f = singledispatch(f)
        if len(params) == 1 and issubclass(p1type, Container) and not issubclass(p1type, Text) \
                and not issubclass(p1type, ByteString):
            def gen_special(*args):
                c = set(args) if issubclass(p1type, AbstractSet) else \
                    list(args) if issubclass(p1type, Sequence) else args
                d = {params[0].name: c}
                return partial(f, **d)

            f.register(p1type.__args__[0], gen_special)

        def gen_partial(*args, **kwargs):
            d = {p.name: arg for p, arg in zip(params, args)}
            d.update(kwargs)
            return partial(f, **d)

        f.register(p1type, gen_partial)
        return f

    if isinstance(t, type(lambda: 1)):
        # Provide for the case that transformation_factory has been
        # written as plain decorator and not as a function call that
        # returns the decorator proper.
152
        func = t
153
154
155
156
157
158
        t = None
        return decorator(func)
    else:
        return decorator


159
def key_parser_name(node: Node) -> str:
160
161
162
    return node.parser.name


163
def key_tag_name(node: Node) -> str:
164
165
166
    return node.tag_name


167
def traverse(root_node: Node,
Eckhart Arnold's avatar
Eckhart Arnold committed
168
             processing_table: ProcessingTableType,
169
170
171
             key_func: KeyFunc=key_tag_name) -> None:
    """
    Traverses the snytax tree starting with the given ``node`` depth
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    first and applies the sequences of callback-functions registered
    in the ``calltable``-dictionary.

    The most important use case is the transformation of a concrete
    syntax tree into an abstract tree (AST). But it is also imaginable
    to employ tree-traversal for the semantic analysis of the AST.

    In order to assign sequences of callback-functions to nodes, a
    dictionary ("processing table") is used. The keys usually represent
    tag names, but any other key function is possible. There exist
    three special keys:
        '+': always called (before any other processing function)
        '*': called for those nodes for which no (other) processing
             function appears in the table
        '~': always called (after any other processing function)

    Args:
        root_node (Node): The root-node of the syntax tree to be traversed
        processing_table (dict): node key -> sequence of functions that
            will be applied to matching nodes in order. This dictionary
192
193
            is interpreted as a `compact_table`. See
            `toolkit.expand_table` or ``EBNFCompiler.EBNFTransTable`
194
195
196
197
198
199
200
201
        key_func (function): A mapping key_func(node) -> keystr. The default
            key_func yields node.parser.name.

    Example:
        table = { "term": [replace_by_single_child, flatten],
            "factor, flowmarker, retrieveop": replace_by_single_child }
        traverse(node, table)
    """
202
203
204
205
206
207
    # Is this optimazation really needed?
    if '__cache__' in processing_table:
        # assume that processing table has already been expanded
        table = processing_table
        cache = processing_table['__cache__']
    else:
208
209
        # normalize processing_table entries by turning single values
        # into lists with a single value
210
211
        table = {name: smart_list(call) for name, call in list(processing_table.items())}
        table = expand_table(table)
Eckhart Arnold's avatar
Eckhart Arnold committed
212
        cache = table.setdefault('__cache__', cast(TransformationDict, dict()))
213
214
        # change processing table in place, so its already expanded and cache filled next time
        processing_table.clear()
215
216
217
218
219
220
        processing_table.update(table)

    # assert '__cache__' in processing_table
    # # Code without optimization
    # table = {name: smart_list(call) for name, call in list(processing_table.items())}
    # table = expand_table(table)
Eckhart Arnold's avatar
Eckhart Arnold committed
221
    # cache = {}  # type: Dict[str, List[Callable]]
222

223
224
    def traverse_recursive(context):
        node = context[-1]
225
226
        if node.children:
            for child in node.result:
227
228
                context.append(child)
                traverse_recursive(context)  # depth first
229
                node.error_flag = max(node.error_flag, child.error_flag)  # propagate error flag
230
                context.pop()
231
232

        key = key_func(node)
233
234
235
        try:
            sequence = cache[key]
        except KeyError:
236
237
238
239
240
241
242
243
244
245
            sequence = table.get('+', []) + \
                       table.get(key, table.get('*', [])) + \
                       table.get('~', [])
            # '+' always called (before any other processing function)
            # '*' called for those nodes for which no (other) processing function
            #     appears in the table
            # '~' always called (after any other processing function)
            cache[key] = sequence

        for call in sequence:
246
            call(context)
247

248
    traverse_recursive([root_node])
249
250
    # assert processing_table['__cache__']

251
252
253
254
255
256
257
258
259
260
261
262
263


# ------------------------------------------------
#
# rearranging transformations:
#     - tree may be rearranged (e.g.flattened)
#     - nodes that are not leaves may be dropped
#     - order is preserved
#     - leave content is preserved (though not necessarily the leaves themselves)
#
# ------------------------------------------------


264
def TRUE_CONDITION(context: List[Node]) -> bool:
265
266
267
    return True


268
def replace_child(node: Node):
269
    assert len(node.children) == 1
270
271
272
273
274
275
276
    child = node.children[0]
    if not child.parser.name:
        child.parser = MockParser(node.parser.name, child.parser.ptype)
        # parser names must not be overwritten, else: child.parser.name = node.parser.name
    node.parser = child.parser
    node._errors.extend(child._errors)
    node.result = child.result
277
278


279
def reduce_child(node: Node):
280
281
    assert len(node.children) == 1
    node._errors.extend(node.children[0]._errors)
Eckhart Arnold's avatar
Eckhart Arnold committed
282
    node.result = node.children[0].result
283
284


285
@transformation_factory(Callable)
286
287
288
def replace_by_single_child(context: List[Node], condition: Callable=TRUE_CONDITION):
    """
    Remove single branch node, replacing it by its immediate descendant
289
    if and only if the condision on the descendant is true.
290
291
292
    (In case the descendant's name is empty (i.e. anonymous) the
    name of this node's parser is kept.)
    """
293
    node = context[-1]
294
295
296
297
298
    if len(node.children) == 1:
        context.append(node.children[0])
        if  condition(context):
            replace_child(node)
        context.pop()
299
300


301
@transformation_factory(Callable)
302
303
304
def reduce_single_child(context: List[Node], condition: Callable=TRUE_CONDITION):
    """
    Reduce a single branch node, by transferring the result of its
305
    immediate descendant to this node, but keeping this node's parser entry.
306
307
    If the condition evaluates to false on the descendant, it will not
    be reduced.
308
    """
309
    node = context[-1]
310
311
312
313
314
    if len(node.children) == 1:
        context.append(node.children[0])
        if condition(context):
            reduce_child(node)
        context.pop()
315
316


317
def is_named(context: List[Node]) -> bool:
Eckhart Arnold's avatar
Eckhart Arnold committed
318
    return bool(context[-1].parser.name)
319
320
321
322


def is_anonymous(context: List[Node]) -> bool:
    return not context[-1].parser.name
323
324
325


@transformation_factory(Callable)
326
327
328
def replace_or_reduce(context: List[Node], condition: Callable=is_named):
    """
    Replaces node by a single child, if condition is met on child,
329
330
    otherwise (i.e. if the child is anonymous) reduces the child.
    """
331
    node = context[-1]
332
333
334
335
336
337
338
    if len(node.children) == 1:
        context.append(node.children[0])
        if condition(context):
            replace_child(node)
        else:
            reduce_child(node)
        context.pop()
339
340
341


@transformation_factory
342
343
344
def replace_parser(context: List[Node], name: str):
    """
    Replaces the parser of a Node with a mock parser with the given
345
346
347
348
349
350
    name.

    Parameters:
        name(str): "NAME:PTYPE" of the surogate. The ptype is optional
        node(Node): The node where the parser shall be replaced
    """
351
    node = context[-1]
352
353
354
355
356
    name, ptype = (name.split(':') + [''])[:2]
    node.parser = MockParser(name, ptype)


@transformation_factory(Callable)
357
358
359
def flatten(context: List[Node], condition: Callable=is_anonymous, recursive: bool=True):
    """
    Flattens all children, that fulfil the given `condition`
360
361
362
363
364
365
366
367
368
369
370
371
    (default: all unnamed children). Flattening means that wherever a
    node has child nodes, the child nodes are inserted in place of the
    node.
     If the parameter `recursive` is `True` the same will recursively be
    done with the child-nodes, first. In other words, all leaves of
    this node and its child nodes are collected in-order as direct
    children of this node.
     Applying flatten recursively will result in these kinds of
    structural transformation:
        (1 (+ 2) (+ 3)     ->   (1 + 2 + 3)
        (1 (+ (2 + (3))))  ->   (1 + 2 + 3)
    """
372
    node = context[-1]
373
    if node.children:
Eckhart Arnold's avatar
Eckhart Arnold committed
374
        new_result = []     # type: List[Node]
375
        for child in node.children:
376
377
            context.append(child)
            if child.children and condition(context):
378
                if recursive:
379
                    flatten(context, condition, recursive)
380
381
382
                new_result.extend(child.children)
            else:
                new_result.append(child)
383
            context.pop()
384
385
386
        node.result = tuple(new_result)


387
388
389
def collapse(context: List[Node]):
    """
    Collapses all sub-nodes of a node by replacing them with the
390
    string representation of the node.
391
    """
392
    node = context[-1]
393
394
395
396
    node.result = str(node)


@transformation_factory
397
398
399
400
401
def merge_children(context: List[Node], tag_names: List[str]):
    """
    Joins all children next to each other and with particular tag-
    names into a single child node with a mock-parser with the name of
    the first tag-name in the list.
402
    """
Eckhart Arnold's avatar
Eckhart Arnold committed
403
    node = context[-1]
404
    result = []
405
    name, ptype = ('', tag_names[0]) if tag_names[0][:1] == ':' else (tag_names[0], '')
406
    if node.children:
407
        i = 0
408
409
410
411
412
413
414
415
416
417
418
        L = len(node.children)
        while i < L:
            while i < L and not node.children[i].tag_name in tag_names:
                result.append(node.children[i])
                i += 1
            k = i + 1
            while (k < L and node.children[k].tag_name in tag_names
                   and bool(node.children[i].children) == bool(node.children[k].children)):
                k += 1
            if i < L:
                result.append(Node(MockParser(name, ptype),
Eckhart Arnold's avatar
Eckhart Arnold committed
419
420
                                   reduce(lambda a, b: a + b,
                                          (node.children for node in node.children[i:k]))))
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
            i = k
        node.result = tuple(result)


# ------------------------------------------------
#
# destructive transformations:
#     - tree may be rearranged (flattened),
#     - order is preserved
#     - but (irrelevant) leaves may be dropped
#     - errors of dropped leaves will be lost
#
# ------------------------------------------------


@transformation_factory
437
def replace_content(context: List[Node], func: Callable):  # Callable[[Node], ResultType]
438
439
440
    """Replaces the content of the node. ``func`` takes the node
    as an argument an returns the mapped result.
    """
441
    node = context[-1]
442
443
444
    node.result = func(node.result)


445
def is_whitespace(context: List[Node]) -> bool:
446
447
    """Removes whitespace and comments defined with the
    ``@comment``-directive."""
448
    return context[-1].parser.ptype == WHITESPACE_PTYPE
449
450


451
452
def is_empty(context: List[Node]) -> bool:
    return not context[-1].result
453
454


455
456
def is_expendable(context: List[Node]) -> bool:
    return is_empty(context) or is_whitespace(context)
457
458


459
460
def is_token(context: List[Node], tokens: AbstractSet[str] = frozenset()) -> bool:
    node = context[-1]
461
462
463
    return node.parser.ptype == TOKEN_PTYPE and (not tokens or node.result in tokens)


464
def is_one_of(context: List[Node], tag_name_set: AbstractSet[str]) -> bool:
465
466
    """Returns true, if the node's tag_name is on of the
    given tag names."""
467
    return context[-1].tag_name in tag_name_set
468
469


470
def has_content(context: List[Node], regexp: str) -> bool:
471
    """Checks a node's content against a regular expression."""
472
    return bool(re.match(regexp, str(context[-1])))
473
474


475
@transformation_factory(Callable)
476
477
def apply_if(context: List[Node], transformation: Callable, condition: Callable):
    """Applies a transformation only if a certain condition is met."""
478
    node = context[-1]
479
    if condition(node):
480
        transformation(context)
481
482


483
@transformation_factory(slice)
484
def keep_children(context: List[Node], section: slice = slice(None)):
485
    """Keeps only child-nodes which fall into a slice of the result field."""
486
    node = context[-1]
487
    if node.children:
488
        node.result = node.children[section]
489
490
491


@transformation_factory(Callable)
492
def remove_children_if(context: List[Node], condition: Callable, section: slice = slice(None)):
493
494
495
496
497
498
499
500
    """Removes all children for which `condition()` returns `True`."""
    node = context[-1]
    if node.children:
        node.result = tuple(c for c in node.children if not condition(context + [c]))


@transformation_factory(Callable)
def remove_children(context: List[Node], condition: Callable, section: slice = slice(None)):
501
    """Removes all nodes from a slice of the result field if the function
502
    `condition(child_node)` evaluates to `True`."""
503
    node = context[-1]
504
    if node.children:
505
506
507
        c = node.children
        N = len(c)
        rng = range(*section.indices(N))
508
509
510
511
512
513
514
515
516
517
        node.result = tuple(c[i] for i in range(N)
                            if not i in rng or not condition(context + [c[i]]))
        # selection = []
        # for i in range(N):
        #     context.append(c[i])
        #     if not i in rng or not condition(context):
        #         selection.append(c[i])
        #     context.pop()
        # if len(selection) != c:
        #     node.result = tuple(selection)
518
519
520
521
522


remove_whitespace = remove_children_if(is_whitespace)  # partial(remove_children_if, condition=is_whitespace)
remove_empty = remove_children_if(is_empty)
remove_expendables = remove_children_if(is_expendable)  # partial(remove_children_if, condition=is_expendable)
523
524
525
526
527
remove_first = apply_if(keep_children(slice(1, None)), lambda nd: len(nd.children) > 1)
remove_last = apply_if(keep_children(slice(None, -1)), lambda nd: len(nd.children) > 1)
remove_brackets = apply_if(keep_children(slice(1, -1)), lambda nd: len(nd.children) >= 2)
remove_infix_operator = keep_children(slice(0, None, 2))
remove_single_child = apply_if(keep_children(slice(0)), lambda nd: len(nd.children) == 1)
528
529
530


@transformation_factory
531
def remove_tokens(context: List[Node], tokens: AbstractSet[str] = frozenset()):
532
533
534
    """Reomoves any among a particular set of tokens from the immediate
    descendants of a node. If ``tokens`` is the empty set, all tokens
    are removed."""
535
    remove_children_if(context, partial(is_token, tokens=tokens))
536
537
538


@transformation_factory
539
def remove_parser(context: List[Node], tag_names: AbstractSet[str]):
Eckhart Arnold's avatar
Eckhart Arnold committed
540
    """Removes children by tag name."""
541
    remove_children_if(context, partial(is_one_of, tag_name_set=tag_names))
542
543
544


@transformation_factory
545
def remove_content(context: List[Node], regexp: str):
546
    """Removes children depending on their string value."""
547
    remove_children_if(context, partial(has_content, regexp=regexp))
548
549
550
551


########################################################################
#
552
# AST semantic validation functions (EXPERIMENTAL!!!)
553
554
555
#
########################################################################

556
@transformation_factory(Callable)
Eckhart Arnold's avatar
Eckhart Arnold committed
557
def assert_condition(context: List[Node], condition: Callable, error_msg: str = ''):
558
    """Checks for `condition`; adds an error message if condition is not met."""
559
    node = context[-1]
560
    if not condition(context):
561
562
563
564
565
566
567
568
569
570
571
572
573
        if error_msg:
            node.add_error(error_msg % node.tag_name if error_msg.find("%s") > 0 else error_msg)
        else:
            cond_name = condition.__name__ if hasattr(condition, '__name__') \
                        else condition.__class__.__name__ if hasattr(condition, '__class__') \
                        else '<unknown>'
            node.add_error("transform.assert_condition: Failed to meet condition " + cond_name)


assert_has_children = assert_condition(lambda nd: nd.children, 'Element "%s" has no children')


@transformation_factory
574
def assert_content(context: List[Node], regexp: str):
575
    node = context[-1]
576
    if not has_content(context, regexp):
577
578
579
        node.add_error('Element "%s" violates %s on %s' %
                       (node.parser.name, str(regexp), str(node)))

580
581

@transformation_factory
582
def require(context: List[Node], child_tags: AbstractSet[str]):
583
    node = context[-1]
584
585
586
587
588
589
590
    for child in node.children:
        if child.tag_name not in child_tags:
            node.add_error('Element "%s" is not allowed inside "%s".' %
                           (child.parser.name, node.parser.name))


@transformation_factory
591
def forbid(context: List[Node], child_tags: AbstractSet[str]):
592
    node = context[-1]
593
594
595
596
    for child in node.children:
        if child.tag_name in child_tags:
            node.add_error('Element "%s" cannot be nested inside "%s".' %
                           (child.parser.name, node.parser.name))