create.py 7.99 KB
Newer Older
1
#!/usr/bin/env python3
2
import abc
3
import multiprocessing
4
import os
5
import warnings
6
from distutils.dir_util import copy_tree
7

8
import suqc.request  # no "from suqc.request import ..." works because of circular imports
9
from suqc.environment import AbstractEnvironmentManager, VadereEnvironmentManager
10 11
from suqc.parameter.postchanges import PostScenarioChangesBase
from suqc.parameter.sampling import ParameterVariationBase
12
from suqc.utils.dict_utils import change_dict, change_dict_ini, deep_dict_lookup
13
from suqc.utils.general import create_folder, njobs_check_and_set, remove_folder
14 15


16 17 18 19 20 21 22
class AbstractScenarioCreation(object):
    def __init__(
        self,
        env_man: AbstractEnvironmentManager,
        parameter_variation: ParameterVariationBase,
        post_change: PostScenarioChangesBase = None,
    ):
23
        self._env_man = env_man
24 25
        self._parameter_variation = parameter_variation
        self._post_changes = post_change
26
        self._sampling_check_selected_keys()
27

28 29 30
    @abc.abstractmethod
    def _sampling_check_selected_keys(self):
        raise NotImplemented
31

32 33 34
    @abc.abstractmethod
    def _sp_creation(self):
        raise NotImplemented
35

36 37 38
    @abc.abstractmethod
    def _mp_creation(self, njobs):
        raise NotImplemented
39

40 41
    # public methods
    def generate_scenarios(self, njobs):
42

43 44
        ntasks = self._parameter_variation.points.shape[0]
        njobs = njobs_check_and_set(njobs=njobs, ntasks=ntasks)
45

46 47 48 49 50 51
        # increases readability and promotes shorter paths (apparently lengthy paths can cause problems on Windows)
        # see issue #76
        self._adapt_nr_digits_env_man(
            nr_variations=self._parameter_variation.nr_parameter_variations(),
            nr_runs=self._parameter_variation.nr_scenario_runs(),
        )
52

53
        target_path = self._env_man.get_env_outputfolder_path()
54

55 56 57
        # For security:
        remove_folder(target_path)
        create_folder(target_path)
58

59 60 61 62 63 64 65 66 67 68 69
        if njobs == 1:
            request_item_list = self._sp_creation()
        else:
            request_item_list = self._mp_creation(njobs)

        return request_item_list

    # private methods
    def _adapt_nr_digits_env_man(self, nr_variations, nr_runs):
        self._env_man.nr_digits_variation = len(str(nr_variations))
        self._env_man.nr_digits_runs = len(str(nr_runs))
70

71
    ## vadere specific
72

73 74 75
    def _create_vadere_scenario(
        self, args
    ):  # TODO: how do multiple arguments work for pool.map functions? (see below)
76
        """Set up a new scenario and return info of parameter id and location."""
77 78 79
        parameter_id = args[0]  # TODO: this would kind of reduce this ugly code
        run_id = args[1]
        parameter_variation = args[2]
80

81 82 83
        par_var_scenario = change_dict(
            self._env_man.vadere_basis_scenario, changes=parameter_variation
        )
84

85 86 87 88 89 90 91 92 93 94
        if self._post_changes is not None:
            # Apply pre-defined changes to each scenario file
            new_scenario = self._post_changes.change_scenario(
                scenario=par_var_scenario,
                parameter_id=parameter_id,
                run_id=run_id,
                parameter_variation=parameter_variation,
            )
        else:
            new_scenario = par_var_scenario
95

96 97 98 99 100 101 102 103 104 105 106 107 108
        output_folder = self._env_man.get_variation_output_folder(parameter_id, run_id)
        self._print_scenario_warnings(new_scenario)
        scenario_path = self._env_man.save_scenario_variation(
            parameter_id, run_id, new_scenario
        )

        result_item = suqc.request.RequestItem(
            parameter_id=parameter_id,
            run_id=run_id,
            scenario_path=scenario_path,
            base_path=self._env_man.base_path,
            output_folder=output_folder,
        )
109
        return result_item
