Commit 5fa9fc7a authored by Edmond Irani Liu's avatar Edmond Irani Liu 🌊
Browse files

add sys path in motion primitive notebook

parent f929321b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
%% Cell type:markdown id: tags:

## Generation of Motion Primitives

This piece of code demonstrates how the motion primitves used in the search problem are generated.

%% Cell type:code id: tags:

``` python
# add directories
import sys
sys.path.append("../../GSMP/motion_automata/")
sys.path.append("../../GSMP/motion_automata/vehicle_model/")
sys.path.append("../../GSMP/tools/")
sys.path.append("…/…/GSMP/tools/commonroad-collision-checker/")

import os
import numpy as np
import itertools
from math import atan2, sin, cos
import xml.etree.ElementTree as et

# import vehicle model (https://gitlab.lrz.de/tum-cps/commonroad-vehicle-models/tree/master/Python)
from vehicleDynamics_KS import vehicleDynamics_KS
from parameters_vehicle1 import parameters_vehicle1
from parameters_vehicle2 import parameters_vehicle2
from parameters_vehicle3 import parameters_vehicle3

# import solution checker
from solution_checker import *

from commonroad.common.solution_writer import CommonRoadSolutionWriter, VehicleModel, VehicleType, CostFunction
```

%% Cell type:markdown id: tags:

### Helper function to create a Trajectory object from a list of states.

The elements of the states are: x, y, steering angle, velocity, orientation, time step

%% Cell type:code id: tags:

``` python
def create_trajectory_from_states(list_states):
    # list to hold states for final trajectory
    list_states_new = list()

    # iterate through trajectory states
    for state in list_states:
        # feed in required slots
        kwarg = {'position': np.array([state[0], state[1]]),
                 'velocity': state[3],
                 'steering_angle': state[2],
                 'orientation': state[4],
                 'time_step': state[5].astype(int)}

        # append state
        list_states_new.append(State(**kwarg))

    # create new trajectory for evaluation
    trajectory_new = Trajectory(initial_time_step=0, state_list=list_states_new)

    return trajectory_new
```

%% Cell type:markdown id: tags:

### Helper function to check the validity of a given trajectory.

The validity is checked with the help of a solution checher under Kinematic Single-Track Model.

%% Cell type:code id: tags:

``` python
def check_validity(trajectory_input, veh_type):
    csw = CommonRoadSolutionWriter(output_dir=os.getcwd(),
                                   scenario_id=0,
                                   step_size=0.1,
                                   vehicle_type=veh_type,
                                   vehicle_model=VehicleModel.KS,
                                   cost_function=CostFunction.JB1)

    # use solution writer to generate target xml file
    csw.add_solution_trajectory(trajectory=trajectory_input, planning_problem_id=100)
    xmlTree = csw.root_node

    # generate states to be checked
    [node] = xmlTree.findall('ksTrajectory')
    veh_trajectory = KSTrajectory.from_xml(node)
    veh_model = KinematicSingleTrackModel(veh_type_id, veh_trajectory, None)

    # validate
    result = TrajectoryValidator.is_trajectory_valid(veh_trajectory, veh_model, 0.1)

    return result
```

%% Cell type:markdown id: tags:

### Congiguration of parameters.

%% Cell type:code id: tags:

``` python
# settings

flag_print_info = False

# vehicle parameter
# 1: FORD_ESCORT 2: BMW_320i 3: VW_VANAGON
veh_type_id = 3

if veh_type_id == 1:
    veh_type = VehicleType.FORD_ESCORT
    veh_param = parameters_vehicle1()
    veh_name = "FORD_ESCORT"
elif veh_type_id == 2:
    veh_type = VehicleType.BMW_320i
    veh_param = parameters_vehicle2()
    veh_name = "BMW320i"
elif veh_type_id == 3:
    veh_type = VehicleType.VW_VANAGON
    veh_param = parameters_vehicle3()
    veh_name = "VW_VANAGON"

# total length of trajectory, in seconds
T = 0.5
# time step for states, in seconds
# commonroad scenarios have dt of 0.1 seconds; for higher accuracy of forward simulation, dt here is set to 0.05
# the simulated states will be down-sampled
dt = 0.05
# calculate time_stamps, *100 for 2 digits accuracy
time_stamps = ((np.arange(0, T, dt) + dt) * 100).astype(int)

# sampling range
# in m/s
min_range_sample_velocity = 0.0
max_range_sample_velocity = 20.0
num_sample_velocity = 20 + 1 # step v is then 1.0 m/s

# steer to one side only, we can mirror the primitives afterwards
# in rad
min_range_sample_steering_angle = 0
max_range_sample_steering_angle = veh_param.steering.max
num_sample_steering_angle = 8 + 1 # step is roughly 10 degrees
```

%% Cell type:code id: tags:

