Here are some potential optimizations in the Flatland environment discovered by Adrian Egli from SBB. They will eventually be integrated in the Flatland codebase, but you are already welcome to take advantage of them.
If you do test and integrate them, you are encouraged to submit PRs to the Flatland repository, which would make you a Flatland contributor!
#---- SpedUp ~7x -----------------------------------------------------------------------------------------------------
#ncalls tottime percall cumtime percall filename:lineno(function)
#109161 0.131 0.000 0.131 0.000 grid4_utils.py:29(get_new_position)
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
#---- ORIGINAL -----------------------------------------------------------------------------------------------------
#ncalls tottime percall cumtime percall filename:lineno(function)
#112703 0.893 0.000 1.355 0.000 grid4_utils.py:32(get_new_position)
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
#---- SpeedUp ~3x ...............................................................
#ncalls tottime percall cumtime percall filename:lineno(function)
#27121 0.041 0.000 0.273 0.000 grid4.py:66(get_transitions)
from numba import njit,jit
def get_transitions(self,cell_transition, orientation):
return opt_get_transitions(cell_transition,orientation)
@jit()
def opt_get_transitions(cell_transition, orientation):
"""
Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent oriented
in direction `orientation` and inside a cell with
transitions `cell_transition`.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
#---- ORIGINAL -----------------------------------------------------------------------------------------------------
#ncalls tottime percall cumtime percall filename:lineno(function)
#25399 0.146 0.000 0.146 0.000 grid4.py:66(get_transitions)
def opt_get_transitions(self, cell_transition, orientation):
I think we could use numba to increase the performance. Especially for all pure numpy and python methods which can be made “static”.