The name of the initial branch for new projects is now "main" instead of "master". Existing projects remain unchanged. More information: https://doku.lrz.de/display/PUBLIC/GitLab

Commit c74091ca authored by Eckhart Arnold's avatar Eckhart Arnold
Browse files

- added type annotations for better documentation and mypy type checks

parent 4589c6b6
...@@ -20,17 +20,18 @@ compilation of domain specific languages based on an EBNF-grammar. ...@@ -20,17 +20,18 @@ compilation of domain specific languages based on an EBNF-grammar.
""" """
import os import os
try: try:
import regex as re import regex as re
except ImportError: except ImportError:
import re import re
from typing import Any, Tuple, cast
from .ebnf import EBNFTransformer, grammar_changed, \ from DHParser.ebnf import EBNFTransformer, EBNFCompiler, grammar_changed, \
get_ebnf_scanner, get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler get_ebnf_scanner, get_ebnf_grammar, get_ebnf_transformer, get_ebnf_compiler, \
from .toolkit import logging, load_if_file, is_python_code, compile_python_object ScannerFactoryFunc, ParserFactoryFunc, TransformerFactoryFunc, CompilerFactoryFunc
from .parsers import Grammar, CompilerBase, compile_source, nil_scanner from DHParser.toolkit import logging, load_if_file, is_python_code, compile_python_object
from .syntaxtree import Node from DHParser.parsers import Grammar, Compiler, compile_source, nil_scanner, ScannerFunc
from DHParser.syntaxtree import Node, TransformerFunc
__all__ = ['GrammarError', __all__ = ['GrammarError',
...@@ -71,7 +72,7 @@ try: ...@@ -71,7 +72,7 @@ try:
except ImportError: except ImportError:
import re import re
from DHParser.toolkit import logging, is_filename, load_if_file from DHParser.toolkit import logging, is_filename, load_if_file
from DHParser.parsers import Grammar, CompilerBase, nil_scanner, \\ from DHParser.parsers import Grammar, Compiler, nil_scanner, \\
Lookbehind, Lookahead, Alternative, Pop, Required, Token, \\ Lookbehind, Lookahead, Alternative, Pop, Required, Token, \\
Optional, NegativeLookbehind, OneOrMore, RegExp, Retrieve, Sequence, RE, Capture, \\ Optional, NegativeLookbehind, OneOrMore, RegExp, Retrieve, Sequence, RE, Capture, \\
ZeroOrMore, Forward, NegativeLookahead, mixin_comment, compile_source, \\ ZeroOrMore, Forward, NegativeLookahead, mixin_comment, compile_source, \\
...@@ -137,7 +138,7 @@ class CompilationError(Exception): ...@@ -137,7 +138,7 @@ class CompilationError(Exception):
return '\n'.join(self.error_messages) return '\n'.join(self.error_messages)
def grammar_instance(grammar_representation): def grammar_instance(grammar_representation) -> Tuple[Grammar, str]:
"""Returns a grammar object and the source code of the grammar, from """Returns a grammar object and the source code of the grammar, from
the given `grammar`-data which can be either a file name, ebnf-code, the given `grammar`-data which can be either a file name, ebnf-code,
python-code, a Grammar-derived grammar class or an instance of python-code, a Grammar-derived grammar class or an instance of
...@@ -167,7 +168,11 @@ def grammar_instance(grammar_representation): ...@@ -167,7 +168,11 @@ def grammar_instance(grammar_representation):
return parser_root, grammar_src return parser_root, grammar_src
def compileDSL(text_or_file, scanner, dsl_grammar, ast_transformation, compiler): def compileDSL(text_or_file: str,
scanner: ScannerFunc,
dsl_grammar: Grammar,
ast_transformation: TransformerFunc,
compiler: Compiler) -> Any:
"""Compiles a text in a domain specific language (DSL) with an """Compiles a text in a domain specific language (DSL) with an
EBNF-specified grammar. Returns the compiled text or raises a EBNF-specified grammar. Returns the compiled text or raises a
compilation error. compilation error.
...@@ -176,10 +181,10 @@ def compileDSL(text_or_file, scanner, dsl_grammar, ast_transformation, compiler) ...@@ -176,10 +181,10 @@ def compileDSL(text_or_file, scanner, dsl_grammar, ast_transformation, compiler)
CompilationError if any errors occurred during compilation CompilationError if any errors occurred during compilation
""" """
assert isinstance(text_or_file, str) assert isinstance(text_or_file, str)
assert isinstance(compiler, CompilerBase) assert isinstance(compiler, Compiler)
parser_root, grammar_src = grammar_instance(dsl_grammar) parser, grammar_src = grammar_instance(dsl_grammar)
result, errors, AST = compile_source(text_or_file, scanner, parser_root, result, errors, AST = compile_source(text_or_file, scanner, parser,
ast_transformation, compiler) ast_transformation, compiler)
if errors: if errors:
src = load_if_file(text_or_file) src = load_if_file(text_or_file)
...@@ -187,7 +192,7 @@ def compileDSL(text_or_file, scanner, dsl_grammar, ast_transformation, compiler) ...@@ -187,7 +192,7 @@ def compileDSL(text_or_file, scanner, dsl_grammar, ast_transformation, compiler)
return result return result
def raw_compileEBNF(ebnf_src, branding="DSL"): def raw_compileEBNF(ebnf_src: str, branding="DSL") -> EBNFCompiler:
"""Compiles an EBNF grammar file and returns the compiler object """Compiles an EBNF grammar file and returns the compiler object
that was used and which can now be queried for the result as well that was used and which can now be queried for the result as well
as skeleton code for scanner, transformer and compiler objects. as skeleton code for scanner, transformer and compiler objects.
...@@ -208,7 +213,7 @@ def raw_compileEBNF(ebnf_src, branding="DSL"): ...@@ -208,7 +213,7 @@ def raw_compileEBNF(ebnf_src, branding="DSL"):
return compiler return compiler
def compileEBNF(ebnf_src, branding="DSL"): def compileEBNF(ebnf_src: str, branding="DSL") -> str:
"""Compiles an EBNF source file and returns the source code of a """Compiles an EBNF source file and returns the source code of a
compiler suite with skeletons for scanner, transformer and compiler suite with skeletons for scanner, transformer and
compiler. compiler.
...@@ -234,7 +239,7 @@ def compileEBNF(ebnf_src, branding="DSL"): ...@@ -234,7 +239,7 @@ def compileEBNF(ebnf_src, branding="DSL"):
return '\n'.join(src) return '\n'.join(src)
def parser_factory(ebnf_src, branding="DSL"): def parser_factory(ebnf_src: str, branding="DSL") -> Grammar:
"""Compiles an EBNF grammar and returns a grammar-parser factory """Compiles an EBNF grammar and returns a grammar-parser factory
function for that grammar. function for that grammar.
...@@ -253,7 +258,8 @@ def parser_factory(ebnf_src, branding="DSL"): ...@@ -253,7 +258,8 @@ def parser_factory(ebnf_src, branding="DSL"):
return compile_python_object(DHPARSER_IMPORTS + grammar_src, 'get_(?:\w+_)?grammar$') return compile_python_object(DHPARSER_IMPORTS + grammar_src, 'get_(?:\w+_)?grammar$')
def load_compiler_suite(compiler_suite): def load_compiler_suite(compiler_suite: str) -> \
Tuple[ScannerFactoryFunc, ParserFactoryFunc, TransformerFactoryFunc, CompilerFactoryFunc]:
"""Extracts a compiler suite from file or string ``compiler suite`` """Extracts a compiler suite from file or string ``compiler suite``
and returns it as a tuple (scanner, parser, ast, compiler). and returns it as a tuple (scanner, parser, ast, compiler).
...@@ -282,13 +288,14 @@ def load_compiler_suite(compiler_suite): ...@@ -282,13 +288,14 @@ def load_compiler_suite(compiler_suite):
if errors: if errors:
raise GrammarError('\n\n'.join(errors), source) raise GrammarError('\n\n'.join(errors), source)
scanner = get_ebnf_scanner scanner = get_ebnf_scanner
parser = get_ebnf_grammar
ast = get_ebnf_transformer ast = get_ebnf_transformer
compiler = compile_python_object(imports + compiler_py, 'get_(?:\w+_)?compiler$') compiler = compile_python_object(imports + compiler_py, 'get_(?:\w+_)?compiler$')
return scanner, parser, ast, compiler return scanner, parser, ast, compiler
def is_outdated(compiler_suite, grammar_source): def is_outdated(compiler_suite: str, grammar_source: str) -> bool:
"""Returns ``True`` if the ``compile_suite`` needs to be updated. """Returns ``True`` if the ``compile_suite`` needs to be updated.
An update is needed, if either the grammar in the compieler suite An update is needed, if either the grammar in the compieler suite
...@@ -313,7 +320,7 @@ def is_outdated(compiler_suite, grammar_source): ...@@ -313,7 +320,7 @@ def is_outdated(compiler_suite, grammar_source):
return True return True
def run_compiler(text_or_file, compiler_suite): def run_compiler(text_or_file: str, compiler_suite: str) -> Any:
"""Compiles a source with a given compiler suite. """Compiles a source with a given compiler suite.
Args: Args:
...@@ -336,7 +343,7 @@ def run_compiler(text_or_file, compiler_suite): ...@@ -336,7 +343,7 @@ def run_compiler(text_or_file, compiler_suite):
return compileDSL(text_or_file, scanner(), parser(), ast(), compiler()) return compileDSL(text_or_file, scanner(), parser(), ast(), compiler())
def compile_on_disk(source_file, compiler_suite="", extension=".xml"): def compile_on_disk(source_file: str, compiler_suite="", extension=".xml"):
"""Compiles the a source file with a given compiler and writes the """Compiles the a source file with a given compiler and writes the
result to a file. result to a file.
...@@ -373,18 +380,20 @@ def compile_on_disk(source_file, compiler_suite="", extension=".xml"): ...@@ -373,18 +380,20 @@ def compile_on_disk(source_file, compiler_suite="", extension=".xml"):
rootname = os.path.splitext(filepath)[0] rootname = os.path.splitext(filepath)[0]
compiler_name = os.path.basename(rootname) compiler_name = os.path.basename(rootname)
if compiler_suite: if compiler_suite:
scanner, parser, trans, cfactory = load_compiler_suite(compiler_suite) sfactory, pfactory, tfactory, cfactory = load_compiler_suite(compiler_suite)
else: else:
scanner = get_ebnf_scanner sfactory = get_ebnf_scanner
parser = get_ebnf_grammar pfactory = get_ebnf_grammar
trans = get_ebnf_transformer tfactory = get_ebnf_transformer
cfactory = get_ebnf_compiler cfactory = get_ebnf_compiler
compiler1 = cfactory(compiler_name, source_file) compiler1 = cfactory()
result, errors, ast = compile_source(source_file, scanner(), parser(), trans(), compiler1) compiler1.set_grammar_name(compiler_name, source_file)
result, errors, ast = compile_source(source_file, sfactory(), pfactory(), tfactory(), compiler1)
if errors: if errors:
return errors return errors
elif cfactory == get_ebnf_compiler: # trans == get_ebnf_transformer or trans == EBNFTransformer: # either an EBNF- or no compiler suite given elif cfactory == get_ebnf_compiler: # trans == get_ebnf_transformer or trans == EBNFTransformer: # either an EBNF- or no compiler suite given
ebnf_compiler = cast(EBNFCompiler, compiler1)
global SECTION_MARKER, RX_SECTION_MARKER, SCANNER_SECTION, PARSER_SECTION, \ global SECTION_MARKER, RX_SECTION_MARKER, SCANNER_SECTION, PARSER_SECTION, \
AST_SECTION, COMPILER_SECTION, END_SECTIONS_MARKER, RX_WHITESPACE, \ AST_SECTION, COMPILER_SECTION, END_SECTIONS_MARKER, RX_WHITESPACE, \
DHPARSER_MAIN, DHPARSER_IMPORTS DHPARSER_MAIN, DHPARSER_IMPORTS
...@@ -412,11 +421,11 @@ def compile_on_disk(source_file, compiler_suite="", extension=".xml"): ...@@ -412,11 +421,11 @@ def compile_on_disk(source_file, compiler_suite="", extension=".xml"):
if RX_WHITESPACE.fullmatch(imports): if RX_WHITESPACE.fullmatch(imports):
imports = DHPARSER_IMPORTS imports = DHPARSER_IMPORTS
if RX_WHITESPACE.fullmatch(scanner): if RX_WHITESPACE.fullmatch(scanner):
scanner = compiler1.gen_scanner_skeleton() scanner = ebnf_compiler.gen_scanner_skeleton()
if RX_WHITESPACE.fullmatch(ast): if RX_WHITESPACE.fullmatch(ast):
ast = compiler1.gen_transformer_skeleton() ast = ebnf_compiler.gen_transformer_skeleton()
if RX_WHITESPACE.fullmatch(compiler): if RX_WHITESPACE.fullmatch(compiler):
compiler = compiler1.gen_compiler_skeleton() compiler = ebnf_compiler.gen_compiler_skeleton()
try: try:
f = open(rootname + 'Compiler.py', 'w', encoding="utf-8") f = open(rootname + 'Compiler.py', 'w', encoding="utf-8")
...@@ -441,6 +450,7 @@ def compile_on_disk(source_file, compiler_suite="", extension=".xml"): ...@@ -441,6 +450,7 @@ def compile_on_disk(source_file, compiler_suite="", extension=".xml"):
if f: f.close() if f: f.close()
else: else:
f = None
try: try:
f = open(rootname + extension, 'w', encoding="utf-8") f = open(rootname + extension, 'w', encoding="utf-8")
if isinstance(result, Node): if isinstance(result, Node):
......
...@@ -18,19 +18,20 @@ permissions and limitations under the License. ...@@ -18,19 +18,20 @@ permissions and limitations under the License.
import keyword import keyword
from functools import partial from functools import partial
try: try:
import regex as re import regex as re
except ImportError: except ImportError:
import re import re
from typing import Callable, cast, List, Set, Tuple
from .toolkit import load_if_file, escape_re, md5, sane_parser_name from DHParser.toolkit import load_if_file, escape_re, md5, sane_parser_name
from .parsers import Grammar, mixin_comment, nil_scanner, Forward, RE, NegativeLookahead, \ from DHParser.parsers import Grammar, mixin_comment, nil_scanner, Forward, RE, NegativeLookahead, \
Alternative, Sequence, Optional, Required, OneOrMore, ZeroOrMore, Token, CompilerBase Alternative, Sequence, Optional, Required, OneOrMore, ZeroOrMore, Token, Compiler, \
from .syntaxtree import Node, traverse, remove_enclosing_delimiters, reduce_single_child, \ ScannerFunc
from DHParser.syntaxtree import Node, traverse, remove_enclosing_delimiters, reduce_single_child, \
replace_by_single_child, TOKEN_PTYPE, remove_expendables, remove_tokens, flatten, \ replace_by_single_child, TOKEN_PTYPE, remove_expendables, remove_tokens, flatten, \
forbid, assert_content, WHITESPACE_PTYPE, key_tag_name forbid, assert_content, WHITESPACE_PTYPE, key_tag_name, TransformerFunc
from .versionnumber import __version__ from DHParser.versionnumber import __version__
__all__ = ['get_ebnf_scanner', __all__ = ['get_ebnf_scanner',
...@@ -41,7 +42,11 @@ __all__ = ['get_ebnf_scanner', ...@@ -41,7 +42,11 @@ __all__ = ['get_ebnf_scanner',
'EBNFTransformer', 'EBNFTransformer',
'EBNFCompilerError', 'EBNFCompilerError',
'EBNFCompiler', 'EBNFCompiler',
'grammar_changed'] 'grammar_changed',
'ScannerFactoryFunc',
'ParserFactoryFunc',
'TransformerFactoryFunc',
'CompilerFactoryFunc']
######################################################################## ########################################################################
...@@ -51,7 +56,7 @@ __all__ = ['get_ebnf_scanner', ...@@ -51,7 +56,7 @@ __all__ = ['get_ebnf_scanner',
######################################################################## ########################################################################
def get_ebnf_scanner(): def get_ebnf_scanner() -> ScannerFunc:
return nil_scanner return nil_scanner
...@@ -137,7 +142,7 @@ class EBNFGrammar(Grammar): ...@@ -137,7 +142,7 @@ class EBNFGrammar(Grammar):
root__ = syntax root__ = syntax
def grammar_changed(grammar_class, grammar_source): def grammar_changed(grammar_class, grammar_source: str) -> bool:
"""Returns ``True`` if ``grammar_class`` does not reflect the latest """Returns ``True`` if ``grammar_class`` does not reflect the latest
changes of ``grammar_source`` changes of ``grammar_source``
...@@ -168,7 +173,7 @@ def grammar_changed(grammar_class, grammar_source): ...@@ -168,7 +173,7 @@ def grammar_changed(grammar_class, grammar_source):
return chksum != grammar_class.source_hash__ return chksum != grammar_class.source_hash__
def get_ebnf_grammar(): def get_ebnf_grammar() -> EBNFGrammar:
global thread_local_ebnf_grammar_singleton global thread_local_ebnf_grammar_singleton
try: try:
grammar = thread_local_ebnf_grammar_singleton grammar = thread_local_ebnf_grammar_singleton
...@@ -223,13 +228,13 @@ EBNF_validation_table = { ...@@ -223,13 +228,13 @@ EBNF_validation_table = {
} }
def EBNFTransformer(syntax_tree): def EBNFTransformer(syntax_tree: Node):
for processing_table, key_func in [(EBNF_transformation_table, key_tag_name), for processing_table, key_func in [(EBNF_transformation_table, key_tag_name),
(EBNF_validation_table, key_tag_name)]: (EBNF_validation_table, key_tag_name)]:
traverse(syntax_tree, processing_table, key_func) traverse(syntax_tree, processing_table, key_func)
def get_ebnf_transformer(): def get_ebnf_transformer() -> TransformerFunc:
return EBNFTransformer return EBNFTransformer
...@@ -239,6 +244,13 @@ def get_ebnf_transformer(): ...@@ -239,6 +244,13 @@ def get_ebnf_transformer():
# #
######################################################################## ########################################################################
ScannerFactoryFunc = Callable[[], ScannerFunc]
ParserFactoryFunc = Callable[[], Grammar]
TransformerFactoryFunc = Callable[[], TransformerFunc]
CompilerFactoryFunc = Callable[[], Compiler]
SCANNER_FACTORY = ''' SCANNER_FACTORY = '''
def get_scanner(): def get_scanner():
return {NAME}Scanner return {NAME}Scanner
...@@ -283,7 +295,7 @@ class EBNFCompilerError(Exception): ...@@ -283,7 +295,7 @@ class EBNFCompilerError(Exception):
pass pass
class EBNFCompiler(CompilerBase): class EBNFCompiler(Compiler):
"""Generates a Parser from an abstract syntax tree of a grammar specified """Generates a Parser from an abstract syntax tree of a grammar specified
in EBNF-Notation. in EBNF-Notation.
""" """
...@@ -305,13 +317,13 @@ class EBNFCompiler(CompilerBase): ...@@ -305,13 +317,13 @@ class EBNFCompiler(CompilerBase):
self._reset() self._reset()
def _reset(self): def _reset(self):
self._result = None self._result = '' # type: str
self.rules = set() self.rules = set() # type: Set[str]
self.variables = set() self.variables = set() # type: Set[str]
self.symbol_nodes = [] self.symbol_nodes = [] # type: List[Node]
self.definition_names = [] self.definition_names = [] # type: List[str]
self.recursive = set() self.recursive = set() # type: Set[str]
self.root = "" self.root = "" # type: str
self.directives = {'whitespace': self.WHITESPACE['horizontal'], self.directives = {'whitespace': self.WHITESPACE['horizontal'],
'comment': '', 'comment': '',
'literalws': ['right'], 'literalws': ['right'],
...@@ -319,15 +331,15 @@ class EBNFCompiler(CompilerBase): ...@@ -319,15 +331,15 @@ class EBNFCompiler(CompilerBase):
'filter': dict()} # alt. 'retrieve_filter' 'filter': dict()} # alt. 'retrieve_filter'
@property @property
def result(self): def result(self) -> str:
return self._result return self._result
def gen_scanner_skeleton(self): def gen_scanner_skeleton(self) -> str:
name = self.grammar_name + "Scanner" name = self.grammar_name + "Scanner"
return "def %s(text):\n return text\n" % name \ return "def %s(text):\n return text\n" % name \
+ SCANNER_FACTORY.format(NAME=self.grammar_name) + SCANNER_FACTORY.format(NAME=self.grammar_name)
def gen_transformer_skeleton(self): def gen_transformer_skeleton(self) -> str:
if not self.definition_names: if not self.definition_names:
raise EBNFCompilerError('Compiler must be run before calling ' raise EBNFCompilerError('Compiler must be run before calling '
'"gen_transformer_Skeleton()"!') '"gen_transformer_Skeleton()"!')
...@@ -343,11 +355,11 @@ class EBNFCompiler(CompilerBase): ...@@ -343,11 +355,11 @@ class EBNFCompiler(CompilerBase):
transtable += [TRANSFORMER_FACTORY.format(NAME=self.grammar_name)] transtable += [TRANSFORMER_FACTORY.format(NAME=self.grammar_name)]
return '\n'.join(transtable) return '\n'.join(transtable)
def gen_compiler_skeleton(self): def gen_compiler_skeleton(self) -> str:
if not self.definition_names: if not self.definition_names:
raise EBNFCompilerError('Compiler has not been run before calling ' raise EBNFCompilerError('Compiler has not been run before calling '
'"gen_Compiler_Skeleton()"!') '"gen_Compiler_Skeleton()"!')
compiler = ['class ' + self.grammar_name + 'Compiler(CompilerBase):', compiler = ['class ' + self.grammar_name + 'Compiler(Compiler):',
' """Compiler for the abstract-syntax-tree of a ' + ' """Compiler for the abstract-syntax-tree of a ' +
self.grammar_name + ' source file.', self.grammar_name + ' source file.',
' """', '', ' """', '',
...@@ -357,23 +369,23 @@ class EBNFCompiler(CompilerBase): ...@@ -357,23 +369,23 @@ class EBNFCompiler(CompilerBase):
'Compiler, self).__init__(grammar_name, grammar_source)', 'Compiler, self).__init__(grammar_name, grammar_source)',
" assert re.match('\w+\Z', grammar_name)", ''] " assert re.match('\w+\Z', grammar_name)", '']
for name in self.definition_names: for name in self.definition_names:
method_name = CompilerBase.derive_method_name(name) method_name = Compiler.derive_method_name(name)
if name == self.root: if name == self.root:
compiler += [' def ' + method_name + '(self, node):', compiler += [' def ' + method_name + '(self, node: Node) -> str:',
' return node', ''] ' return node', '']
else: else:
compiler += [' def ' + method_name + '(self, node):', compiler += [' def ' + method_name + '(self, node: Node) -> str:',
' pass', ''] ' pass', '']
compiler += [COMPILER_FACTORY.format(NAME=self.grammar_name)] compiler += [COMPILER_FACTORY.format(NAME=self.grammar_name)]
return '\n'.join(compiler) return '\n'.join(compiler)
def assemble_parser(self, definitions, root_node): def assemble_parser(self, definitions: List[Tuple[str, str]], root_node: Node) -> str:
# fix capture of variables that have been defined before usage [sic!] # fix capture of variables that have been defined before usage [sic!]
if self.variables: if self.variables:
for i in range(len(definitions)): for i in range(len(definitions)):
if definitions[i][0] in self.variables: if definitions[i][0] in self.variables:
definitions[i] = (definitions[i][0], 'Capture(%s)' % definitions[1]) definitions[i] = (definitions[i][0], 'Capture(%s)' % definitions[i][1])
self.definition_names = [defn[0] for defn in definitions] self.definition_names = [defn[0] for defn in definitions]
definitions.append(('wspR__', self.WHITESPACE_KEYWORD definitions.append(('wspR__', self.WHITESPACE_KEYWORD
...@@ -434,27 +446,27 @@ class EBNFCompiler(CompilerBase): ...@@ -434,27 +446,27 @@ class EBNFCompiler(CompilerBase):
+ GRAMMAR_FACTORY.format(NAME=self.grammar_name) + GRAMMAR_FACTORY.format(NAME=self.grammar_name)
return self._result return self._result
def on_syntax(self, node): def on_syntax(self, node: Node) -> str:
self._reset() self._reset()
definitions = [] definitions = []
# drop the wrapping sequence node # drop the wrapping sequence node
if len(node.children) == 1 and not node.result[0].parser.name: if len(node.children) == 1 and not node.children[0].parser.name:
node = node.result[0] node = node.children[0]
# compile definitions and directives and collect definitions # compile definitions and directives and collect definitions
for nd in node.result: for nd in node.children:
if nd.parser.name == "definition": if nd.parser.name == "definition":
definitions.append(self._compile(nd)) definitions.append(self._compile(nd))
else: else:
assert nd.parser.name == "directive", nd.as_sexpr() assert nd.parser.name == "directive", nd.as_sexpr()
self._compile(nd) self._compile(nd)
node.error_flag |= nd.error_flag node.error_flag = node.error_flag or nd.error_flag
return self.assemble_parser(definitions, node) return self.assemble_parser(definitions, node)
def on_definition(self, node): def on_definition(self, node: Node) -> Tuple[str, str]:
rule = node.result[0].result rule = cast(str, node.children[0].result)
if rule in self.rules: if rule in self.rules:
node.add_error('A rule with name "%s" has already been defined.' % rule) node.add_error('A rule with name "%s" has already been defined.' % rule)
elif rule in EBNFCompiler.RESERVED_SYMBOLS: elif rule in EBNFCompiler.RESERVED_SYMBOLS:
...@@ -470,7 +482,7 @@ class EBNFCompiler(CompilerBase): ...@@ -470,7 +482,7 @@ class EBNFCompiler(CompilerBase):
% rule + '(This may change in the furute.)') % rule + '(This may change in the furute.)')
try: try:
self.rules.add(rule) self.rules.add(rule)
defn = self._compile(node.result[1]) defn = self._compile(node.children[1])
if rule in self.variables: if rule in self.variables:
defn = 'Capture(%s)' % defn defn = 'Capture(%s)' % defn
self.variables.remove(rule) self.variables.remove(rule)
...@@ -481,7 +493,7 @@ class EBNFCompiler(CompilerBase): ...@@ -481,7 +493,7 @@ class EBNFCompiler(CompilerBase):
return rule, defn return rule, defn
@staticmethod @staticmethod
def _check_rx(node, rx): def _check_rx(node: Node, rx: str) -> str:
"""Checks whether the string `rx` represents a valid regular """Checks whether the string `rx` represents a valid regular
expression. Makes sure that multiline regular expressions are expression. Makes sure that multiline regular expressions are
prepended by the multiline-flag. Returns the regular expression string. prepended by the multiline-flag. Returns the regular expression string.
...@@ -494,22 +506,22 @@ class EBNFCompiler(CompilerBase): ...@@ -494,22 +506,22 @@ class EBNFCompiler(CompilerBase):
(repr(rx), str(re_error))) (repr(rx), str(re_error)))
return rx return rx
def on_directive(self, node): def on_directive(self, node: Node) -> str:
key = node.result[0].result.lower() key = cast(str, node.children[0].result).lower()
assert key not in self.directives['tokens'] assert key not in self.directives['tokens']
if key in {'comment', 'whitespace'}: </