``` python
# create list of possible samples for velocity and steering angle
list_samples_v = np.linspace(min_range_sample_velocity,
                             max_range_sample_velocity,
                             num_sample_velocity)

list_samples_steering_angle = np.linspace(min_range_sample_steering_angle,
                                         max_range_sample_steering_angle,
                                         num_sample_steering_angle)

# for some statistics
num_possible_states = num_sample_velocity * num_sample_steering_angle
num_possible_start_end_combinations = num_possible_states ** 2
count_processed = 0
count_validated =0
count_accepted = 0

# for saving the results
list_traj_accepted = []
list_traj_failed = []

print("Total possible combination of states: ", num_possible_start_end_combinations)

# v = velocity, d = delta = steering_angle
# iterate through possible instance of the cartesian product of list_samples_v and list_samples_steering_angle
# create all possible combinations of (v_start, d_start) and (v_end, d_end)

for (v_start, d_start) in itertools.product(list_samples_v, list_samples_steering_angle):
    for (v_end, d_end) in itertools.product(list_samples_v, list_samples_steering_angle):

        count_processed += 1

        # Print progress
        if count_processed % 2000 == 0 and count_processed:
            print("Progress: {} primitives checked.".format(count_processed))

        # compute required inputs
        a_input = (v_end - v_start) / T
        steering_rate_input = (d_end - d_start) / T

        # check if the constraints are respected
        if (a_input > veh_param.longitudinal.a_max) or (a_input < -veh_param.longitudinal.a_max) or \
           (steering_rate_input > veh_param.steering.v_max) or (steering_rate_input < -veh_param.steering.v_max):
            continue

        if flag_print_info:
            print("{:^8}, {:^8}, {:^8}, {:^8}, {:^8}".format("x", "y","steer", "v", "theta"))
            print("{:=<46}".format(""))

        # list to store the states
        list_states = []

        # trajectory always starts at position (0, 0) m with orientation of 0 rad
        x_input = np.array([0.0, 0.0, d_start, v_start, 0.0])
        u_input = np.array([steering_rate_input, a_input])

        # time stamp = 0
        list_states.append(np.append(x_input, 0))

        flag_friction_constraint_satisfied = True

        # forward simulation of states
        # ref: https://gitlab.lrz.de/tum-cps/commonroad-vehicle-models/blob/master/vehicleModels_commonRoad.pdf, page 4
        for time_stamp in time_stamps:
            # simulate state transition
            x_dot = np.array(vehicleDynamics_KS(x_input, u_input, veh_param))

            # check friction circle constraint
            if (a_input ** 2 + (x_input[3] * x_dot[4]) ** 2) ** 0.5 > veh_param.longitudinal.a_max:
                flag_friction_constraint_satisfied = False
                break

            # generate new state
            x_output = x_input + x_dot * dt

            # subsample the states with step size of 0.1 seconds
            if time_stamp % 10 == 0:
                # add state to list
                list_states.append(np.append(x_output, time_stamp / 10))

            if flag_print_info:
                print("{:^8.3f}, {:^8.3f}, {:^8.3f}, {:^8.3f}, {:^8.3f}".format(
                    x_output[0], x_output[1], x_output[2], x_output[3], x_output[4]))

            # prepare for next iteration
            x_input = x_output

        # skip this trajectory if the friction constraint is not satisfied
        if not flag_friction_constraint_satisfied:
            continue

        # create trajectory from the list of states
        trajectory_new = create_trajectory_from_states(list_states)

        result = check_validity(trajectory_new, veh_type)
        count_validated += 1

        if result:
            count_accepted += 1
            list_traj_accepted.append(trajectory_new)

print("============================================")

if count_validated != 0:
    percentage_accept = round(count_accepted / count_validated, 2) * 100
else:
    percentage_accept = 0

print("Validated: {}, Accepted: {}, Rate: {}%".format(count_validated, count_accepted, percentage_accept))
```

%% Cell type:markdown id: tags:

### Plot generate motion primitives.

%% Cell type:code id: tags:

``` python
import matplotlib.pyplot as plt

for traj in list_traj_accepted:

    list_x = [state.position[0] for state in traj.state_list]
    list_y = [state.position[1] for state in traj.state_list]

    plt.plot(list_x, list_y)

plt.axis('equal')
plt.show()
```

%% Cell type:markdown id: tags:

### Create mirrored primitives

%% Cell type:code id: tags:

``` python
# make mirrored ones
import copy

# make sure to make a deep copy
list_traj_accepted_mirrored = copy.deepcopy(list_traj_accepted)

count_acc = 0
for traj in list_traj_accepted:
    list_states_mirrored = []

    for state in traj.state_list:
        # add mirrored state into list
        list_states_mirrored.append([state.position[0],
                                    -state.position[1],
                                    -state.steering_angle,
                                     state.velocity,
                                    -state.orientation,
                                     state.time_step])

    trajectory_new = create_trajectory_from_states(list_states_mirrored)

    # double check the validity before adding to the list
    if check_validity(trajectory_new, veh_type):
        list_traj_accepted_mirrored.append(trajectory_new)
        count_acc += 1

print("Total number of primitives (mirrored included): ", len(list_traj_accepted_mirrored))
```