110

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    def _print_scenario_warnings(self, scenario):
        try:
            real_time_sim_time_ratio, _ = deep_dict_lookup(
                scenario, "realTimeSimTimeRatio"
            )
        except Exception:
            real_time_sim_time_ratio = (
                0  # ignore this warning if the lookup failed for whatever reason.
            )

        if real_time_sim_time_ratio > 1e-14:
            warnings.warn(
                f"In a scenario the key 'realTimeSimTimeRatio={real_time_sim_time_ratio}'. Large values "
                f"slow down the evaluation speed."
            )

    ## omnet specific
    def _create_omnet_scenario(
        self, args
    ):  # TODO: how do multiple arguments work for pool.map functions? (see below)
        """Set up a new scenario and return info of parameter id and location."""
        parameter_id = args[0]  # TODO: this would kind of reduce this ugly code
        run_id = args[1]
        parameter_variation = args[2]

        par_var_scenario = change_dict_ini(
            self._env_man.omnet_basis_ini, changes=parameter_variation
        )
        output_path = self._env_man.scenario_variation_path(
            parameter_id, run_id, simulator="omnet"
        )

        with open(output_path, "w") as outfile:
            par_var_scenario.writer(outfile)

        folder = os.path.dirname(output_path)
        ini_path = os.path.join(self._env_man.env_path, "additional_rover_files")
        copy_tree(ini_path, folder)


class VadereScenarioCreation(AbstractScenarioCreation):
    def __init__(
        self,
        env_man: AbstractEnvironmentManager,
        parameter_variation: ParameterVariationBase,
        post_change: PostScenarioChangesBase = None,
    ):
        super().__init__(env_man, parameter_variation, post_change)

160 161
    def _sp_creation(self):
        """Single process loop to create all requested scenarios."""
162 163
        request_item_list = list()
        for par_id, run_id, par_change in self._parameter_variation.par_iter():
164 165 166
            request_item_list.append(
                self._create_vadere_scenario([par_id, run_id, par_change])
            )
167
        return request_item_list
168 169 170 171

    def _mp_creation(self, njobs):
        """Multi process function to create all requested scenarios."""
        pool = multiprocessing.Pool(processes=njobs)
172 173 174
        request_item_list = pool.map(
            self._create_vadere_scenario, self._parameter_variation.par_iter()
        )
175 176
        return request_item_list

177 178
    def _sampling_check_selected_keys(self):
        self._parameter_variation.check_vadere_keys(self._env_man.vadere_basis_scenario)
179 180


181 182 183 184 185 186 187 188
class CoupledScenarioCreation(AbstractScenarioCreation):
    def __init__(
        self,
        env_man: AbstractEnvironmentManager,
        parameter_variation: ParameterVariationBase,
        post_change: PostScenarioChangesBase = None,
    ):
        super().__init__(env_man, parameter_variation, post_change)
189

190 191
    def _sp_creation(self):
        """Single process loop to create all requested scenarios."""
192

193 194 195 196
        # omnet specific
        variations_omnet = self._parameter_variation.par_iter(simulator="omnet")
        for par_id, run_id, par_change in variations_omnet:
            self._create_omnet_scenario([par_id, run_id, par_change])
197

198 199 200 201 202 203 204
        # vadere specific
        request_item_list = list()
        variations_vadere = self._parameter_variation.par_iter(simulator="vadere")
        for par_id, run_id, par_change in variations_vadere:
            request_item_list.append(
                self._create_vadere_scenario([par_id, run_id, par_change])
            )
205

206
        return request_item_list
207

208 209 210 211 212 213 214 215 216
    def _mp_creation(self, njobs):
        """Multi process function to create all requested scenarios."""
        pool = multiprocessing.Pool(processes=njobs)

        variations_omnet = self._parameter_variation.par_iter(simulator="omnet")
        pool.map(self._create_omnet_scenario, variations_omnet)

        variations_vadere = self._parameter_variation.par_iter(simulator="vadere")
        request_item_list = pool.map(self._create_vadere_scenario, variations_vadere)
217
        return request_item_list
218 219 220 221 222

    def _sampling_check_selected_keys(self):

        self._parameter_variation.check_vadere_keys(self._env_man.vadere_basis_scenario)
        self._parameter_variation.check_omnet_keys(self._env_man.omnet_basis_ini)