Source code for commonroad.scenario.trajectory

import warnings
from typing import List, Union

import numpy as np

from commonroad.common.validity import (
    is_natural_number,
    is_positive,
    is_real_number,
    is_real_number_vector,
    is_valid_orientation,
)
from commonroad.scenario.state import State, TraceState
from commonroad.visualization.draw_params import (
    OptionalSpecificOrAllDrawParams,
    TrajectoryParams,
)
from commonroad.visualization.drawable import IDrawable
from commonroad.visualization.renderer import IRenderer


[docs]class Trajectory(IDrawable): """Class to model the movement of an object over time. The states of the trajectory can be either exact or uncertain (see :class:`commonroad.scenario.trajectory.State`); however, only exact time_step are allowed.""" def __init__(self, initial_time_step: int, state_list: List[TraceState]): """ :param initial_time_step: initial time step of the trajectory :param state_list: ordered sequence of states over time representing the trajectory. It is assumed that the time discretization between two states matches the time discretization of the scenario. """ self.initial_time_step: int = initial_time_step self._state_list: List[TraceState] = self.check_state_list(state_list)
[docs] def check_state_list(self, state_list: List[TraceState]) -> List[TraceState]: """ Checks whether state list is valid. :param state_list: state list which should be evaluated :return: evaluated state list """ assert isinstance( state_list, list ), "<Trajectory/state_list>: argument state_list of wrong type. " "Expected type: %s. Got type: %s." % ( list, type(state_list), ) assert len(state_list) >= 1, ( "<Trajectory/state_list>: argument state_list must contain at least one state." " length of state_list: %s." % len(state_list) ) assert all(isinstance(state, State) for state in state_list), ( "<Trajectory/state_list>: element of " "state_list is of wrong type. Expected type: " "%s." % List[State] ) assert all( is_natural_number(state.time_step) for state in state_list if hasattr(state, "time_step") ), "<Trajectory/state_list>: Element time_step of each state must be an integer." assert all(set(state_list[0].used_attributes) == set(state.used_attributes) for state in state_list), ( "<Trajectory/state_list>: all states must have the same attributes. Attributes of first state: %s." % state_list[0].attributes ) assert state_list[0].time_step == self.initial_time_step, ( f"state_list[0].time_step={state_list[0].time_step} != " f"self.initial_time_step={self.initial_time_step}" ) return state_list
def __eq__(self, other): if not isinstance(other, Trajectory): warnings.warn(f"Inequality between Trajectory {repr(self)} and different type {type(other)}") return False return self._initial_time_step == other.initial_time_step and list(self._state_list) == list(other.state_list) def __hash__(self): return hash((self._initial_time_step, tuple(self._state_list))) @property def initial_time_step(self) -> int: """Initial time step of the trajectory.""" return self._initial_time_step @initial_time_step.setter def initial_time_step(self, initial_time_step): assert isinstance(initial_time_step, int), ( "<Trajectory/initial_time_step>: argument initial_time_step of " "wrong type. Expected type: %s. Got type: %s." % (int, type(initial_time_step)) ) self._initial_time_step = initial_time_step
[docs] def append_state(self, state: TraceState): """Append the state to the trajectory. :param state: The new state. It's time step must be larger than the time step of the last state in the trajectory """ assert isinstance( state, State ), "<Trajectory/append_state>: argument state of wrong type. Expected type: %s. Got type: %s." % ( State, type(state), ) assert set(self._state_list[0].used_attributes) == set(state.used_attributes), ( "<Trajectory/append_state>: attributes of the argument state do not match" " the attributes of the other states in the state list." " Expected attributes: '%s'. Got attributes: '%s'" % (self._state_list[0].attributes, state.attributes) ) assert state.time_step > self.final_state.time_step, ( "<Trajectory/append_state>: the time step of the argument state" " must be larger than the time step of the last state in the trajectory." " Time step of last state in trajectory: %s. Got time step: %s" % (self.final_state.time_step, state.time_step) ) self._state_list.append(state)
@property def state_list(self) -> List[TraceState]: """List of states of the trajectory over time.""" return self._state_list @property def final_state(self) -> TraceState: """Final state of the trajectory.""" return self._state_list[-1]
[docs] def state_at_time_step(self, time_step: int) -> Union[TraceState, None]: """ Function to get the state of a trajectory at a specific time instance. :param time_step: considered time step :return: state of the trajectory at time_step """ state = None if self._initial_time_step <= time_step < self._initial_time_step + len(self._state_list): state = self._state_list[time_step - self._initial_time_step] return state
[docs] def states_in_time_interval(self, time_begin: int, time_end: int) -> List[Union[TraceState, None]]: """ Function to get the states of a trajectory at a specific time interval. :param time_begin: first considered time step :param time_end: last considered time step :return: list of states """ assert time_end >= time_begin return [self.state_at_time_step(time_step) for time_step in range(time_begin, time_end + 1)]
[docs] def translate_rotate(self, translation: np.ndarray, angle: float): """First translates each state of the trajectory, then rotates each state of the trajectory around the origin. :param translation: translation vector [x_off, y_off] in x- and y-direction :param angle: rotation angle in radian (counter-clockwise) """ assert is_real_number_vector(translation, 2), ( "<Trajectory/translate_rotate>: argument translation is not " "a vector of real numbers of length 2." ) assert is_real_number(angle), ( "<Trajectory/translate_rotate>: argument angle must be a scalar. " "angle = %s" % angle ) assert is_valid_orientation(angle), ( "<Trajectory/translate_rotate>: argument angle must be within the " "interval [-2pi,2pi]. angle = %s" % angle ) new_state_list = [] for i in range(len(self._state_list)): new_state_list.append(self._state_list[i].translate_rotate(translation, angle)) self._state_list = new_state_list
[docs] @classmethod def resample_continuous_time_state_list( cls, states: List[TraceState], time_stamps_cont: np.ndarray, resampled_dt: float, num_resampled_states: int, initial_time_cont: float = 0, ) -> "Trajectory": """ This method resamples a given state list with continuous time vector in a fixed time resolution. The interpolation is done in a linear fashion. :param states: The list of states to interpolate :param time_stamps_cont: The vector of continuous time stamps (corresponding to the states) :param resampled_dt: Target time step length :param num_resampled_states: The resulting number of states. It must hold (t_0+N*dT) \\in time interval :param initial_time_cont: The initial continuous time stamp (default 0). It must hold t\\in time interval :return: The resampled trajectory """ assert is_positive( resampled_dt ), "<Trajectory/interpolate_state_list>: Time step size must be a positive number! " "dT = {}".format( resampled_dt ) assert isinstance(states, list) and all(isinstance(x, State) for x in states), ( "<Trajectory/interpolate_state_list>: Provided state list " "is not in the correct format! State list = {}".format(states) ) assert is_real_number_vector(time_stamps_cont), ( "<Trajectory/interpolate_state_list>: Provided time vector is not in the " "correct format! time = {}".format(time_stamps_cont) ) assert len(states) == len(time_stamps_cont), ( "<Trajectory/interpolate_state_list>: Provided time and state lists do not " "share the same length! Time = {} / States = {}".format(len(time_stamps_cont), len(states)) ) assert is_positive(num_resampled_states) and is_natural_number( num_resampled_states ), "<Trajectory/interpolate_state_list>: Provided state horizon must be a " "positive Integer! N = {}".format( num_resampled_states ) assert is_real_number( initial_time_cont ), "<Trajectory/interpolate_state_list>: Provided initial time must be a " "real number! t_0 = {}".format( initial_time_cont ) assert any(time_stamps_cont <= initial_time_cont) and any( initial_time_cont <= time_stamps_cont ), "<Trajectory/interpolate_state_list>: Provided initial " "time is not within time vector! t_0 = {}".format( initial_time_cont ) assert any( initial_time_cont + num_resampled_states * resampled_dt <= time_stamps_cont ), "<Trajectory/interpolate_state_list>: Provided end time is not within time vector! t_h = {}".format( initial_time_cont + num_resampled_states * resampled_dt ) # prepare interpolation by determining all slots with values slots = list() values = list() for s in states[0].attributes: # check if state has attribute s if getattr(states[0], s) is not None: slots.append(s) values.append([]) # create interpolation vector t_i = np.arange( initial_time_cont, initial_time_cont + num_resampled_states * resampled_dt + resampled_dt, resampled_dt ) values_i = list() for s in slots: values = list() multiple = False # go through all states for x in states: if getattr(x, s) is not None: val = getattr(x, s) assert is_real_number(val) or is_real_number_vector(val), ( "<Trajectory/interpolate_state_list>: Currently, this method only " "supports states with real numbers! val = {}".format(val) ) # check if slot is defined for multiple values if not multiple and hasattr(val, "shape"): if len(val) > 1: multiple = True for i in range(len(val)): values.append([]) if multiple: for i, v in enumerate(val): values[i].append(v) else: values.append(val) else: raise ValueError( "<Trajectory/interpolate_state_list>: States do not share the same amount of variables!" ) # do the interpolation if multiple: temp = list() for v in values: temp.append(np.interp(t_i, time_stamps_cont, v)) # stack values again values_i.append(np.array(temp).transpose()) else: values_i.append(np.interp(t_i, time_stamps_cont, values)) state_type = states[0].__class__ # create new trajectory states_new = list() for i in range(len(t_i)): variables = dict() for j, s in enumerate(slots): variables[s] = values_i[j][i] variables["time_step"] = i states_new.append(state_type(**variables)) return cls(states_new[0].time_step, states_new)
def __str__(self): traffic_str = "\n" traffic_str += "Initial time step: {} \n".format(self.initial_time_step) traffic_str += "Number of states: {}\n".format(len(self.state_list)) traffic_str += "State elements: {}".format(self.state_list[0].attributes) return traffic_str
[docs] def draw(self, renderer: IRenderer, draw_params: OptionalSpecificOrAllDrawParams[TrajectoryParams] = None): renderer.draw_trajectory(self, draw_params)