%% Cell type:markdown id: tags:

### check average number of successors

%% Cell type:code id: tags:

``` python
list_count_seccesors = []
for prim_main in list_traj_accepted_mirrored:
    count_successors = 0

    for prim_2bc in list_traj_accepted_mirrored:

        state_final_prim_main = prim_main.state_list[-1]
        state_initial_prim_2bc = prim_2bc.state_list[0]

        if abs(state_final_prim_main.velocity - state_initial_prim_2bc.velocity) < 0.02 and \
           abs(state_final_prim_main.steering_angle - state_initial_prim_2bc.steering_angle) < 0.02:
            count_successors += 1

    list_count_seccesors.append(count_successors)

print("average number of successors: ", np.mean(list_count_seccesors))
```

%% Cell type:markdown id: tags:

### Plot final primitives

%% Cell type:code id: tags:

``` python
import matplotlib.pyplot as plt

fig = plt.figure()

for i in range(len(list_traj_accepted_mirrored)):
    traj = list_traj_accepted_mirrored[i]

    list_x = [state.position[0] for state in traj.state_list]
    list_y = [state.position[1] for state in traj.state_list]

    # length constraint
    x_start, y_start = list_x[0], list_y[0]
    x_end, y_end = list_x[-1], list_y[-1]

    # only plot trajectories that are not longer than the threshold
    if np.linalg.norm([x_end - x_start, y_end - y_start]) > 20:
        continue

    plt.plot(list_x, list_y)

plt.axis('equal')
plt.show()
```

%% Cell type:markdown id: tags:

### Generate sample trajectories and check if they pass the check

%% Cell type:code id: tags:

``` python
list_traj_accepted_mirrored_backup = copy.deepcopy(list_traj_accepted_mirrored)
```

%% Cell type:code id: tags:

``` python
# generate sample path and check if they pass the solution checker
import random
import copy
random.seed()

# number of primitives to be connected
num_seg_path = 10
num_simulation = 200
for count_run in range(num_simulation):
    print(count_run)
    count_seg_path = 0

    num_prim = len(list_traj_accepted_mirrored)
    # get a random start primitive id
    idx_prim = random.randrange(num_prim)
    list_trajectories = []

    while count_seg_path < num_seg_path:
        # retrieve primitive
        prim_main = copy.deepcopy(list_traj_accepted_mirrored[idx_prim])
        list_trajectories.append(prim_main)
        list_successors_prim_main = []
        count_seg_path += 1

        # obtain its successors
        for j in range(num_prim):
            prim_2bc = list_traj_accepted_mirrored[j]

            state_final_prim_main = prim_main.state_list[-1]
            state_initial_prim_2bc = prim_2bc.state_list[0]

            if abs(state_final_prim_main.velocity - state_initial_prim_2bc.velocity) < 0.02 and \
               abs(state_final_prim_main.steering_angle - state_initial_prim_2bc.steering_angle) < 0.02:

                list_successors_prim_main.append(j)

        # start over if a primitive does not have a valid successor
        num_successors_prim_main = len(list_successors_prim_main)
        if num_successors_prim_main == 0:
            count_seg_path = 0
            idx_prim = random.randrange(num_prim)
            list_trajectories = []
        else:
            # else add a random succesor into list of prims
            idx_successor = random.randrange(num_successors_prim_main)
            idx_prim = list_successors_prim_main[idx_successor]

#     fig = plt.figure()

    # plot first prim
    list_states_final = copy.deepcopy(list_trajectories[0].state_list)
    list_x = [state.position[0] for state in list_states_final]
    list_y = [state.position[1] for state in list_states_final]
    plt.scatter(list_x, list_y)

    list_trajectories_backup = copy.deepcopy(list_trajectories)

    # plot remaining prims
    for i in range(1, len(list_trajectories)):

        traj_pre = list_trajectories[i - 1]
        traj_cur = list_trajectories[i]

        # retrieve states
        state_final_traj_pre   = traj_pre.state_list[-1]
        state_initial_traj_cur = traj_cur.state_list[0]

        while not(state_initial_traj_cur.orientation < 0.001):
            print("error in orientation", state_initial_traj_cur.orientation, i)
            traj_cur.translate_rotate(np.zeros(2), -state_initial_traj_cur.orientation)
            state_initial_traj_cur = traj_cur.state_list[0]
            print(state_initial_traj_cur.orientation)

        while not (state_initial_traj_cur.position[0] < 0.001 and state_initial_traj_cur.position[1] < 0.001):
            print("error in position", state_initial_traj_cur.position, i)
            traj_cur.translate_rotate(-state_initial_traj_cur.position, 0)
            state_initial_traj_cur = traj_cur.state_list[0]
            print(state_initial_traj_cur.position)

        # rotate + translate with regard to the last state of preivous trajectory
        traj_cur.translate_rotate(np.zeros(2), state_final_traj_pre.orientation)
        traj_cur.translate_rotate(state_final_traj_pre.position, 0)

        # retrieve new states
        state_final_traj_pre   = traj_pre.state_list[-1]
        state_initial_traj_cur = traj_cur.state_list[0]

        list_x = [state.position[0] for state in traj_cur.state_list]
        list_y = [state.position[1] for state in traj_cur.state_list]

        plt.scatter(list_x, list_y)

        # discard the first state of second primitive onward
        traj_cur.state_list.pop(0)
        list_states_final += traj_cur.state_list


    list_x = [state.position[0] for state in list_states_final]
    list_y = [state.position[1] for state in list_states_final]

    plt.xlim([min(list_x) - 2, max(list_x) + 2])
    plt.ylim([min(list_y) - 2, max(list_y) + 2])
    plt.axis('equal')
#     plt.show()

     # save as cr node to be validated via the solution checker

    list_states_for_traj = []
    count = 0
    for state in list_states_final:
        list_states_for_traj.append([state.position[0],
                                   state.position[1],
                                   state.steering_angle,
                                   state.velocity,
                                   state.orientation,
                                   np.int64(count)])
        count += 1

    trajectory_new = create_trajectory_from_states(list_states_for_traj)
    result = check_validity(trajectory_new, veh_type)
#     print("Validation result of sample path: ", result)
    if not result:
        print("This is not good :(")
```

