Optimization opportunities in the Flatland environment

Here are some potential optimizations in the Flatland environment discovered by Adrian Egli from SBB. They will eventually be integrated in the Flatland codebase, but you are already welcome to take advantage of them.

If you do test and integrate them, you are encouraged to submit PRs to the Flatland repository, which would make you a Flatland contributor!

#---- SpedUp ~7x -----------------------------------------------------------------------------------------------------
#ncalls  tottime  percall  cumtime  percall filename:lineno(function)
#109161    0.131    0.000    0.131    0.000 grid4_utils.py:29(get_new_position)
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
	return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
#---- ORIGINAL -----------------------------------------------------------------------------------------------------
#ncalls  tottime  percall  cumtime  percall filename:lineno(function)
#112703    0.893    0.000    1.355    0.000 grid4_utils.py:32(get_new_position)
def get_new_position(position, movement):
     """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
    if movement == Grid4TransitionsEnum.NORTH:
        return (position[0] - 1, position[1])
    elif movement == Grid4TransitionsEnum.EAST:
        return (position[0], position[1] + 1)
    elif movement == Grid4TransitionsEnum.SOUTH:
        return (position[0] + 1, position[1])
    elif movement == Grid4TransitionsEnum.WEST:
        return (position[0], position[1] - 1)
#---- SpeedUp ~3x ...............................................................
#ncalls  tottime  percall  cumtime  percall filename:lineno(function)
#27121    0.041    0.000    0.273    0.000 grid4.py:66(get_transitions)
from numba import njit,jit

def get_transitions(self,cell_transition, orientation):
	return opt_get_transitions(cell_transition,orientation)

@jit()
def opt_get_transitions(cell_transition, orientation):
    """
    Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
    if no diagonal transitions allowed) available for an agent oriented
    in direction `orientation` and inside a cell with
    transitions `cell_transition`.
    Parameters
    ----------
    cell_transition : int
        16 bits used to encode the valid transitions for a cell.
    orientation : int
        Orientation of the agent inside the cell.
    Returns
    -------
    tuple
        List of the validity of transitions in the cell.
    """
    bits = (cell_transition >> ((3 - orientation) * 4))
    return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)

#---- ORIGINAL -----------------------------------------------------------------------------------------------------
#ncalls  tottime  percall  cumtime  percall filename:lineno(function)
#25399    0.146    0.000    0.146    0.000 grid4.py:66(get_transitions)
def opt_get_transitions(self, cell_transition, orientation):

I think we could use numba to increase the performance. Especially for all pure numpy and python methods which can be made “static”.

I highly recommend to replace np.isclose with simple equivalent function:

def my_isclose(x, y, rtol=1.e-5, atol=1.e-8):
    return abs(x - y) <= atol + rtol * abs(y)

Here is performance for RailEnv.step before:

912200 function calls (870670 primitive calls) in 1.818 seconds

After

492968 function calls (477640 primitive calls) in 1.161 seconds

1 Like