server.py 6.53 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
# server.py - an asynchronous tcp-server to compile sources with DHParser
#
# Copyright 2019  by Eckhart Arnold (arnold@badw.de)
#                 Bavarian Academy of Sciences an Humanities (badw.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.  See the License for the specific language governing
# permissions and limitations under the License.


"""
Module `server` contains an asynchronous tcp-server that receives compilation
requests, runs custom compilation functions in a multiprocessing.Pool.

This allows to start a DHParser-compilation environment just once and save the
startup time of DHParser for each subsequent compilation. In particular, with
a just-in-time-compiler like PyPy (https://pypy.org) setting up a
compilation-server is highly recommended, because jit-compilers typically
sacrifice startup-speed for running-speed.

It is up to the compilation function to either return the result of the
compilation in serialized form, or just save the compilation results on the
file system an merely return an success or failure message. Module `server`
does not define any of these message. This is completely up to the clients
of module `server`, i.e. the compilation-modules, to decide.
34 35 36

The communication, i.e. requests and responses, follows the json-rpc protocol
(https://www.jsonrpc.org/specification)
37 38 39 40
"""


import asyncio
41
import json
eckhart's avatar
eckhart committed
42
from multiprocessing import Process, Value, Queue
43
from typing import Callable, Optional, Union, Dict, List, Sequence, cast
44 45

from DHParser.toolkit import get_config_value
46

47 48
RPC_Table = Dict[str, Callable]
RPC_Type = Union[RPC_Table, List[Callable], Callable]
49

50 51
SERVER_ERROR = "COMPILER-SERVER-ERROR"

eckhart's avatar
eckhart committed
52 53 54 55
SERVER_OFFLINE = 0
SERVER_STARTING = 1
SERVER_ONLINE = 2
SERVER_TERMINATE = 3
eckhart's avatar
eckhart committed
56 57


58 59 60 61 62 63 64 65 66 67 68 69
class Server:
    def __init__(self, rpc_functions: RPC_Type):
        if isinstance(rpc_functions, Dict):
            self.rpc_table = cast(RPC_Table, rpc_functions)  # type: RPC_Table
        elif isinstance(rpc_functions, List):
            self.rpc_table = {}
            for func in cast(List, rpc_functions):
                self.rpc_table[func.__name__] = func
        else:
            assert isinstance(rpc_functions, Callable)
            func = cast(Callable, rpc_functions)
            self.rpc_table = { func.__name__: func }
eckhart's avatar
eckhart committed
70

71
        self.max_source_size = get_config_value('max_rpc_size')
eckhart's avatar
eckhart committed
72
        self.stage = Value('b', SERVER_OFFLINE)
73
        self.server = None  # type: Optional[asyncio.base_events.Server]
eckhart's avatar
eckhart committed
74 75
        self.server_messages = Queue()  # type: Queue
        self.server_process = None  # type: Optional[Process]
76

77
    async def handle_compilation_request(self,
eckhart's avatar
eckhart committed
78 79
                                         reader: asyncio.StreamReader,
                                         writer: asyncio.StreamWriter):
80 81
        data = await reader.read(self.max_source_size + 1)
        if len(data) > self.max_source_size:
82 83 84
            writer.write('{"jsonrpc": "2.0", "error": {"code": -32600, "message": '
                         '"Invaild Request: Source code too large! Only %i MB allowed"}, '
                         '"id": null}' % (self.max_source_size // (1024**2)))
85
        else:
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            obj = json.loads(data)
            rpc_error = None
            json_id = obj.get('id', 'null') if isinstance(obj, Dict) else 'null'
            if not isinstance(obj, Dict):
                rpc_error = -32700, 'Parse error: Request does not appear to be an RPC-call!?'
            elif obj.get('jsonrpc', 'unknown') != '2.0':
                rpc_error = -32600, 'Invalid Request: jsonrpc version 2.0 needed, version "%s" ' \
                            'found.' % obj.get('jsonrpc', 'unknown')
            elif not 'method' in obj:
                rpc_error = -32600, 'Invalid Request: No method specified.'
            elif obj['method'] not in self.rpc_table:
                rpc_error = -32601, 'Method not found: ' + str(obj['method'])
            else:
                method = self.rpc_table[obj['method']]
                params = obj['params'] if 'params' in obj else ()
                try:
                    if isinstance(params, Sequence):
                        result = method(*params)
                    elif isinstance(params, Dict):
                        result = method(**params)
                except Exception as e:
                    rpc_error = -32602, "Invalid Params: " + str(e)

            if rpc_error is None:
                json_result = {"jsonrpc": "2.0", "result": result, "id": json_id}
                json.dump(writer, json_result)
            else:
                writer.write(b'{"jsonrpc": "2.0", "error": {"code": %i, "message": %s}, "id": %s '
                             % (rpc_error[0], rpc_error[1], json_id))
115 116
        await writer.drain()
        writer.close()
117 118 119
        # TODO: add these lines in case a terminate signal is received, i.e. exit server coroutine
        #  gracefully.
        # self.server.cancel()
120

eckhart's avatar
eckhart committed
121
    async def serve(self, address: str = '127.0.0.1', port: int = 8888):
122 123 124
        self.server = await asyncio.start_server(self.handle_compilation_request, address, port)
        print(type(self.server))
        async with self.server:
eckhart's avatar
eckhart committed
125 126
            self.stage.value = SERVER_ONLINE
            self.server_messages.put(SERVER_ONLINE)
127 128
            await self.server.serve_forever()
        # self.server.wait_until_closed()
129

eckhart's avatar
eckhart committed
130
    def run_server(self, address: str = '127.0.0.1', port: int = 8888):
eckhart's avatar
eckhart committed
131
        self.stage.value = SERVER_STARTING
di68kap's avatar
di68kap committed
132
        asyncio.run(self.serve(address, port))
eckhart's avatar
eckhart committed
133 134

    def wait_until_server_online(self):
eckhart's avatar
eckhart committed
135 136 137 138 139 140 141 142 143 144 145 146 147
        if self.stage.value != SERVER_ONLINE:
            if self.server_messages.get() != SERVER_ONLINE:
                raise AssertionError('could not start server!?')
            assert self.stage.value == SERVER_ONLINE

    def run_as_process(self):
        self.server_process = Process(target=self.run_server)
        self.server_process.start()
        self.wait_until_server_online()

    def terminate_server_process(self):
        self.server_process.terminate()

148
    def wait_for_termination_request(self):
eckhart's avatar
eckhart committed
149 150 151 152 153 154
        assert self.server_process
        # self.wait_until_server_online()
        while self.server_messages.get() != SERVER_TERMINATE:
            pass
        self.terminate_server_process()
        self.server_process = None