%% Cell type:markdown id: tags:

### save the result as xml file

%% Cell type:code id: tags:

``` python
# create Trajectories tag
node_trajectories = et.Element('Trajectories')

for trajectory in list_traj_accepted_mirrored:
# trajectory = list_traj_accepted_mirrored[137]

    # create a tag for individual trajectory
    node_trajectory = et.SubElement(node_trajectories, 'Trajectory')

    # add time duration tag
    node_duration = et.SubElement(node_trajectory, 'Duration')
    node_duration.set('unit','deci-second')
    node_duration.text = "10"

    list_states = trajectory.state_list

    # add start state
    node_start = et.SubElement(node_trajectory, 'Start')
    node_x = et.SubElement(node_start, 'x')
    node_y = et.SubElement(node_start, 'y')
    node_sa = et.SubElement(node_start, 'steering_angle')
    node_v = et.SubElement(node_start, 'velocity')
    node_o = et.SubElement(node_start, 'orientation')
    node_t = et.SubElement(node_start, 'time_step')

    state_start = list_states[0]

    node_x.text = str(state_start.position[0])
    node_y.text = str(state_start.position[1])
    node_sa.text = str(state_start.steering_angle)
    node_v.text = str(state_start.velocity)
    node_o.text = str(state_start.orientation)
    node_t.text = str(state_start.time_step)

    # add final state
    node_final = et.SubElement(node_trajectory, 'Final')
    node_x = et.SubElement(node_final, 'x')
    node_y = et.SubElement(node_final, 'y')
    node_sa = et.SubElement(node_final, 'steering_angle')
    node_v = et.SubElement(node_final, 'velocity')
    node_o = et.SubElement(node_final, 'orientation')
    node_t = et.SubElement(node_final, 'time_step')

    state_final = list_states[-1]

    node_x.text = str(state_final.position[0])
    node_y.text = str(state_final.position[1])
    node_sa.text = str(state_final.steering_angle)
    node_v.text = str(state_final.velocity)
    node_o.text = str(state_final.orientation)
    node_t.text = str(state_final.time_step)

    # add states in between
    list_states_in_between = list_states[1:-1]

    node_path = et.SubElement(node_trajectory, 'Path')

    for state in list_states_in_between:
        node_state = et.SubElement(node_path, 'State')

        node_x = et.SubElement(node_state, 'x')
        node_y = et.SubElement(node_state, 'y')
        node_sa = et.SubElement(node_state, 'steering_angle')
        node_v = et.SubElement(node_state, 'velocity')
        node_o = et.SubElement(node_state, 'orientation')
        node_t = et.SubElement(node_state, 'time_step')

        node_x.text = str(state.position[0])
        node_y.text = str(state.position[1])
        node_sa.text = str(state.steering_angle)
        node_v.text = str(state.velocity)
        node_o.text = str(state.orientation)
        node_t.text = str(state.time_step)
```

%% Cell type:code id: tags:

``` python
from xml.dom import minidom
prefix = "../../GSMP/motion_automata/motion_primitives/"
vtep = round((max_range_sample_velocity - min_range_sample_velocity) / (num_sample_velocity - 1), 2)
sastep = round(max_range_sample_steering_angle * 2 / (num_sample_steering_angle - 1), 2)
file_name = "V_{}_{}_Vstep_{}_SA_{}_{}_SAstep_{}_T_{}_Model_{}.xml".format(min_range_sample_velocity, max_range_sample_velocity,
                                                                           vtep,
                                                                           -max_range_sample_steering_angle, max_range_sample_steering_angle,
                                                                           sastep,
                                                                           round(T, 1),
                                                                           veh_name)

xml_prettified = minidom.parseString(et.tostring(node_trajectories)).toprettyxml(indent="   ")
with open(prefix + file_name, "w") as f:
    f.write(xml_prettified)
    print("file saved: {}".format(file_name))
```
+1 −0
Original line number Diff line number Diff line
%% Cell type:markdown id: tags:

## Generation of Motion Primitives

This piece of code demonstrates how the motion primitves used in the search problem are generated.

%% Cell type:code id: tags:

``` python
# add directories
import sys
sys.path.append("../../GSMP/motion_automata/")
sys.path.append("../../GSMP/motion_automata/vehicle_model/")
sys.path.append("../../GSMP/tools/")
sys.path.append("…/…/GSMP/tools/commonroad-collision-checker/")

import os
import numpy as np
import itertools
from math import atan2, sin, cos
import xml.etree.ElementTree as et

# import vehicle model (https://gitlab.lrz.de/tum-cps/commonroad-vehicle-models/tree/master/Python)
from vehicleDynamics_KS import vehicleDynamics_KS
from parameters_vehicle1 import parameters_vehicle1
from parameters_vehicle2 import parameters_vehicle2
from parameters_vehicle3 import parameters_vehicle3

# import solution checker
from solution_checker import *

from commonroad.common.solution_writer import CommonRoadSolutionWriter, VehicleModel, VehicleType, CostFunction
```

%% Cell type:markdown id: tags:

### Helper function to create a Trajectory object from a list of states.

The elements of the states are: x, y, steering angle, velocity, orientation, time step

%% Cell type:code id: tags:

``` python
def create_trajectory_from_states(list_states):
    # list to hold states for final trajectory
    list_states_new = list()

    # iterate through trajectory states
    for state in list_states:
        # feed in required slots
        kwarg = {'position': np.array([state[0], state[1]]),
                 'velocity': state[3],
                 'steering_angle': state[2],
                 'orientation': state[4],
                 'time_step': state[5].astype(int)}

        # append state
        list_states_new.append(State(**kwarg))

    # create new trajectory for evaluation
    trajectory_new = Trajectory(initial_time_step=0, state_list=list_states_new)

    return trajectory_new
```

%% Cell type:markdown id: tags:

### Helper function to check the validity of a given trajectory.

The validity is checked with the help of a solution checher under Kinematic Single-Track Model.

%% Cell type:code id: tags:

``` python
def check_validity(trajectory_input, veh_type):
    csw = CommonRoadSolutionWriter(output_dir=os.getcwd(),
                                   scenario_id=0,
                                   step_size=0.1,
                                   vehicle_type=veh_type,
                                   vehicle_model=VehicleModel.KS,
                                   cost_function=CostFunction.JB1)

    # use solution writer to generate target xml file
    csw.add_solution_trajectory(trajectory=trajectory_input, planning_problem_id=100)
    xmlTree = csw.root_node

    # generate states to be checked
    [node] = xmlTree.findall('ksTrajectory')
    veh_trajectory = KSTrajectory.from_xml(node)
    veh_model = KinematicSingleTrackModel(veh_type_id, veh_trajectory, None)

    # validate
    result = TrajectoryValidator.is_trajectory_valid(veh_trajectory, veh_model, 0.1)

    return result
```

%% Cell type:markdown id: tags:

### Congiguration of parameters.

%% Cell type:code id: tags:

``` python
# settings

flag_print_info = False

# vehicle parameter
# 1: FORD_ESCORT 2: BMW_320i 3: VW_VANAGON
veh_type_id = 3

if veh_type_id == 1:
    veh_type = VehicleType.FORD_ESCORT
    veh_param = parameters_vehicle1()
    veh_name = "FORD_ESCORT"
elif veh_type_id == 2:
    veh_type = VehicleType.BMW_320i
    veh_param = parameters_vehicle2()
    veh_name = "BMW320i"
elif veh_type_id == 3:
    veh_type = VehicleType.VW_VANAGON
    veh_param = parameters_vehicle3()
    veh_name = "VW_VANAGON"

# total length of trajectory, in seconds
T = 0.5
# time step for states, in seconds
# commonroad scenarios have dt of 0.1 seconds; for higher accuracy of forward simulation, dt here is set to 0.05
# the simulated states will be down-sampled
dt = 0.05
# calculate time_stamps, *100 for 2 digits accuracy
time_stamps = ((np.arange(0, T, dt) + dt) * 100).astype(int)

# sampling range
# in m/s
min_range_sample_velocity = 0.0
max_range_sample_velocity = 20.0
num_sample_velocity = 20 + 1 # step v is then 1.0 m/s

# steer to one side only, we can mirror the primitives afterwards
# in rad
min_range_sample_steering_angle = 0
max_range_sample_steering_angle = veh_param.steering.max
num_sample_steering_angle = 8 + 1 # step is roughly 10 degrees
```

%% Cell type:code id: tags:

