Commit 76d05e77 authored by Eckhart Arnold's avatar Eckhart Arnold
Browse files

- some slight adjustments regarding static type checking

parent c74091ca
...@@ -24,7 +24,7 @@ try: ...@@ -24,7 +24,7 @@ try:
import regex as re import regex as re
except ImportError: except ImportError:
import re import re
from typing import Any, Tuple, cast from typing import Any, cast, Tuple, Union
from DHParser.ebnf import EBNFTransformer, EBNFCompiler, grammar_changed, \ from DHParser.ebnf import EBNFTransformer, EBNFCompiler, grammar_changed, \
get_ebnf_scanner, get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler, \ get_ebnf_scanner, get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler, \
...@@ -147,7 +147,7 @@ def grammar_instance(grammar_representation) -> Tuple[Grammar, str]: ...@@ -147,7 +147,7 @@ def grammar_instance(grammar_representation) -> Tuple[Grammar, str]:
if isinstance(grammar_representation, str): if isinstance(grammar_representation, str):
# read grammar # read grammar
grammar_src = load_if_file(grammar_representation) grammar_src = load_if_file(grammar_representation)
if is_python_code(grammar_representation): if is_python_code(grammar_src):
parser_py, errors, AST = grammar_src, '', None parser_py, errors, AST = grammar_src, '', None
else: else:
with logging(False): with logging(False):
...@@ -170,7 +170,7 @@ def grammar_instance(grammar_representation) -> Tuple[Grammar, str]: ...@@ -170,7 +170,7 @@ def grammar_instance(grammar_representation) -> Tuple[Grammar, str]:
def compileDSL(text_or_file: str, def compileDSL(text_or_file: str,
scanner: ScannerFunc, scanner: ScannerFunc,
dsl_grammar: Grammar, dsl_grammar: Union[str, Grammar],
ast_transformation: TransformerFunc, ast_transformation: TransformerFunc,
compiler: Compiler) -> Any: compiler: Compiler) -> Any:
"""Compiles a text in a domain specific language (DSL) with an """Compiles a text in a domain specific language (DSL) with an
......
...@@ -22,7 +22,7 @@ try: ...@@ -22,7 +22,7 @@ try:
import regex as re import regex as re
except ImportError: except ImportError:
import re import re
from typing import Callable, cast, List, Set, Tuple from typing import Callable, List, Set, Tuple
from DHParser.toolkit import load_if_file, escape_re, md5, sane_parser_name from DHParser.toolkit import load_if_file, escape_re, md5, sane_parser_name
from DHParser.parsers import Grammar, mixin_comment, nil_scanner, Forward, RE, NegativeLookahead, \ from DHParser.parsers import Grammar, mixin_comment, nil_scanner, Forward, RE, NegativeLookahead, \
...@@ -197,10 +197,10 @@ EBNF_transformation_table = { ...@@ -197,10 +197,10 @@ EBNF_transformation_table = {
"syntax": "syntax":
remove_expendables, remove_expendables,
"directive, definition": "directive, definition":
partial(remove_tokens, tokens={'@', '='}), partial(remove_tokens, {'@', '='}),
"expression": "expression":
[replace_by_single_child, flatten, [replace_by_single_child, flatten,
partial(remove_tokens, tokens={'|'})], partial(remove_tokens, {'|'})],
"term": "term":
[replace_by_single_child, flatten], # supports both idioms: "{ factor }+" and "factor { factor }" [replace_by_single_child, flatten], # supports both idioms: "{ factor }+" and "factor { factor }"
"factor, flowmarker, retrieveop": "factor, flowmarker, retrieveop":
...@@ -214,7 +214,7 @@ EBNF_transformation_table = { ...@@ -214,7 +214,7 @@ EBNF_transformation_table = {
(TOKEN_PTYPE, WHITESPACE_PTYPE): (TOKEN_PTYPE, WHITESPACE_PTYPE):
[remove_expendables, reduce_single_child], [remove_expendables, reduce_single_child],
"list_": "list_":
[flatten, partial(remove_tokens, tokens={','})], [flatten, partial(remove_tokens, {','})],
"*": "*":
[remove_expendables, replace_by_single_child] [remove_expendables, replace_by_single_child]
} }
...@@ -223,8 +223,8 @@ EBNF_transformation_table = { ...@@ -223,8 +223,8 @@ EBNF_transformation_table = {
EBNF_validation_table = { EBNF_validation_table = {
# Semantic validation on the AST # Semantic validation on the AST
"repetition, option, oneormore": "repetition, option, oneormore":
[partial(forbid, child_tags=['repetition', 'option', 'oneormore']), [partial(forbid, ['repetition', 'option', 'oneormore']),
partial(assert_content, regex=r'(?!§)')], partial(assert_content, r'(?!§)')],
} }
...@@ -466,7 +466,7 @@ class EBNFCompiler(Compiler): ...@@ -466,7 +466,7 @@ class EBNFCompiler(Compiler):
return self.assemble_parser(definitions, node) return self.assemble_parser(definitions, node)
def on_definition(self, node: Node) -> Tuple[str, str]: def on_definition(self, node: Node) -> Tuple[str, str]:
rule = cast(str, node.children[0].result) rule = str(node.children[0]) # cast(str, node.children[0].result)
if rule in self.rules: if rule in self.rules:
node.add_error('A rule with name "%s" has already been defined.' % rule) node.add_error('A rule with name "%s" has already been defined.' % rule)
elif rule in EBNFCompiler.RESERVED_SYMBOLS: elif rule in EBNFCompiler.RESERVED_SYMBOLS:
...@@ -507,7 +507,7 @@ class EBNFCompiler(Compiler): ...@@ -507,7 +507,7 @@ class EBNFCompiler(Compiler):
return rx return rx
def on_directive(self, node: Node) -> str: def on_directive(self, node: Node) -> str:
key = cast(str, node.children[0].result).lower() key = str(node.children[0]).lower() # cast(str, node.children[0].result).lower()
assert key not in self.directives['tokens'] assert key not in self.directives['tokens']
if key in {'comment', 'whitespace'}: if key in {'comment', 'whitespace'}:
if node.children[1].parser.name == "list_": if node.children[1].parser.name == "list_":
...@@ -520,8 +520,8 @@ class EBNFCompiler(Compiler): ...@@ -520,8 +520,8 @@ class EBNFCompiler(Compiler):
else: else:
node.add_error('Value "%s" not allowed for directive "%s".' % (value, key)) node.add_error('Value "%s" not allowed for directive "%s".' % (value, key))
else: else:
value = cast(str, node.children[1].result).strip("~") value = str(node.children[1]).strip("~") # cast(str, node.children[1].result).strip("~")
if value != cast(str, node.children[1].result): if value != str(node.children[1]): # cast(str, node.children[1].result):
node.add_error("Whitespace marker '~' not allowed in definition of " node.add_error("Whitespace marker '~' not allowed in definition of "
"%s regular expression." % key) "%s regular expression." % key)
if value[0] + value[-1] in {'""', "''"}: if value[0] + value[-1] in {'""', "''"}:
...@@ -576,7 +576,7 @@ class EBNFCompiler(Compiler): ...@@ -576,7 +576,7 @@ class EBNFCompiler(Compiler):
def on_factor(self, node: Node) -> str: def on_factor(self, node: Node) -> str:
assert node.children assert node.children
assert len(node.children) >= 2, node.as_sexpr() assert len(node.children) >= 2, node.as_sexpr()
prefix = cast(str, node.children[0].result) prefix = str(node.children[0]) # cast(str, node.children[0].result)
custom_args = [] # type: List[str] custom_args = [] # type: List[str]
if prefix in {'::', ':'}: if prefix in {'::', ':'}:
...@@ -588,7 +588,7 @@ class EBNFCompiler(Compiler): ...@@ -588,7 +588,7 @@ class EBNFCompiler(Compiler):
return str(arg.result) return str(arg.result)
if str(arg) in self.directives['filter']: if str(arg) in self.directives['filter']:
custom_args = ['retrieve_filter=%s' % self.directives['filter'][str(arg)]] custom_args = ['retrieve_filter=%s' % self.directives['filter'][str(arg)]]
self.variables.add(cast(str, arg.result)) self.variables.add(str(arg)) # cast(str, arg.result)
elif len(node.children) > 2: elif len(node.children) > 2:
# shift = (Node(node.parser, node.result[1].result),) # shift = (Node(node.parser, node.result[1].result),)
...@@ -623,7 +623,7 @@ class EBNFCompiler(Compiler): ...@@ -623,7 +623,7 @@ class EBNFCompiler(Compiler):
"AST transformation!") "AST transformation!")
def on_symbol(self, node: Node) -> str: def on_symbol(self, node: Node) -> str:
result = cast(str, node.result) result = str(node) # ; assert result == cast(str, node.result)
if result in self.directives['tokens']: if result in self.directives['tokens']:
return 'ScannerToken("' + result + '")' return 'ScannerToken("' + result + '")'
else: else:
...@@ -633,10 +633,10 @@ class EBNFCompiler(Compiler): ...@@ -633,10 +633,10 @@ class EBNFCompiler(Compiler):
return result return result
def on_literal(self, node) -> str: def on_literal(self, node) -> str:
return 'Token(' + cast(str, node.result).replace('\\', r'\\') + ')' # return 'Token(' + ', '.join([node.result]) + ')' ? return 'Token(' + str(node).replace('\\', r'\\') + ')' # return 'Token(' + ', '.join([node.result]) + ')' ?
def on_regexp(self, node: Node) -> str: def on_regexp(self, node: Node) -> str:
rx = cast(str, node.result) rx = str(node) # ; assert rx == cast(str, node.result)
name = [] # type: List[str] name = [] # type: List[str]
if rx[:2] == '~/': if rx[:2] == '~/':
if not 'left' in self.directives['literalws']: if not 'left' in self.directives['literalws']:
......
...@@ -117,10 +117,8 @@ ZOMBIE_PARSER = ZombieParser() ...@@ -117,10 +117,8 @@ ZOMBIE_PARSER = ZombieParser()
# msg: str # msg: str
Error = NamedTuple('Error', [('pos', int), ('msg', str)]) Error = NamedTuple('Error', [('pos', int), ('msg', str)])
ResultType = Union[Tuple['Node', ...], str]
ChildrenType = Tuple['Node', ...] SloppyResultType = Union[Tuple['Node', ...], 'Node', str, None]
ResultType = Union[ChildrenType, str]
SloppyResultT = Union[ChildrenType, 'Node', str, None]
class Node: class Node:
...@@ -163,13 +161,14 @@ class Node: ...@@ -163,13 +161,14 @@ class Node:
parsing stage and never during or after the parsing stage and never during or after the
AST-transformation. AST-transformation.
""" """
def __init__(self, parser, result: SloppyResultT) -> None:
def __init__(self, parser, result: SloppyResultType) -> None:
"""Initializes the ``Node``-object with the ``Parser``-Instance """Initializes the ``Node``-object with the ``Parser``-Instance
that generated the node and the parser's result. that generated the node and the parser's result.
""" """
self._result = '' # type: ResultType self._result = '' # type: ResultType
self._errors = [] # type: List[str] self._errors = [] # type: List[str]
self._children = () # type: ChildrenType self._children = () # type: Tuple['Node', ...]
self._len = len(self.result) if not self.children else \ self._len = len(self.result) if not self.children else \
sum(child._len for child in self.children) # type: int sum(child._len for child in self.children) # type: int
# self.pos: int = 0 # continuous updating of pos values # self.pos: int = 0 # continuous updating of pos values
...@@ -181,7 +180,7 @@ class Node: ...@@ -181,7 +180,7 @@ class Node:
def __str__(self): def __str__(self):
if self.children: if self.children:
return "".join(str(child) for child in self.result) return "".join(str(child) for child in self.children)
return str(self.result) return str(self.result)
def __eq__(self, other): def __eq__(self, other):
...@@ -207,17 +206,17 @@ class Node: ...@@ -207,17 +206,17 @@ class Node:
return self._result return self._result
@result.setter @result.setter
def result(self, result: SloppyResultT): def result(self, result: SloppyResultType):
# # made obsolete by static type checking with mypy is done # # made obsolete by static type checking with mypy is done
# assert ((isinstance(result, tuple) and all(isinstance(child, Node) for child in result)) # assert ((isinstance(result, tuple) and all(isinstance(child, Node) for child in result))
# or isinstance(result, Node) # or isinstance(result, Node)
# or isinstance(result, str)), str(result) # or isinstance(result, str)), str(result)
self._result = (result,) if isinstance(result, Node) else result or '' self._result = (result,) if isinstance(result, Node) else result or ''
self._children = cast(ChildrenType, self._result) \ self._children = cast(Tuple['Node', ...], self._result) \
if isinstance(self._result, tuple) else cast(ChildrenType, ()) if isinstance(self._result, tuple) else cast(Tuple['Node', ...], ())
@property @property
def children(self) -> ChildrenType: def children(self) -> Tuple['Node', ...]:
return self._children return self._children
@property @property
...@@ -515,6 +514,14 @@ def traverse(root_node, processing_table, key_func=key_tag_name): ...@@ -515,6 +514,14 @@ def traverse(root_node, processing_table, key_func=key_tag_name):
traverse_recursive(root_node) traverse_recursive(root_node)
# Note on processing functions: If processing functions receive more
# than one parameter, the ``node``-parameter should always be the
# last parameter to ease partial function application, e.g.:
# def replace_parser(name, node):
# ...
# processing_func = partial(replace_parser, "special")
def no_operation(node): def no_operation(node):
pass pass
...@@ -552,10 +559,15 @@ def reduce_single_child(node): ...@@ -552,10 +559,15 @@ def reduce_single_child(node):
node.result = node.result[0].result node.result = node.result[0].result
def replace_parser(node, name, ptype=''): def replace_parser(name, node):
"""Replaces the parser of a Node with a mock parser with the given """Replaces the parser of a Node with a mock parser with the given
name and pseudo-type. name.
Parameters:
name(str): "NAME:PTYPE" of the surogate. The ptype is optional
node(Node): The node where the parser shall be replaced
""" """
name, ptype = (name.split(':') + [''])[:2]
node.parser = MockParser(name, ptype) node.parser = MockParser(name, ptype)
...@@ -616,28 +628,28 @@ def is_expendable(node): ...@@ -616,28 +628,28 @@ def is_expendable(node):
return is_empty(node) or is_whitespace(node) return is_empty(node) or is_whitespace(node)
def is_token(node, token_set=frozenset()): def is_token(token_set, node):
return node.parser.ptype == TOKEN_PTYPE and (not token_set or node.result in token_set) return node.parser.ptype == TOKEN_PTYPE and (not token_set or node.result in token_set)
def remove_children_if(node, condition): def remove_children_if(condition, node):
"""Removes all nodes from the result field if the function """Removes all nodes from the result field if the function
``condition(child_node)`` evaluates to ``True``.""" ``condition(child_node)`` evaluates to ``True``."""
if node.children: if node.children:
node.result = tuple(c for c in node.children if not condition(c)) node.result = tuple(c for c in node.children if not condition(c))
remove_whitespace = partial(remove_children_if, condition=is_whitespace) remove_whitespace = partial(remove_children_if, is_whitespace)
# remove_scanner_tokens = partial(remove_children_if, condition=is_scanner_token) # remove_scanner_tokens = partial(remove_children_if, condition=is_scanner_token)
remove_expendables = partial(remove_children_if, condition=is_expendable) remove_expendables = partial(remove_children_if, is_expendable)
def remove_tokens(node, tokens=frozenset()): def remove_tokens(tokens, node):
"""Reomoves any among a particular set of tokens from the immediate """Reomoves any among a particular set of tokens from the immediate
descendants of a node. If ``tokens`` is the empty set, all tokens descendants of a node. If ``tokens`` is the empty set, all tokens
are removed. are removed.
""" """
remove_children_if(node, partial(is_token, token_set=tokens)) remove_children_if(partial(is_token, tokens), node)
def remove_enclosing_delimiters(node): def remove_enclosing_delimiters(node):
...@@ -649,7 +661,7 @@ def remove_enclosing_delimiters(node): ...@@ -649,7 +661,7 @@ def remove_enclosing_delimiters(node):
node.result = node.result[1:-1] node.result = node.result[1:-1]
def map_content(node, func): def map_content(func, node):
"""Replaces the content of the node. ``func`` takes the node """Replaces the content of the node. ``func`` takes the node
as an argument an returns the mapped result. as an argument an returns the mapped result.
""" """
...@@ -664,21 +676,21 @@ def map_content(node, func): ...@@ -664,21 +676,21 @@ def map_content(node, func):
######################################################################## ########################################################################
def require(node, child_tag): def require(child_tag, node):
for child in node.children: for child in node.children:
if child.tag_name not in child_tag: if child.tag_name not in child_tag:
node.add_error('Element "%s" is not allowed inside "%s".' % node.add_error('Element "%s" is not allowed inside "%s".' %
(child.parser.name, node.parser.name)) (child.parser.name, node.parser.name))
def forbid(node, child_tags): def forbid(child_tags, node):
for child in node.children: for child in node.children:
if child.tag_name in child_tags: if child.tag_name in child_tags:
node.add_error('Element "%s" cannot be nested inside "%s".' % node.add_error('Element "%s" cannot be nested inside "%s".' %
(child.parser.name, node.parser.name)) (child.parser.name, node.parser.name))
def assert_content(node, regex): def assert_content(regex, node):
content = str(node) content = str(node)
if not re.match(regex, content): if not re.match(regex, content):
node.add_error('Element "%s" violates %s on %s' % node.add_error('Element "%s" violates %s on %s' %
......
...@@ -24,10 +24,10 @@ import os ...@@ -24,10 +24,10 @@ import os
import sys import sys
from functools import partial from functools import partial
from DHParser.toolkit import logging
from DHParser.dsl import compileDSL, compile_on_disk from DHParser.dsl import compileDSL, compile_on_disk
from DHParser.ebnf import get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler from DHParser.ebnf import get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler
from DHParser.parsers import compile_source, nil_scanner from DHParser.parsers import compile_source, nil_scanner
from DHParser.toolkit import logging
def selftest(file_name): def selftest(file_name):
...@@ -47,7 +47,6 @@ def selftest(file_name): ...@@ -47,7 +47,6 @@ def selftest(file_name):
else: else:
# compile the grammar again using the result of the previous # compile the grammar again using the result of the previous
# compilation as parser # compilation as parser
print(type(result))
result = compileDSL(grammar, nil_scanner, result, transformer, compiler) result = compileDSL(grammar, nil_scanner, result, transformer, compiler)
print(result) print(result)
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment