server.py 6.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# server.py - an asynchronous tcp-server to compile sources with DHParser
#
# Copyright 2019  by Eckhart Arnold (arnold@badw.de)
#                 Bavarian Academy of Sciences an Humanities (badw.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.  See the License for the specific language governing
# permissions and limitations under the License.


"""
Module `server` contains an asynchronous tcp-server that receives compilation
requests, runs custom compilation functions in a multiprocessing.Pool.

This allows to start a DHParser-compilation environment just once and save the
startup time of DHParser for each subsequent compilation. In particular, with
a just-in-time-compiler like PyPy (https://pypy.org) setting up a
compilation-server is highly recommended, because jit-compilers typically
sacrifice startup-speed for running-speed.

It is up to the compilation function to either return the result of the
compilation in serialized form, or just save the compilation results on the
file system an merely return an success or failure message. Module `server`
does not define any of these message. This is completely up to the clients
of module `server`, i.e. the compilation-modules, to decide.
34
35
36

The communication, i.e. requests and responses, follows the json-rpc protocol
(https://www.jsonrpc.org/specification)
37
38
39
40
"""


import asyncio
41
import json
eckhart's avatar
eckhart committed
42
from multiprocessing import Process, Value, Queue
43
from typing import Callable, Optional, Union, Dict, List, Sequence, cast
44
45

from DHParser.toolkit import get_config_value
46

47
48
RPC_Table = Dict[str, Callable]
RPC_Type = Union[RPC_Table, List[Callable], Callable]
49

50
51
SERVER_ERROR = "COMPILER-SERVER-ERROR"

eckhart's avatar
eckhart committed
52
53
54
55
SERVER_OFFLINE = 0
SERVER_STARTING = 1
SERVER_ONLINE = 2
SERVER_TERMINATE = 3
eckhart's avatar
eckhart committed
56
57


58
59
60
61
62
63
64
65
66
67
68
69
class Server:
    def __init__(self, rpc_functions: RPC_Type):
        if isinstance(rpc_functions, Dict):
            self.rpc_table = cast(RPC_Table, rpc_functions)  # type: RPC_Table
        elif isinstance(rpc_functions, List):
            self.rpc_table = {}
            for func in cast(List, rpc_functions):
                self.rpc_table[func.__name__] = func
        else:
            assert isinstance(rpc_functions, Callable)
            func = cast(Callable, rpc_functions)
            self.rpc_table = { func.__name__: func }
eckhart's avatar
eckhart committed
70

71
        self.max_source_size = get_config_value('max_rpc_size')
eckhart's avatar
eckhart committed
72
        self.stage = Value('b', SERVER_OFFLINE)
73
        self.server = None  # type: Optional[asyncio.base_events.Server]
eckhart's avatar
eckhart committed
74
75
        self.server_messages = Queue()  # type: Queue
        self.server_process = None  # type: Optional[Process]
76

77
    async def handle_compilation_request(self,
eckhart's avatar
eckhart committed
78
79
                                         reader: asyncio.StreamReader,
                                         writer: asyncio.StreamWriter):
80
81
        data = await reader.read(self.max_source_size + 1)
        if len(data) > self.max_source_size:
82
83
84
            writer.write('{"jsonrpc": "2.0", "error": {"code": -32600, "message": '
                         '"Invaild Request: Source code too large! Only %i MB allowed"}, '
                         '"id": null}' % (self.max_source_size // (1024**2)))
85
        else:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            obj = json.loads(data)
            rpc_error = None
            json_id = obj.get('id', 'null') if isinstance(obj, Dict) else 'null'
            if not isinstance(obj, Dict):
                rpc_error = -32700, 'Parse error: Request does not appear to be an RPC-call!?'
            elif obj.get('jsonrpc', 'unknown') != '2.0':
                rpc_error = -32600, 'Invalid Request: jsonrpc version 2.0 needed, version "%s" ' \
                            'found.' % obj.get('jsonrpc', 'unknown')
            elif not 'method' in obj:
                rpc_error = -32600, 'Invalid Request: No method specified.'
            elif obj['method'] not in self.rpc_table:
                rpc_error = -32601, 'Method not found: ' + str(obj['method'])
            else:
                method = self.rpc_table[obj['method']]
                params = obj['params'] if 'params' in obj else ()
                try:
                    if isinstance(params, Sequence):
                        result = method(*params)
                    elif isinstance(params, Dict):
                        result = method(**params)
                except Exception as e:
                    rpc_error = -32602, "Invalid Params: " + str(e)

            if rpc_error is None:
                json_result = {"jsonrpc": "2.0", "result": result, "id": json_id}
                json.dump(writer, json_result)
            else:
                writer.write(b'{"jsonrpc": "2.0", "error": {"code": %i, "message": %s}, "id": %s '
                             % (rpc_error[0], rpc_error[1], json_id))
115
116
        await writer.drain()
        writer.close()
117
118
119
        # TODO: add these lines in case a terminate signal is received, i.e. exit server coroutine
        #  gracefully.
        # self.server.cancel()
120

eckhart's avatar
eckhart committed
121
    async def serve(self, address: str = '127.0.0.1', port: int = 8888):
122
123
124
        self.server = await asyncio.start_server(self.handle_compilation_request, address, port)
        print(type(self.server))
        async with self.server:
eckhart's avatar
eckhart committed
125
126
            self.stage.value = SERVER_ONLINE
            self.server_messages.put(SERVER_ONLINE)
127
128
            await self.server.serve_forever()
        # self.server.wait_until_closed()
129

eckhart's avatar
eckhart committed
130
    def run_server(self, address: str = '127.0.0.1', port: int = 8888):
eckhart's avatar
eckhart committed
131
        self.stage.value = SERVER_STARTING
di68kap's avatar
di68kap committed
132
        asyncio.run(self.serve(address, port))
eckhart's avatar
eckhart committed
133
134

    def wait_until_server_online(self):
eckhart's avatar
eckhart committed
135
136
137
138
139
140
141
142
143
144
145
146
147
        if self.stage.value != SERVER_ONLINE:
            if self.server_messages.get() != SERVER_ONLINE:
                raise AssertionError('could not start server!?')
            assert self.stage.value == SERVER_ONLINE

    def run_as_process(self):
        self.server_process = Process(target=self.run_server)
        self.server_process.start()
        self.wait_until_server_online()

    def terminate_server_process(self):
        self.server_process.terminate()

148
    def wait_for_termination_request(self):
eckhart's avatar
eckhart committed
149
150
151
152
153
154
        assert self.server_process
        # self.wait_until_server_online()
        while self.server_messages.get() != SERVER_TERMINATE:
            pass
        self.terminate_server_process()
        self.server_process = None