``` python
# create list of possible samples for velocity and steering angle
list_samples_v = np.linspace(min_range_sample_velocity,
                             max_range_sample_velocity,
                             num_sample_velocity)

list_samples_steering_angle = np.linspace(min_range_sample_steering_angle,
                                         max_range_sample_steering_angle,
                                         num_sample_steering_angle)

# for some statistics
num_possible_states = num_sample_velocity * num_sample_steering_angle
num_possible_start_end_combinations = num_possible_states ** 2
count_processed = 0
count_validated =0
count_accepted = 0

# for saving the results
list_traj_accepted = []
list_traj_failed = []

print("Total possible combination of states: ", num_possible_start_end_combinations)

# v = velocity, d = delta = steering_angle
# iterate through possible instance of the cartesian product of list_samples_v and list_samples_steering_angle
# create all possible combinations of (v_start, d_start) and (v_end, d_end)

for (v_start, d_start) in itertools.product(list_samples_v, list_samples_steering_angle):
    for (v_end, d_end) in itertools.product(list_samples_v, list_samples_steering_angle):

        count_processed += 1

        # Print progress
        if count_processed % 2000 == 0 and count_processed:
            print("Progress: {} primitives checked.".format(count_processed))

        # compute required inputs
        a_input = (v_end - v_start) / T
        steering_rate_input = (d_end - d_start) / T

        # check if the constraints are respected
        if (a_input > veh_param.longitudinal.a_max) or (a_input < -veh_param.longitudinal.a_max) or \
           (steering_rate_input > veh_param.steering.v_max) or (steering_rate_input < -veh_param.steering.v_max):
            continue

        if flag_print_info:
            print("{:^8}, {:^8}, {:^8}, {:^8}, {:^8}".format("x", "y","steer", "v", "theta"))
            print("{:=<46}".format(""))

        # list to store the states
        list_states = []

        # trajectory always starts at position (0, 0) m with orientation of 0 rad
        x_input = np.array([0.0, 0.0, d_start, v_start, 0.0])
        u_input = np.array([steering_rate_input, a_input])

        # time stamp = 0
        list_states.append(np.append(x_input, 0))

        flag_friction_constraint_satisfied = True

        # forward simulation of states
        # ref: https://gitlab.lrz.de/tum-cps/commonroad-vehicle-models/blob/master/vehicleModels_commonRoad.pdf, page 4
        for time_stamp in time_stamps:
            # simulate state transition
            x_dot = np.array(vehicleDynamics_KS(x_input, u_input, veh_param))

            # check friction circle constraint
            if (a_input ** 2 + (x_input[3] * x_dot[4]) ** 2) ** 0.5 > veh_param.longitudinal.a_max:
                flag_friction_constraint_satisfied = False
                break

            # generate new state
            x_output = x_input + x_dot * dt

            # subsample the states with step size of 0.1 seconds
            if time_stamp % 10 == 0:
                # add state to list
                list_states.append(np.append(x_output, time_stamp / 10))

            if flag_print_info:
                print("{:^8.3f}, {:^8.3f}, {:^8.3f}, {:^8.3f}, {:^8.3f}".format(
                    x_output[0], x_output[1], x_output[2], x_output[3], x_output[4]))

            # prepare for next iteration
            x_input = x_output

        # skip this trajectory if the friction constraint is not satisfied
        if not flag_friction_constraint_satisfied:
            continue

        # create trajectory from the list of states
        trajectory_new = create_trajectory_from_states(list_states)

        result = check_validity(trajectory_new, veh_type)
        count_validated += 1

        if result:
            count_accepted += 1
            list_traj_accepted.append(trajectory_new)

print("============================================")

if count_validated != 0:
    percentage_accept = round(count_accepted / count_validated, 2) * 100
else:
    percentage_accept = 0

print("Validated: {}, Accepted: {}, Rate: {}%".format(count_validated, count_accepted, percentage_accept))
```

%% Cell type:markdown id: tags:

### Plot generate motion primitives.

%% Cell type:code id: tags:

``` python
import matplotlib.pyplot as plt

for traj in list_traj_accepted:

    list_x = [state.position[0] for state in traj.state_list]
    list_y = [state.position[1] for state in traj.state_list]

    plt.plot(list_x, list_y)

plt.axis('equal')
plt.show()
```

%% Cell type:markdown id: tags:

### Create mirrored primitives

%% Cell type:code id: tags:

``` python
# make mirrored ones
import copy

# make sure to make a deep copy
list_traj_accepted_mirrored = copy.deepcopy(list_traj_accepted)

count_acc = 0
for traj in list_traj_accepted:
    list_states_mirrored = []

    for state in traj.state_list:
        # add mirrored state into list
        list_states_mirrored.append([state.position[0],
                                    -state.position[1],
                                    -state.steering_angle,
                                     state.velocity,
                                    -state.orientation,
                                     state.time_step])

    trajectory_new = create_trajectory_from_states(list_states_mirrored)

    # double check the validity before adding to the list
    if check_validity(trajectory_new, veh_type):
        list_traj_accepted_mirrored.append(trajectory_new)
        count_acc += 1

print("Total number of primitives (mirrored included): ", len(list_traj_accepted_mirrored))
```

%% Cell type:markdown id: tags:

### check average number of successors

%% Cell type:code id: tags:

