Commit 2ce46062 authored by Eckhart Arnold's avatar Eckhart Arnold
Browse files

- fixed mypy-type errors

parent be770f50
...@@ -26,9 +26,9 @@ try: ...@@ -26,9 +26,9 @@ try:
except ImportError: except ImportError:
import re import re
try: try:
from typing import Any, cast, Tuple, Union from typing import Any, cast, Tuple, Union, Iterable
except ImportError: except ImportError:
from .typing34 import Any, cast, Tuple, Union from .typing34 import Any, cast, Tuple, Union, Iterable
from DHParser.ebnf import EBNFCompiler, grammar_changed, \ from DHParser.ebnf import EBNFCompiler, grammar_changed, \
get_ebnf_preprocessor, get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler, \ get_ebnf_preprocessor, get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler, \
...@@ -511,7 +511,7 @@ def recompile_grammar(ebnf_filename, force=False) -> bool: ...@@ -511,7 +511,7 @@ def recompile_grammar(ebnf_filename, force=False) -> bool:
base, ext = os.path.splitext(ebnf_filename) base, ext = os.path.splitext(ebnf_filename)
compiler_name = base + 'Compiler.py' compiler_name = base + 'Compiler.py'
error_file_name = base + '_ebnf_ERRORS.txt' error_file_name = base + '_ebnf_ERRORS.txt'
errors = [] errors = [] # type: Iterable[str]
if (not os.path.exists(compiler_name) or force or if (not os.path.exists(compiler_name) or force or
grammar_changed(compiler_name, ebnf_filename)): grammar_changed(compiler_name, ebnf_filename)):
# print("recompiling parser for: " + ebnf_filename) # print("recompiling parser for: " + ebnf_filename)
......
...@@ -25,9 +25,9 @@ try: ...@@ -25,9 +25,9 @@ try:
except ImportError: except ImportError:
import re import re
try: try:
from typing import Callable, Dict, List, Set, Tuple from typing import Callable, Dict, List, Set, Tuple, Union
except ImportError: except ImportError:
from .typing34 import Callable, Dict, List, Set, Tuple from .typing34 import Callable, Dict, List, Set, Tuple, Union
from DHParser.toolkit import load_if_file, escape_re, md5, sane_parser_name from DHParser.toolkit import load_if_file, escape_re, md5, sane_parser_name
from DHParser.parser import Grammar, mixin_comment, nil_preprocessor, Forward, RE, NegativeLookahead, \ from DHParser.parser import Grammar, mixin_comment, nil_preprocessor, Forward, RE, NegativeLookahead, \
...@@ -222,7 +222,7 @@ EBNF_AST_transformation_table = { ...@@ -222,7 +222,7 @@ EBNF_AST_transformation_table = {
} }
def EBNFTransform() -> TransformationDict: def EBNFTransform() -> TransformationFunc:
return partial(traverse, processing_table=EBNF_AST_transformation_table.copy()) return partial(traverse, processing_table=EBNF_AST_transformation_table.copy())
def get_ebnf_transformer() -> TransformationFunc: def get_ebnf_transformer() -> TransformationFunc:
......
...@@ -158,8 +158,7 @@ class HistoryRecord: ...@@ -158,8 +158,7 @@ class HistoryRecord:
def __init__(self, call_stack: List['Parser'], node: Node, remaining: int) -> None: def __init__(self, call_stack: List['Parser'], node: Node, remaining: int) -> None:
# copy call stack, dropping uninformative Forward-Parsers # copy call stack, dropping uninformative Forward-Parsers
self.call_stack = [p for p in call_stack if p.ptype != ":Forward"] self.call_stack = [p for p in call_stack if p.ptype != ":Forward"] # type: List['Parser']
# type: List['Parser']
self.node = node # type: Node self.node = node # type: Node
self.remaining = remaining # type: int self.remaining = remaining # type: int
document = call_stack[-1].grammar.document__.text if call_stack else '' document = call_stack[-1].grammar.document__.text if call_stack else ''
...@@ -188,7 +187,7 @@ class HistoryRecord: ...@@ -188,7 +187,7 @@ class HistoryRecord:
else slice(-self.remaining, None)) else slice(-self.remaining, None))
@staticmethod @staticmethod
def last_match(history: List['HistoryRecord']) -> Optional['HistoryRecord']: def last_match(history: List['HistoryRecord']) -> Union['HistoryRecord', None]:
""" """
Returns the last match from the parsing-history. Returns the last match from the parsing-history.
Args: Args:
...@@ -204,7 +203,7 @@ class HistoryRecord: ...@@ -204,7 +203,7 @@ class HistoryRecord:
return None return None
@staticmethod @staticmethod
def most_advanced_match(history: List['HistoryRecord']) -> Optional['HistoryRecord']: def most_advanced_match(history: List['HistoryRecord']) -> Union['HistoryRecord', None]:
""" """
Returns the closest-to-the-end-match from the parsing-history. Returns the closest-to-the-end-match from the parsing-history.
Args: Args:
...@@ -632,10 +631,10 @@ class Grammar: ...@@ -632,10 +631,10 @@ class Grammar:
# root__ must be overwritten with the root-parser by grammar subclass # root__ must be overwritten with the root-parser by grammar subclass
parser_initialization__ = "pending" # type: str parser_initialization__ = "pending" # type: str
# some default values # some default values
COMMENT__ = r'' # r'#.*(?:\n|$)' COMMENT__ = r'' # type: str # r'#.*(?:\n|$)'
WSP__ = mixin_comment(whitespace=r'[\t ]*', comment=COMMENT__) WSP__ = mixin_comment(whitespace=r'[\t ]*', comment=COMMENT__) # type: str
wspL__ = '' wspL__ = '' # type: str
wspR__ = WSP__ wspR__ = WSP__ # type: str
@classmethod @classmethod
...@@ -741,7 +740,7 @@ class Grammar: ...@@ -741,7 +740,7 @@ class Grammar:
@property @property
def reversed__(self) -> str: def reversed__(self) -> StringView:
if not self._reversed__: if not self._reversed__:
self._reversed__ = StringView(self.document__.text[::-1]) self._reversed__ = StringView(self.document__.text[::-1])
return self._reversed__ return self._reversed__
......
...@@ -16,6 +16,7 @@ implied. See the License for the specific language governing ...@@ -16,6 +16,7 @@ implied. See the License for the specific language governing
permissions and limitations under the License. permissions and limitations under the License.
""" """
import collections.abc
import copy import copy
import os import os
from functools import partial from functools import partial
...@@ -144,7 +145,7 @@ def flatten_sxpr(sxpr: str) -> str: ...@@ -144,7 +145,7 @@ def flatten_sxpr(sxpr: str) -> str:
return re.sub('\s(?=\))', '', re.sub('\s+', ' ', sxpr)).strip() return re.sub('\s(?=\))', '', re.sub('\s+', ' ', sxpr)).strip()
class Node: class Node(collections.abc.Sized):
""" """
Represents a node in the concrete or abstract syntax tree. Represents a node in the concrete or abstract syntax tree.
...@@ -199,7 +200,7 @@ class Node: ...@@ -199,7 +200,7 @@ class Node:
# self.error_flag = False # type: bool # self.error_flag = False # type: bool
self._errors = [] # type: List[str] self._errors = [] # type: List[str]
self.result = result self.result = result
self._len = len(result) if not self.children else \ self._len = len(self._result) if not self.children else \
sum(child._len for child in self.children) # type: int sum(child._len for child in self.children) # type: int
# self.pos: int = 0 # continuous updating of pos values wastes a lot of time # self.pos: int = 0 # continuous updating of pos values wastes a lot of time
self._pos = -1 # type: int self._pos = -1 # type: int
......
...@@ -43,9 +43,9 @@ except ImportError: ...@@ -43,9 +43,9 @@ except ImportError:
import sys import sys
try: try:
from typing import Any, List, Tuple, Iterable, Union, Optional from typing import Any, List, Tuple, Iterable, Sequence, Union, Optional, TypeVar
except ImportError: except ImportError:
from .typing34 import Any, List, Tuple, Iterable, Union, Optional from .typing34 import Any, List, Tuple, Iterable, Sequence, Union, Optional, TypeVar
__all__ = ('logging', __all__ = ('logging',
'is_logging', 'is_logging',
...@@ -154,7 +154,7 @@ def clear_logs(logfile_types={'.cst', '.ast', '.log'}): ...@@ -154,7 +154,7 @@ def clear_logs(logfile_types={'.cst', '.ast', '.log'}):
os.rmdir(log_dirname) os.rmdir(log_dirname)
class StringView: class StringView(collections.abc.Sized):
""""A rudimentary StringView class, just enough for the use cases """"A rudimentary StringView class, just enough for the use cases
in parswer.py. in parswer.py.
...@@ -218,7 +218,7 @@ def sv_match(regex, sv: StringView): ...@@ -218,7 +218,7 @@ def sv_match(regex, sv: StringView):
return regex.match(sv.text, pos=sv.begin, endpos=sv.end) return regex.match(sv.text, pos=sv.begin, endpos=sv.end)
def sv_index(absolute_index: Union[int, Iterable], sv: StringView) -> Union[int, tuple]: def sv_index(absolute_index: Union[int, Iterable], sv: StringView) -> int:
""" """
Converts the an index into string watched by a StringView object Converts the an index into string watched by a StringView object
to an index relativ to the string view object, e.g.: to an index relativ to the string view object, e.g.:
...@@ -229,10 +229,14 @@ def sv_index(absolute_index: Union[int, Iterable], sv: StringView) -> Union[int, ...@@ -229,10 +229,14 @@ def sv_index(absolute_index: Union[int, Iterable], sv: StringView) -> Union[int,
>>> sv_index(match.end(), sv) >>> sv_index(match.end(), sv)
1 1
""" """
try: return absolute_index - sv.begin
return absolute_index - sv.begin
except TypeError:
return tuple(index - sv.begin for index in absolute_index) def sv_indices(absolute_indices: Iterable[int], sv: StringView) -> Tuple[int]:
"""Converts the an index into string watched by a StringView object
to an index relativ to the string view object. See also: `sv_index()`
"""
return tuple(index - sv.begin for index in absolute_indices)
def sv_search(regex, sv: StringView): def sv_search(regex, sv: StringView):
...@@ -366,7 +370,8 @@ def md5(*txt): ...@@ -366,7 +370,8 @@ def md5(*txt):
return md5_hash.hexdigest() return md5_hash.hexdigest()
def smart_list(arg) -> list: # def smart_list(arg: Union[str, Iterable[T]]) -> Union[Sequence[str], Sequence[T]]:
def smart_list(arg: Union[str, Iterable, Any]) -> Sequence:
"""Returns the argument as list, depending on its type and content. """Returns the argument as list, depending on its type and content.
If the argument is a string, it will be interpreted as a list of If the argument is a string, it will be interpreted as a list of
...@@ -402,9 +407,9 @@ def smart_list(arg) -> list: ...@@ -402,9 +407,9 @@ def smart_list(arg) -> list:
if len(lst) > 1: if len(lst) > 1:
return [s.strip() for s in lst] return [s.strip() for s in lst]
return [s.strip() for s in arg.strip().split(' ')] return [s.strip() for s in arg.strip().split(' ')]
elif isinstance(arg, collections.abc.Container): elif isinstance(arg, Sequence):
return arg return arg
elif isinstance(arg, collections.abc.Iterable): elif isinstance(arg, Iterable):
return list(arg) return list(arg)
else: else:
return [arg] return [arg]
......
...@@ -82,7 +82,8 @@ __all__ = ('TransformationDict', ...@@ -82,7 +82,8 @@ __all__ = ('TransformationDict',
TransformationProc = Callable[[List[Node]], None] TransformationProc = Callable[[List[Node]], None]
TransformationDict = Dict TransformationDict = Dict[str, Sequence[Callable]]
ProcessingTableType = Dict[str, Union[Sequence[Callable], TransformationDict]]
ConditionFunc = Callable # Callable[[List[Node]], bool] ConditionFunc = Callable # Callable[[List[Node]], bool]
KeyFunc = Callable[[Node], str] KeyFunc = Callable[[Node], str]
...@@ -172,7 +173,7 @@ def key_tag_name(node: Node) -> str: ...@@ -172,7 +173,7 @@ def key_tag_name(node: Node) -> str:
def traverse(root_node: Node, def traverse(root_node: Node,
processing_table: Dict[str, List[Callable]], processing_table: ProcessingTableType,
key_func: KeyFunc=key_tag_name) -> None: key_func: KeyFunc=key_tag_name) -> None:
""" """
Traverses the snytax tree starting with the given ``node`` depth Traverses the snytax tree starting with the given ``node`` depth
...@@ -216,7 +217,7 @@ def traverse(root_node: Node, ...@@ -216,7 +217,7 @@ def traverse(root_node: Node,
# into lists with a single value # into lists with a single value
table = {name: smart_list(call) for name, call in list(processing_table.items())} table = {name: smart_list(call) for name, call in list(processing_table.items())}
table = expand_table(table) table = expand_table(table)
cache = table.setdefault('__cache__', {}) # type: Dict[str, List[Callable]] cache = table.setdefault('__cache__', cast(TransformationDict, dict()))
# change processing table in place, so its already expanded and cache filled next time # change processing table in place, so its already expanded and cache filled next time
processing_table.clear() processing_table.clear()
processing_table.update(table) processing_table.update(table)
...@@ -278,13 +279,13 @@ def replace_child(node: Node): ...@@ -278,13 +279,13 @@ def replace_child(node: Node):
node.children[0].parser.name = node.parser.name node.children[0].parser.name = node.parser.name
node.parser = node.children[0].parser node.parser = node.children[0].parser
node._errors.extend(node.children[0]._errors) node._errors.extend(node.children[0]._errors)
node.result = node.result[0].result node.result = node.children[0].result
def reduce_child(node: Node): def reduce_child(node: Node):
assert len(node.children) == 1 assert len(node.children) == 1
node._errors.extend(node.children[0]._errors) node._errors.extend(node.children[0]._errors)
node.result = node.result[0].result node.result = node.children[0].result
@transformation_factory(Callable) @transformation_factory(Callable)
...@@ -320,7 +321,7 @@ def reduce_single_child(context: List[Node], condition: Callable=TRUE_CONDITION) ...@@ -320,7 +321,7 @@ def reduce_single_child(context: List[Node], condition: Callable=TRUE_CONDITION)
def is_named(context: List[Node]) -> bool: def is_named(context: List[Node]) -> bool:
return context[-1].parser.name return bool(context[-1].parser.name)
def is_anonymous(context: List[Node]) -> bool: def is_anonymous(context: List[Node]) -> bool:
...@@ -376,7 +377,7 @@ def flatten(context: List[Node], condition: Callable=is_anonymous, recursive: bo ...@@ -376,7 +377,7 @@ def flatten(context: List[Node], condition: Callable=is_anonymous, recursive: bo
""" """
node = context[-1] node = context[-1]
if node.children: if node.children:
new_result = [] new_result = [] # type: List[Node]
for child in node.children: for child in node.children:
context.append(child) context.append(child)
if child.children and condition(context): if child.children and condition(context):
...@@ -405,7 +406,7 @@ def merge_children(context: List[Node], tag_names: List[str]): ...@@ -405,7 +406,7 @@ def merge_children(context: List[Node], tag_names: List[str]):
names into a single child node with a mock-parser with the name of names into a single child node with a mock-parser with the name of
the first tag-name in the list. the first tag-name in the list.
""" """
node = context node = context[-1]
result = [] result = []
name, ptype = ('', tag_names[0]) if tag_names[0][:1] == ':' else (tag_names[0], '') name, ptype = ('', tag_names[0]) if tag_names[0][:1] == ':' else (tag_names[0], '')
if node.children: if node.children:
...@@ -421,7 +422,8 @@ def merge_children(context: List[Node], tag_names: List[str]): ...@@ -421,7 +422,8 @@ def merge_children(context: List[Node], tag_names: List[str]):
k += 1 k += 1
if i < L: if i < L:
result.append(Node(MockParser(name, ptype), result.append(Node(MockParser(name, ptype),
reduce(lambda a, b: a + b, (node.result for node in node.children[i:k])))) reduce(lambda a, b: a + b,
(node.children for node in node.children[i:k]))))
i = k i = k
node.result = tuple(result) node.result = tuple(result)
...@@ -558,7 +560,7 @@ def remove_content(context: List[Node], regexp: str): ...@@ -558,7 +560,7 @@ def remove_content(context: List[Node], regexp: str):
######################################################################## ########################################################################
@transformation_factory(Callable) @transformation_factory(Callable)
def assert_condition(context: List[Node], condition: Callable, error_msg: str = '') -> bool: def assert_condition(context: List[Node], condition: Callable, error_msg: str = ''):
"""Checks for `condition`; adds an error message if condition is not met.""" """Checks for `condition`; adds an error message if condition is not met."""
node = context[-1] node = context[-1]
if not condition(context): if not condition(context):
......
...@@ -354,9 +354,6 @@ class TestFlowControlOperators: ...@@ -354,9 +354,6 @@ class TestFlowControlOperators:
SUCC_LB = indirection SUCC_LB = indirection
indirection = /\s*?\n/ indirection = /\s*?\n/
""" """
# result, messages, syntax_tree = compile_source(lang, None, get_ebnf_grammar(),
# get_ebnf_transformer(), get_ebnf_compiler('LookbehindTest'))
# print(result)
parser = grammar_provider(lang)() parser = grammar_provider(lang)()
cst = parser(self.t1) cst = parser(self.t1)
assert not cst.error_flag, cst.as_sxpr() assert not cst.error_flag, cst.as_sxpr()
......
...@@ -82,7 +82,6 @@ class TestInfiLoopsAndRecursion: ...@@ -82,7 +82,6 @@ class TestInfiLoopsAndRecursion:
assert not syntax_tree.error_flag, syntax_tree.collect_errors() assert not syntax_tree.error_flag, syntax_tree.collect_errors()
snippet = "7 + 8 * 4" snippet = "7 + 8 * 4"
syntax_tree = parser(snippet) syntax_tree = parser(snippet)
# print(syntax_tree.as_sxpr())
assert not syntax_tree.error_flag, syntax_tree.collect_errors() assert not syntax_tree.error_flag, syntax_tree.collect_errors()
snippet = "9 + 8 * (4 + 3)" snippet = "9 + 8 * (4 + 3)"
syntax_tree = parser(snippet) syntax_tree = parser(snippet)
...@@ -95,7 +94,6 @@ class TestInfiLoopsAndRecursion: ...@@ -95,7 +94,6 @@ class TestInfiLoopsAndRecursion:
parser = grammar_provider(minilang)() parser = grammar_provider(minilang)()
syntax_tree = parser(snippet) syntax_tree = parser(snippet)
assert syntax_tree.error_flag assert syntax_tree.error_flag
# print(syntax_tree.collect_errors())
class TestFlowControl: class TestFlowControl:
......
...@@ -97,7 +97,6 @@ class TestNode: ...@@ -97,7 +97,6 @@ class TestNode:
transform = get_ebnf_transformer() transform = get_ebnf_transformer()
compiler = get_ebnf_compiler() compiler = get_ebnf_compiler()
tree = parser(ebnf) tree = parser(ebnf)
print(tree.as_sxpr())
tree_copy = copy.deepcopy(tree) tree_copy = copy.deepcopy(tree)
transform(tree_copy) transform(tree_copy)
res1 = compiler(tree_copy) res1 = compiler(tree_copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment