Commit 7aa81606 authored by Eckhart Arnold's avatar Eckhart Arnold
Browse files

- syntaxtree.py: general json serialization

parent 6c6d63ce
......@@ -45,12 +45,13 @@ For JSON see:
import asyncio
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, CancelledError
import json
from multiprocessing import Process, Value, Queue
from multiprocessing import Process, Queue, Value, Array
import sys
import time
from typing import Callable, Coroutine, Optional, Union, Dict, List, Tuple, Sequence, Set, cast
from DHParser.toolkit import get_config_value, is_filename, load_if_file, re
from DHParser.syntaxtree import Node_JSONEncoder
from DHParser.toolkit import get_config_value, re
__all__ = ('RPC_Table',
'RPC_Type',
......@@ -67,7 +68,7 @@ RPC_Table = Dict[str, Callable]
RPC_Type = Union[RPC_Table, List[Callable], Callable]
JSON_Type = Union[Dict, Sequence, str, int, None]
RE_IS_JSON = b'\s*(?:{|\[|"|\d|true|false|null)'
RE_IS_JSONRPC = b'\s*{' # b'\s*(?:{|\[|"|\d|true|false|null)'
RE_GREP_URL = b'GET ([^ \n]+) HTTP'
SERVER_ERROR = "COMPILER-SERVER-ERROR"
......@@ -173,11 +174,14 @@ class Server:
assert not (self.blocking - self.rpc_table.keys())
self.max_source_size = get_config_value('max_rpc_size') #type: int
self.stage = Value('b', SERVER_OFFLINE) # type: Value
self.server = None # type: Optional[asyncio.AbstractServer]
self.server_messages = Queue() # type: Queue
self.server_process = None # type: Optional[Process]
# shared variables
self.stage = Value('b', SERVER_OFFLINE) # type: Value
self.host = Array('c', b' ' * 1024) # type: Array
self.port = Value('H', 0) # type: Value
# if the server is run in a separate process, the following variables
# should only be accessed from the server process
self.server = None # type: Optional[asyncio.AbstractServer]
......@@ -204,6 +208,33 @@ class Server:
response = RESPONSE_HEADER.format(date=gmt, length=len(encoded_html))
return response.encode() + encoded_html
async def run(method_name: str, method: Callable, params: Union[Dict, Sequence]) \
-> Tuple[JSON_Type, Optional[Tuple[int, str]]]:
nonlocal result, rpc_error
try:
# run method either a) directly if it is short running or
# b) in a thread pool if it contains blocking io or
# c) in a process pool if it is cpu bound
# see: https://docs.python.org/3/library/asyncio-eventloop.html
# #executing-code-in-thread-or-process-pools
has_kw_params = isinstance(params, Dict)
assert has_kw_params or isinstance(params, Sequence)
loop = asyncio.get_running_loop()
executor = self.pp_executor if method_name in self.cpu_bound else \
self.tp_executor if method_name in self.blocking else None
if executor is None:
result = method(**params) if has_kw_params else method(*params)
elif has_kw_params:
result = await loop.run_in_executor(executor, method, **params)
else:
result = await loop.run_in_executor(executor, method, *params)
except TypeError as e:
rpc_error = -32602, "Invalid Params: " + str(e)
except NameError as e:
rpc_error = -32601, "Method not found: " + str(e)
except Exception as e:
rpc_error = -32000, "Server Error: " + str(e)
if data.startswith(b'GET'):
# HTTP request
m = re.match(RE_GREP_URL, data)
......@@ -217,15 +248,19 @@ class Server:
func = self.rpc_table.get(func_name,
lambda _: UNKNOWN_FUNC_HTML.format(func=func_name))
result = func(argument) if argument is not None else func()
if isinstance(result, str):
writer.write(http_response(result))
await run(func.__name__, func, (argumnet,) if argument else ())
if rpc_error is None:
if isinstance(result, str):
writer.write(http_response(result))
else:
writer.write(http_response(json.dumps(result, indent=2)))
else:
writer.write(http_response(json.dumps(result, indent=2)))
writer.write(http_response(rpc_error[1]))
elif not re.match(RE_IS_JSON, data):
elif not re.match(RE_IS_JSONRPC, data):
# plain data
if oversized:
writer.write("Source code too large! Only %i MB allowed" \
writer.write("Source code too large! Only %i MB allowed"
% (self.max_source_size // (1024 ** 2)))
elif data == STOP_SERVER_REQUEST:
writer.write(self.stop_response.encode())
......@@ -236,11 +271,15 @@ class Server:
else:
err = lambda arg: 'function "compile_src" not registered!'
func = self.rpc_table.get('compile_src', self.rpc_table.get('compile', err))
result = func(data.decode())
if isinstance(result, str):
writer.write(result.encode())
# result = func(data.decode())
await run(func.__name__, func, (data.decode(),))
if rpc_error is None:
if isinstance(result, str):
writer.write(result.encode())
else:
writer.write(json.dumps(result).encode())
else:
writer.write(json.dumps(result).encode())
writer.write(rpc_error[1].encode())
else:
# JSON RPC
......@@ -276,33 +315,11 @@ class Server:
method_name = obj['method']
method = self.rpc_table[method_name]
params = obj['params'] if 'params' in obj else ()
try:
# run method either a) directly if it is short running or
# b) in a thread pool if it contains blocking io or
# c) in a process pool if it is cpu bound
# see: https://docs.python.org/3/library/asyncio-eventloop.html
# #executing-code-in-thread-or-process-pools
has_kw_params = isinstance(params, Dict)
assert has_kw_params or isinstance(params, Sequence)
loop = asyncio.get_running_loop()
executor = self.pp_executor if method_name in self.cpu_bound else \
self.tp_executor if method_name in self.blocking else None
if executor is None:
result = method(**params) if has_kw_params else method(*params)
elif has_kw_params:
result = await loop.run_in_executor(executor, method, **params)
else:
result = await loop.run_in_executor(executor, method, *params)
except TypeError as e:
rpc_error = -32602, "Invalid Params: " + str(e)
except NameError as e:
rpc_error = -32601, "Method not found: " + str(e)
except Exception as e:
rpc_error = -32000, "Server Error: " + str(e)
await run(method_name, method, params)
if rpc_error is None:
json_result = {"jsonrpc": "2.0", "result": result, "id": json_id}
writer.write(json.dumps(json_result).encode())
writer.write(json.dumps(json_result, cls=Node_JSONEncoder).encode())
else:
writer.write(('{"jsonrpc": "2.0", "error": {"code": %i, "message": "%s"}, "id": %s}'
% (rpc_error[0], rpc_error[1], json_id)).encode())
......@@ -317,6 +334,8 @@ class Server:
self.pp_executor = p
self.tp_executor = t
self.stop_response = "DHParser server at {}:{} stopped!".format(host, port)
self.host.value = host.encode()
self.port.value = port
self.server = cast(asyncio.base_events.Server,
await asyncio.start_server(self.handle_request, host, port))
async with self.server:
......@@ -341,6 +360,7 @@ class Server:
except CancelledError:
self.pp_executor = None
self.tt_exectuor = None
asyncio_run(self.server.wait_closed())
self.server_messages.put(SERVER_OFFLINE)
self.stage.value = SERVER_OFFLINE
......@@ -365,11 +385,26 @@ class Server:
self.server_process.start()
self.wait_until_server_online()
def terminate_server_process(self):
async def termination_request(self):
try:
host, port = self.host.value.decode(), self.port.value
reader, writer = await asyncio.open_connection(host, port)
writer.write(STOP_SERVER_REQUEST)
await reader.read(500)
while self.stage.value != SERVER_OFFLINE \
and self.server_messages.get() != SERVER_OFFLINE:
pass
writer.close()
except ConnectionRefusedError:
pass
def terminate_server(self):
"""
Terminates the server process.
"""
try:
if self.stage.value in (SERVER_STARTING, SERVER_ONLINE):
asyncio_run(self.termination_request())
if self.server_process and self.server_process.is_alive():
if self.stage.value in (SERVER_STARTING, SERVER_ONLINE):
self.stage.value = SERVER_TERMINATE
......@@ -390,4 +425,4 @@ class Server:
if self.stage.value in (SERVER_STARTING, SERVER_ONLINE, SERVER_TERMINATE):
while self.server_messages.get() != SERVER_OFFLINE:
pass
self.terminate_server_process()
self.terminate_server()
......@@ -43,10 +43,11 @@ __all__ = ('WHITESPACE_PTYPE',
'StrictResultType',
'ChildrenType',
'Node',
'serialize',
'FrozenNode',
'tree_sanity_check',
'RootNode',
'Node_JSONEncoder',
'node_obj_hook',
'parse_sxpr',
'parse_xml',
'parse_json_syntaxtree',
......@@ -193,7 +194,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
self.result = result
self.tag_name = tag_name # type: str
def __deepcopy__(self, memo):
if self.children:
duplicate = self.__class__(self.tag_name, copy.deepcopy(self.children), False)
......@@ -205,7 +205,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
# duplicate._xml_attr = copy.deepcopy(self._xml_attr) # this is not cython compatible
return duplicate
def __str__(self):
if isinstance(self, RootNode):
root = cast(RootNode, self)
......@@ -217,7 +216,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
(content[e_pos - self.pos:], '; '.join(e.message for e in errors))
return self.content
def __repr__(self):
# mpargs = {'name': self.parser.name, 'ptype': self.parser.ptype}
# name, ptype = (self._tag_name.split(':') + [''])[:2]
......@@ -226,12 +224,10 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
"(" + ", ".join(child.__repr__() for child in self.children) + ")"
return "Node(%s, %s)" % (self.tag_name, rarg)
def __len__(self):
return (sum(len(child) for child in self.children)
if self.children else len(self._result))
def __bool__(self):
"""Returns the bool value of a node, which is always True. The reason
for this is that a boolean test on a variable that can contain a node
......@@ -239,11 +235,9 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
"""
return True
def __hash__(self):
return hash(self.tag_name)
def __getitem__(self, index_or_tagname: Union[int, str]) -> Union['Node', Iterator['Node']]:
"""
Returns the child node with the given index if ``index_or_tagname`` is
......@@ -271,7 +265,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
raise KeyError(index_or_tagname)
raise ValueError('Leave nodes have no children that can be indexed!')
def __contains__(self, tag_name: str) -> bool:
"""
Returns true if a child with the given tag name exists.
......@@ -290,7 +283,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
return False
raise ValueError('Leave node cannot contain other nodes')
def equals(self, other: 'Node') -> bool:
"""
Equality of value: Two nodes are considered as having the same value,
......@@ -308,7 +300,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
return self.result == other.result
return False
def get(self, index_or_tagname: Union[int, str],
surrogate: Union['Node', Iterator['Node']]) -> Union['Node', Iterator['Node']]:
"""Returns the child node with the given index if ``index_or_tagname``
......@@ -322,7 +313,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
except KeyError:
return surrogate
def is_anonymous(self) -> bool:
"""Returns True, if the Node is an "anonymous" Node, i.e. a node that
has not been created by a named parser.
......@@ -334,6 +324,7 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
"""
return not self.tag_name or self.tag_name[0] == ':'
## node content
@property
def result(self) -> StrictResultType:
......@@ -344,7 +335,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
"""
return self._result
@result.setter
def result(self, result: ResultType):
# # made obsolete by static type checking with mypy
......@@ -366,7 +356,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
self.children = NoChildren
self._result = result # cast(StrictResultType, result)
def _content(self) -> List[str]:
"""
Returns string content as list of string fragments
......@@ -380,7 +369,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
self._result = str(self._result)
return [self._result]
@property
def content(self) -> str:
"""
......@@ -399,6 +387,7 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
# return "".join(child.content for child in self.children) if self.children \
# else str(self._result)
## node position
@property
def pos(self) -> int:
......@@ -407,7 +396,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
raise AssertionError("Position value not initialized! Use Node.with_pos()")
return self._pos
def with_pos(self, pos: int) -> 'Node':
"""
Initialize position value. Usually, the parser guard
......@@ -434,6 +422,23 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
offset = child.pos + len(child)
return self
## (XML-)attributes
def has_attr(self) -> bool:
"""
Returns `True`, if the node has any attributes, `False` otherwise.
This function does not create an attribute dictionary, therefore
it should be preferred to querying node.attr when testing for the
existence of any attributes.
"""
try:
# if self._xml_attr is not None:
# return True
return bool(self._xml_attr)
except AttributeError:
pass
return False
@property
def attr(self):
......@@ -465,24 +470,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
self._xml_attr = OrderedDict()
return self._xml_attr
def has_attr(self) -> bool:
"""
Returns `True`, if the node has any attributes, `False` otherwise.
This function does not create an attribute dictionary, therefore
it should be prefered to querying node.attr when testing for the
existence of any attributes.
"""
try:
# if self._xml_attr is not None:
# return True
return bool(self._xml_attr)
except AttributeError:
pass
return False
def compare_attr(self, other: 'Node') -> bool:
"""
Returns True, if `self` and `other` have the same attributes with the
......@@ -498,6 +485,7 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
# other has empty attribute dictionary and self as no attributes
return True # neither self nor other have any attributes
## tree traversal and node selection
def select(self, match_function: Callable, include_root: bool = False, reverse: bool = False) \
-> Iterator['Node']:
......@@ -532,7 +520,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
# for child in child_iterator:
# yield from child.select(match_function, True, reverse)
def select_by_tag(self, tag_names: Union[str, AbstractSet[str]],
include_root: bool = False) -> Iterator['Node']:
"""
......@@ -565,7 +552,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
tag_names = frozenset({tag_names})
return self.select(lambda node: node.tag_name in tag_names, include_root)
def pick(self, tag_names: Union[str, Set[str]]) -> Optional['Node']:
"""
Picks the first descendant with one of the given tag_names.
......@@ -580,16 +566,7 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
except StopIteration:
return None
def tree_size(self) -> int:
"""
Recursively counts the number of nodes in the tree including the root node.
"""
return sum(child.tree_size() for child in self.children) + 1
#
# serialization methods
#
## serialization methods
def _tree_repr(self, tab, open_fn, close_fn, data_fn=lambda i: i,
density=0, inline=False, inline_fn=lambda node: False) -> str:
......@@ -652,7 +629,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
else:
return head + '\n'.join([usetab + data_fn(s) for s in res.split('\n')]) + tail
def as_sxpr(self, src: Optional[str] = None,
indentation: int = 2,
compact: bool = False,
......@@ -709,7 +685,6 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
sxpr = self._tree_repr(' ' * indentation, opening, closing, pretty, density=density)
return sxpr if compact else flatten_sxpr(sxpr, flatten_threshold)
def as_xml(self, src: str = None,
indentation: int = 2,
inline_tags: Set[str] = frozenset(),
......@@ -783,36 +758,40 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
return self._tree_repr(' ' * indentation, opening, closing, sanitizer,
density=1, inline_fn=inlining)
## JSON reading and writing
def to_json_obj(self) -> Dict:
"""Serialize a node or tree as json-object"""
data = [self.tag_name,
[child.to_json_obj() for child in self.children]
if self.children else str(self._result)]
has_attr = self.has_attr()
if self._pos >= 0 or has_attr:
data.append(self._pos)
if has_attr:
data.append(dict(self._xml_attr))
return {'__class__': 'DHParser.Node', 'data': data}
json = {'__class__': 'DHParser.Node', 'tag_name': self.tag_name }
if self.children:
json['result'] = [child.to_json_obj() for child in self.children]
else:
json['result'] = str(self._result)
if self.has_attr():
json['attr'] = dict(self._xml_attr)
if self._pos >= 0:
json['pos'] = self._pos
return json
@staticmethod
def from_json_obj(json_obj: Dict) -> 'Node':
"""Convert a json object representing a node (or tree) back into a
Node object. Raises a ValueError, if `json_obj` does not represent
a node."""
assert isinstance(json_obj, Dict)
if json_obj.get('__class__', '') != 'DHParser.Node':
raise ValueError('JSON object: ' + str(json_obj) +
' does not represent a Node object.')
tag_name, result, pos, attr = (json_obj['data'] + [-1, None])[:4]
tag_name = json_obj['tag_name']
result = json_obj['result']
if isinstance(result, str):
leafhint = True
else:
leafhint = False
result = tuple(Node.from_json_obj(child) for child in result)
node = Node(tag_name, result, leafhint)
node._pos = pos
node._pos = json_obj.get('pos', -1)
attr = json_obj.get('attr', {})
if attr:
node.attr.update(attr)
return node
......@@ -821,35 +800,36 @@ class Node: # (collections.abc.Sized): Base class omitted for cython-compatibil
return json.dumps(self.to_json_obj(), indent=indent, ensure_ascii=ensure_ascii,
separators=(', ', ': ') if indent is not None else (',', ':'))
## generalized serialization methoed
def serialize(node: Node, how: str = 'default') -> str:
"""
Serializes the tree starting with `node` either as S-expression, XML, JSON,
or in compact form. Possible values for `how` are 'S-expression',
'XML', 'JSON', 'compact' accordingly, or 'AST', 'CST', 'default' in which case
the value of respective configuration variable determines the
serialization format. (See module `configuration.py`.)
"""
switch = how.lower()
if switch == 'ast':
switch = get_config_value('ast_serialization').lower()
elif switch == 'cst':
switch = get_config_value('cst_serialization').lower()
elif switch == 'default':
switch = get_config_value('default_serialization').lower()
if switch == SXPRESSION_SERIALIZATION.lower():
return node.as_sxpr(flatten_threshold=get_config_value('flatten_sxpr_threshold'))
elif switch == XML_SERIALIZATION.lower():
return node.as_xml()
elif switch == JSON_SERIALIZATION.lower():
return node.as_json()
elif switch == COMPACT_SERIALIZATION.lower():
return node.as_sxpr(compact=True)
else:
raise ValueError('Unknown serialization %s. Allowed values are either: %s or : %s'
% (how, "'ast', 'cst', 'default'", ", ".join(list(SERIALIZATIONS))))
def serialize_as(self: 'Node', how: str = 'default') -> str:
"""
Serializes the tree starting with `node` either as S-expression, XML, JSON,
or in compact form. Possible values for `how` are 'S-expression',
'XML', 'JSON', 'compact' accordingly, or 'AST', 'CST', 'default' in which case
the value of respective configuration variable determines the
serialization format. (See module `configuration.py`.)
"""
switch = how.lower()
if switch == 'ast':
switch = get_config_value('ast_serialization').lower()
elif switch == 'cst':
switch = get_config_value('cst_serialization').lower()
elif switch == 'default':
switch = get_config_value('default_serialization').lower()
if switch == SXPRESSION_SERIALIZATION.lower():
return self.as_sxpr(flatten_threshold=get_config_value('flatten_sxpr_threshold'))
elif switch == XML_SERIALIZATION.lower():
return self.as_xml()
elif switch == JSON_SERIALIZATION.lower():
return self.as_json()
elif switch == COMPACT_SERIALIZATION.lower():
return self.as_sxpr(compact=True)
else:
raise ValueError('Unknown serialization %s. Allowed values are either: %s or : %s'
% (how, "'ast', 'cst', 'default'", ", ".join(list(SERIALIZATIONS))))
class FrozenNode(Node):
......@@ -857,9 +837,11 @@ class FrozenNode(Node):
FrozenNode is an immutable kind of Node, i.e. it must not be changed
after initialization. The purpose is mainly to allow certain kinds of
optimization, like not having to instantiate empty nodes (because they
are always the same and will be dropped while parsing, anyway).
are always the same and will be dropped while parsing, anyway) or,
rather, throw errors if the program tries to treat a node that is
supposed to be a temporary (frozen) node as if it was a regular node.
Frozen nodes must be used only temporarily during parsing or
Frozen nodes must only be used temporarily during parsing or
tree-transformation and should not occur in the product of the
transformation any more. This can be verified with `tree_sanity_check()`.
"""
......@@ -1066,6 +1048,20 @@ class RootNode(Node):
empty_tags=self.empty_tags)
class Node_JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Node):
return cast(Node, obj).to_json_obj()
return json.JSONEncoder.default(self, obj)
def node_obj_hook(dct):
cls = dct.get('__class__', '')
if cls == "DHParser.Node":
return Node.from_json_obj(dct)
return dct
#######################################################################
#
# S-expression- and XML-parsers and JSON-reader
......
......@@ -40,7 +40,7 @@ from typing import Dict, List, Union, cast
from DHParser.error import Error, is_error, adjust_error_locations
from DHParser.log import log_dir, logging, is_logging, clear_logs, log_parsing_history
from DHParser.parse import UnknownParserError, Parser, Lookahead
from DHParser.syntaxtree import Node, RootNode, parse_tree, flatten_sxpr, serialize, ZOMBIE_TAG
from DHParser.syntaxtree import Node, RootNode, parse_tree, flatten_sxpr, ZOMBIE_TAG
from DHParser.toolkit import GLOBALS, get_config_value, load_if_file, re
......@@ -269,10 +269,10 @@ def get_report(test_unit):
cst = tests.get('__cst__', {}).get(test_name, None)
if cst and (not ast or str(test_name).endswith('*')):
report.append('\n### CST')
report.append(indent(serialize(cst, 'cst')))
report.append(indent(cst.serialize_as('cst')))
if ast:
report.append('\n### AST')
report.append(indent(serialize(ast, 'ast')))
report.append(indent(ast.serialize_as('ast')))
for test_name, test_code in tests.get('fail', dict()).items():