``` python
list_count_seccesors = []
for prim_main in list_traj_accepted_mirrored:
    count_successors = 0

    for prim_2bc in list_traj_accepted_mirrored:

        state_final_prim_main = prim_main.state_list[-1]
        state_initial_prim_2bc = prim_2bc.state_list[0]

        if abs(state_final_prim_main.velocity - state_initial_prim_2bc.velocity) < 0.02 and \
           abs(state_final_prim_main.steering_angle - state_initial_prim_2bc.steering_angle) < 0.02:
            count_successors += 1

    list_count_seccesors.append(count_successors)

print("average number of successors: ", np.mean(list_count_seccesors))
```

%% Cell type:markdown id: tags:

### Plot final primitives

%% Cell type:code id: tags:

``` python
import matplotlib.pyplot as plt

fig = plt.figure()

for i in range(len(list_traj_accepted_mirrored)):
    traj = list_traj_accepted_mirrored[i]

    list_x = [state.position[0] for state in traj.state_list]
    list_y = [state.position[1] for state in traj.state_list]

    # length constraint
    x_start, y_start = list_x[0], list_y[0]
    x_end, y_end = list_x[-1], list_y[-1]

    # only plot trajectories that are not longer than the threshold
    if np.linalg.norm([x_end - x_start, y_end - y_start]) > 20:
        continue

    plt.plot(list_x, list_y)

plt.axis('equal')
plt.show()
```

%% Cell type:markdown id: tags:

### Generate sample trajectories and check if they pass the check

%% Cell type:code id: tags:

``` python
list_traj_accepted_mirrored_backup = copy.deepcopy(list_traj_accepted_mirrored)
```

%% Cell type:code id: tags:

``` python
# generate sample path and check if they pass the solution checker
import random
import copy
random.seed()

# number of primitives to be connected
num_seg_path = 10
num_simulation = 200
for count_run in range(num_simulation):
    print(count_run)
    count_seg_path = 0

    num_prim = len(list_traj_accepted_mirrored)
    # get a random start primitive id
    idx_prim = random.randrange(num_prim)
    list_trajectories = []

    while count_seg_path < num_seg_path:
        # retrieve primitive
        prim_main = copy.deepcopy(list_traj_accepted_mirrored[idx_prim])
        list_trajectories.append(prim_main)
        list_successors_prim_main = []
        count_seg_path += 1

        # obtain its successors
        for j in range(num_prim):
            prim_2bc = list_traj_accepted_mirrored[j]

            state_final_prim_main = prim_main.state_list[-1]
            state_initial_prim_2bc = prim_2bc.state_list[0]

            if abs(state_final_prim_main.velocity - state_initial_prim_2bc.velocity) < 0.02 and \
               abs(state_final_prim_main.steering_angle - state_initial_prim_2bc.steering_angle) < 0.02:

                list_successors_prim_main.append(j)

        # start over if a primitive does not have a valid successor
        num_successors_prim_main = len(list_successors_prim_main)
        if num_successors_prim_main == 0:
            count_seg_path = 0
            idx_prim = random.randrange(num_prim)
            list_trajectories = []
        else:
            # else add a random succesor into list of prims
            idx_successor = random.randrange(num_successors_prim_main)
            idx_prim = list_successors_prim_main[idx_successor]

#     fig = plt.figure()

    # plot first prim
    list_states_final = copy.deepcopy(list_trajectories[0].state_list)
    list_x = [state.position[0] for state in list_states_final]
    list_y = [state.position[1] for state in list_states_final]
    plt.scatter(list_x, list_y)

    list_trajectories_backup = copy.deepcopy(list_trajectories)

    # plot remaining prims
    for i in range(1, len(list_trajectories)):

        traj_pre = list_trajectories[i - 1]
        traj_cur = list_trajectories[i]

        # retrieve states
        state_final_traj_pre   = traj_pre.state_list[-1]
        state_initial_traj_cur = traj_cur.state_list[0]

        while not(state_initial_traj_cur.orientation < 0.001):
            print("error in orientation", state_initial_traj_cur.orientation, i)
            traj_cur.translate_rotate(np.zeros(2), -state_initial_traj_cur.orientation)
            state_initial_traj_cur = traj_cur.state_list[0]
            print(state_initial_traj_cur.orientation)

        while not (state_initial_traj_cur.position[0] < 0.001 and state_initial_traj_cur.position[1] < 0.001):
            print("error in position", state_initial_traj_cur.position, i)
            traj_cur.translate_rotate(-state_initial_traj_cur.position, 0)
            state_initial_traj_cur = traj_cur.state_list[0]
            print(state_initial_traj_cur.position)

        # rotate + translate with regard to the last state of preivous trajectory
        traj_cur.translate_rotate(np.zeros(2), state_final_traj_pre.orientation)
        traj_cur.translate_rotate(state_final_traj_pre.position, 0)

        # retrieve new states
        state_final_traj_pre   = traj_pre.state_list[-1]
        state_initial_traj_cur = traj_cur.state_list[0]

        list_x = [state.position[0] for state in traj_cur.state_list]
        list_y = [state.position[1] for state in traj_cur.state_list]

        plt.scatter(list_x, list_y)

        # discard the first state of second primitive onward
        traj_cur.state_list.pop(0)
        list_states_final += traj_cur.state_list


    list_x = [state.position[0] for state in list_states_final]
    list_y = [state.position[1] for state in list_states_final]

    plt.xlim([min(list_x) - 2, max(list_x) + 2])
    plt.ylim([min(list_y) - 2, max(list_y) + 2])
    plt.axis('equal')
#     plt.show()

     # save as cr node to be validated via the solution checker

    list_states_for_traj = []
    count = 0
    for state in list_states_final:
        list_states_for_traj.append([state.position[0],
                                   state.position[1],
                                   state.steering_angle,
                                   state.velocity,
                                   state.orientation,
                                   np.int64(count)])
        count += 1

    trajectory_new = create_trajectory_from_states(list_states_for_traj)
    result = check_validity(trajectory_new, veh_type)
#     print("Validation result of sample path: ", result)
    if not result:
        print("This is not good :(")
```

%% Cell type:markdown id: tags:

### save the result as xml file

%% Cell type:code id: tags:

``` python
# create Trajectories tag
node_trajectories = et.Element('Trajectories')

for trajectory in list_traj_accepted_mirrored:
# trajectory = list_traj_accepted_mirrored[137]

    # create a tag for individual trajectory
    node_trajectory = et.SubElement(node_trajectories, 'Trajectory')

    # add time duration tag
    node_duration = et.SubElement(node_trajectory, 'Duration')
    node_duration.set('unit','deci-second')
    node_duration.text = "10"

    list_states = trajectory.state_list

    # add start state
    node_start = et.SubElement(node_trajectory, 'Start')
    node_x = et.SubElement(node_start, 'x')
    node_y = et.SubElement(node_start, 'y')
    node_sa = et.SubElement(node_start, 'steering_angle')
    node_v = et.SubElement(node_start, 'velocity')
    node_o = et.SubElement(node_start, 'orientation')
    node_t = et.SubElement(node_start, 'time_step')

    state_start = list_states[0]

    node_x.text = str(state_start.position[0])
    node_y.text = str(state_start.position[1])
    node_sa.text = str(state_start.steering_angle)
    node_v.text = str(state_start.velocity)
    node_o.text = str(state_start.orientation)
    node_t.text = str(state_start.time_step)

    # add final state
    node_final = et.SubElement(node_trajectory, 'Final')
    node_x = et.SubElement(node_final, 'x')
    node_y = et.SubElement(node_final, 'y')
    node_sa = et.SubElement(node_final, 'steering_angle')
    node_v = et.SubElement(node_final, 'velocity')
    node_o = et.SubElement(node_final, 'orientation')
    node_t = et.SubElement(node_final, 'time_step')

    state_final = list_states[-1]

    node_x.text = str(state_final.position[0])
    node_y.text = str(state_final.position[1])
    node_sa.text = str(state_final.steering_angle)
    node_v.text = str(state_final.velocity)
    node_o.text = str(state_final.orientation)
    node_t.text = str(state_final.time_step)

    # add states in between
    list_states_in_between = list_states[1:-1]

    node_path = et.SubElement(node_trajectory, 'Path')

    for state in list_states_in_between:
        node_state = et.SubElement(node_path, 'State')

        node_x = et.SubElement(node_state, 'x')
        node_y = et.SubElement(node_state, 'y')
        node_sa = et.SubElement(node_state, 'steering_angle')
        node_v = et.SubElement(node_state, 'velocity')
        node_o = et.SubElement(node_state, 'orientation')
        node_t = et.SubElement(node_state, 'time_step')

        node_x.text = str(state.position[0])
        node_y.text = str(state.position[1])
        node_sa.text = str(state.steering_angle)
        node_v.text = str(state.velocity)
        node_o.text = str(state.orientation)
        node_t.text = str(state.time_step)
```

%% Cell type:code id: tags:

``` python
from xml.dom import minidom
prefix = "../../GSMP/motion_automata/motion_primitives/"
vtep = round((max_range_sample_velocity - min_range_sample_velocity) / (num_sample_velocity - 1), 2)
sastep = round(max_range_sample_steering_angle * 2 / (num_sample_steering_angle - 1), 2)
file_name = "V_{}_{}_Vstep_{}_SA_{}_{}_SAstep_{}_T_{}_Model_{}.xml".format(min_range_sample_velocity, max_range_sample_velocity,
                                                                           vtep,
                                                                           -max_range_sample_steering_angle, max_range_sample_steering_angle,
                                                                           sastep,
                                                                           round(T, 1),
                                                                           veh_name)

xml_prettified = minidom.parseString(et.tostring(node_trajectories)).toprettyxml(indent="   ")
with open(prefix + file_name, "w") as f:
    f.write(xml_prettified)
    print("file saved: {}".format(file_name))
```