From 53dde234de74edb1e92b805b5c079f9005a43c8d Mon Sep 17 00:00:00 2001 From: Gioele Molinari Date: Mon, 21 Jul 2025 14:39:30 +0200 Subject: [PATCH 1/9] Adapt Terra to start with custom agent position and angle --- terra/agent.py | 197 ++++++++++++++++++++++++++++++++++++--- terra/env.py | 51 +++++++--- terra/state.py | 18 +++- terra/viz/game/game.py | 1 + terra/viz/main_manual.py | 12 ++- 5 files changed, 248 insertions(+), 31 deletions(-) diff --git a/terra/agent.py b/terra/agent.py index 1c78d1395..4eda547cd 100644 --- a/terra/agent.py +++ b/terra/agent.py @@ -1,4 +1,4 @@ -from typing import NamedTuple +from typing import NamedTuple, Optional, Tuple import jax import jax.numpy as jnp @@ -49,21 +49,82 @@ def new( max_traversable_y: int, padding_mask: Array, action_map: Array, + custom_pos: Optional[Tuple[int, int]] = None, + custom_angle: Optional[int] = None, + ) -> tuple["Agent", jax.random.PRNGKey]: + """ + Create a new agent with specified parameters. + + Args: + key: JAX random key + env_cfg: Environment configuration + max_traversable_x: Maximum traversable x coordinate + max_traversable_y: Maximum traversable y coordinate + padding_mask: Mask indicating obstacles + custom_pos: Optional custom position (x, y) to place the agent + custom_angle: Optional custom angle for the agent + + Returns: + New agent instance and updated random key + """ + # Handle custom position or default based on config + has_custom_args = (custom_pos is not None) or (custom_angle is not None) + + def use_custom_position(k): + # Create position based on custom args or defaults + temp_pos = IntMap(jnp.array(custom_pos)) if custom_pos is not None else IntMap(jnp.array([-1, -1])) + temp_angle = jnp.full((1,), custom_angle, dtype=IntMap) if custom_angle is not None else jnp.full((1,), -1, dtype=IntMap) + + # Get default position for missing components + def_pos, def_angle, _ = _get_top_left_init_state(k, env_cfg) + + # Combine custom and default values + pos = jnp.where(jnp.any(temp_pos < 0), def_pos, temp_pos) + angle = jnp.where(jnp.any(temp_angle < 0), def_angle, temp_angle) + + # Check validity and return result using jax.lax.cond + valid = _validate_agent_position( + pos, angle, env_cfg, padding_mask, + env_cfg.agent.width, env_cfg.agent.height + ) + + # Define the true and false branches for jax.lax.cond + def true_fn(_): + return (pos, angle, k) + + def false_fn(_): + return jax.lax.cond( + env_cfg.agent.random_init_state, + lambda k_inner: _get_random_init_state( + k_inner, env_cfg, max_traversable_x, max_traversable_y, + padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height, + ), + lambda k_inner: _get_top_left_init_state(k_inner, env_cfg), + k + ) + + # Use jax.lax.cond to handle the validity check + return jax.lax.cond(valid, true_fn, false_fn, None) + + def use_default_position(k): + # Use existing logic for random or top-left position + return jax.lax.cond( + env_cfg.agent.random_init_state, + lambda k_inner: _get_random_init_state( + k_inner, env_cfg, max_traversable_x, max_traversable_y, + padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height, + ), + lambda k_inner: _get_top_left_init_state(k_inner, env_cfg), + k + ) + + # Use jax.lax.cond for JAX-compatible control flow pos_base, angle_base, key = jax.lax.cond( - env_cfg.agent.random_init_state, - lambda k: _get_random_init_state( - k, - env_cfg, - max_traversable_x, - max_traversable_y, - padding_mask, - action_map, - env_cfg.agent.width, - env_cfg.agent.height, - ), - lambda k: _get_top_left_init_state(k, env_cfg), - key, + has_custom_args, + use_custom_position, + use_default_position, + key ) agent_state = AgentState( @@ -82,6 +143,112 @@ def new( return Agent(agent_state=agent_state, width=width, height=height, moving_dumped_dirt=moving_dumped_dirt), key +def _validate_agent_position( + pos_base: Array, + angle_base: Array, + env_cfg: EnvConfig, + padding_mask: Array, + agent_width: int, + agent_height: int, +) -> Array: + """ + Validate if an agent position is valid (within bounds and not intersecting obstacles). + + Returns: + JAX array with boolean value indicating if the position is valid + """ + map_width = padding_mask.shape[0] + map_height = padding_mask.shape[1] + + # Check if position is within bounds + max_center_coord = jnp.ceil( + jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1])) + ).astype(IntMap) + + max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width) + max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height) + + within_bounds = jnp.logical_and( + jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord), + jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord) + ) + + # Check if position intersects with obstacles + def check_obstacle_intersection(_): + agent_corners_xy = get_agent_corners( + pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base + ) + polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height) + has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1)) + return jnp.logical_not(has_obstacle) + + def return_false(_): + return jnp.array(False) + + # Only check obstacles if we're within bounds (to avoid unnecessary computations) + valid = jax.lax.cond( + within_bounds, + check_obstacle_intersection, + return_false, + None + ) + + return valid + + +def _validate_agent_position( + pos_base: Array, + angle_base: Array, + env_cfg: EnvConfig, + padding_mask: Array, + agent_width: int, + agent_height: int, +) -> Array: + """ + Validate if an agent position is valid (within bounds and not intersecting obstacles). + + Returns: + JAX array with boolean value indicating if the position is valid + """ + map_width = padding_mask.shape[0] + map_height = padding_mask.shape[1] + + # Check if position is within bounds + max_center_coord = jnp.ceil( + jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1])) + ).astype(IntMap) + + max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width) + max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height) + + within_bounds = jnp.logical_and( + jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord), + jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord) + ) + + # Check if position intersects with obstacles + def check_obstacle_intersection(_): + agent_corners_xy = get_agent_corners( + pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base + ) + polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height) + has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1)) + return jnp.logical_not(has_obstacle) + + def return_false(_): + return jnp.array(False) + + # Only check obstacles if we're within bounds (to avoid unnecessary computations) + valid = jax.lax.cond( + within_bounds, + check_obstacle_intersection, + return_false, + None + ) + + return valid + + def _get_top_left_init_state(key: jax.random.PRNGKey, env_cfg: EnvConfig): max_center_coord = jnp.ceil( jnp.max( @@ -174,4 +341,4 @@ def _check_intersection(): ), ) - return pos_base, angle_base, key + return pos_base, angle_base, key \ No newline at end of file diff --git a/terra/env.py b/terra/env.py index 9d29a9423..11b5653ce 100644 --- a/terra/env.py +++ b/terra/env.py @@ -1,6 +1,7 @@ from collections.abc import Callable from functools import partial from typing import NamedTuple +from typing import Any, Optional, Tuple import jax import jax.numpy as jnp @@ -82,6 +83,8 @@ def reset( dumpability_mask_init: Array, action_map: Array, env_cfg: EnvConfig, + custom_pos: Optional[Tuple[int, int]] = None, + custom_angle: Optional[int] = None, ) -> tuple[State, dict[str, Array]]: """ Resets the environment using values from config files, and a seed. @@ -95,6 +98,8 @@ def reset( trench_type, dumpability_mask_init, action_map, + custom_pos, + custom_angle, ) state = self.wrap_state(state) @@ -154,7 +159,7 @@ def render_obs_pygame( Renders the environment at a given observation. """ if info is not None: - target_tiles = info["target_tiles"] + target_tiles = info.get("target_tiles", None) else: target_tiles = None @@ -169,6 +174,7 @@ def render_obs_pygame( loaded=obs["agent_state"][..., [5]], target_tiles=target_tiles, generate_gif=generate_gif, + info=info, ) @partial(jax.jit, static_argnums=(0,)) @@ -360,7 +366,9 @@ def _get_map(self, maps_buffer_keys: jax.random.PRNGKey, env_cfgs: EnvConfig): return jax.vmap(self.maps_buffer.get_map)(maps_buffer_keys, env_cfgs) @partial(jax.jit, static_argnums=(0,)) - def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey) -> State: + def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey, + custom_pos: Optional[Tuple[int, int]] = None, + custom_angle: Optional[int] = None) -> State: env_cfgs = self.curriculum_manager.reset_cfgs(env_cfgs) env_cfgs = self.update_env_cfgs(env_cfgs) ( @@ -372,17 +380,36 @@ def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey) -> State: action_maps, new_rng_key, ) = self._get_map_init(rng_key, env_cfgs) - timestep = jax.vmap(self.terra_env.reset)( - rng_key, - target_maps, - padding_masks, - trench_axes, - trench_type, - dumpability_mask_init, - action_maps, - env_cfgs, - ) + timestep = jax.vmap( + self.terra_env.reset, + in_axes=(0, 0, 0, 0, 0, 0, 0, 0, None, None) + )( + rng_key, + target_maps, + padding_masks, + trench_axes, + trench_type, + dumpability_mask_init, + action_maps, + env_cfgs, + custom_pos, + custom_angle, + ) return timestep + + @property + def actions_size(self) -> int: + """ + Number of actions played at every env step. + """ + return self.num_actions + + @property + def num_actions(self) -> int: + """ + Total number of actions + """ + return self.batch_cfg.action_type.get_num_actions() @partial(jax.jit, static_argnums=(0,)) def step( diff --git a/terra/state.py b/terra/state.py index d3cd5731f..6e55fe898 100644 --- a/terra/state.py +++ b/terra/state.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional, Tuple from typing import NamedTuple import jax @@ -71,14 +71,24 @@ def new( trench_type: Array, dumpability_mask_init: Array, action_map: Array, + custom_pos: Optional[Tuple[int, int]] = None, + custom_angle: Optional[int] = None, ) -> "State": world = GridWorld.new( target_map, padding_mask, trench_axes, trench_type, dumpability_mask_init, action_map ) agent, key = Agent.new( - key, env_cfg, world.max_traversable_x, world.max_traversable_y, padding_mask, action_map + key, + env_cfg, + world.max_traversable_x, + world.max_traversable_y, + padding_mask, + action_map, + custom_pos=custom_pos, + custom_angle=custom_angle, ) + agent = jax.tree_map( lambda x: x if isinstance(x, Array) else jnp.array(x), agent ) @@ -100,6 +110,8 @@ def _reset( trench_type: Array, dumpability_mask_init: Array, action_map: Array, + custom_pos: Optional[Tuple[int, int]] = None, + custom_angle: Optional[int] = None, ) -> "State": """ Resets the already-existing State @@ -114,6 +126,8 @@ def _reset( trench_type=trench_type, dumpability_mask_init=dumpability_mask_init, action_map=action_map, + custom_pos=custom_pos, + custom_angle=custom_angle, ) def _step(self, action: Action) -> "State": diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index 7959ec983..2a93d93ed 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -85,6 +85,7 @@ def run( loaded, generate_gif, target_tiles=None, + info=None, ): # self.events() self.update( diff --git a/terra/viz/main_manual.py b/terra/viz/main_manual.py index aa9e15135..dd9387bd8 100644 --- a/terra/viz/main_manual.py +++ b/terra/viz/main_manual.py @@ -40,7 +40,15 @@ def main(): env_cfgs = jax.vmap(lambda x: EnvConfig.new())(jnp.arange(n_envs)) rng, _rng = jax.random.split(rng) _rng = _rng[None] - timestep = env.reset(env_cfgs, _rng) + + # initial_custom_pos = (32, 32) + # initial_custom_angle = 0 + initial_custom_pos = None # Use default random position + initial_custom_angle = None # Use default random angle + + timestep = env.reset(env_cfgs, _rng, + custom_pos=initial_custom_pos, + custom_angle=initial_custom_angle) print(f"{timestep.state.agent.width=}") print(f"{timestep.state.agent.height=}") @@ -49,7 +57,7 @@ def main(): def repeat_action(action, n_times=n_envs): return action_type.new(action.action[None].repeat(n_times, 0)) - + # Trigger the JIT compilation timestep = env.step(timestep, repeat_action(action_type.do_nothing()), _rng) end_time = time.time() From 9484156fccb10656ac73da3c8a79fb8f5b876afc Mon Sep 17 00:00:00 2001 From: Gioele Molinari Date: Mon, 21 Jul 2025 16:03:53 +0200 Subject: [PATCH 2/9] Fix visualization --- terra/viz/game/game.py | 325 +++++++++++++++++++++++++++++++++++------ 1 file changed, 280 insertions(+), 45 deletions(-) diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index 2a93d93ed..804d2bfe4 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -99,7 +99,7 @@ def run( loaded, target_tiles, ) - self.draw() + self.draw(info) if generate_gif: frame = pg.surfarray.array3d(pg.display.get_surface()) self.frames.append(frame.swapaxes(0, 1)) @@ -177,41 +177,63 @@ def update_world_agent( for thread in threads: thread.join() - def draw(self): - self.surface.fill("#F0F0F0") - agent_surfaces = [] - agent_positions = [] - for i, (world, agent) in enumerate(zip(self.worlds, self.agents)): - ix = i % self.n_envs_y - iy = i // self.n_envs_y + def _create_agent_surface(self, agent, base_dir, cabin_dir, loaded): + """Create a temporary agent with specific angles and return its surface.""" - total_offset_x = ( - ix * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size - ) - total_offset_y = ( - iy * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size + temp_agent = Agent( + agent.width if hasattr(agent, 'width') else agent.w, + agent.height if hasattr(agent, 'height') else agent.h, + agent.tile_size, + agent.angles_base, + agent.angles_cabin ) + + # Update it with the specific position and angles + temp_agent.update([0, 0], base_dir, cabin_dir, loaded) # Position doesn't matter for surface creation - # Draw terrain - for x in range(world.grid_length_x): - for y in range(world.grid_length_y): - sq = world.action_map[x][y]["cart_rect"] - c = world.action_map[x][y]["color"] - rect = pg.Rect( - sq[0][0] + total_offset_x, - sq[0][1] + total_offset_y, - self.tile_size, - self.tile_size, - ) - pg.draw.rect(self.surface, c, rect, 0) - - # Highlight target tiles (where the digger will dig / dump) - if hasattr(world, 'target_tiles') and world.target_tiles is not None: - flat_idx = y * world.grid_length_x + x - if flat_idx < len(world.target_tiles) and world.target_tiles[flat_idx]: - pg.draw.rect(self.surface, "#FF3300", rect, 2) - + # Get vertices for the body and cabin + body_vertices = temp_agent.agent["body"]["vertices"] + cabin_vertices = temp_agent.agent["cabin"]["vertices"] + body_color = temp_agent.agent["body"]["color"] + cabin_color = temp_agent.agent["cabin"]["color"] + + # Calculate bounding box + all_vertices = body_vertices + cabin_vertices + #all_vertices = body_vertices + min_x = min(v[0] for v in all_vertices) + min_y = min(v[1] for v in all_vertices) + max_x = max(v[0] for v in all_vertices) + max_y = max(v[1] for v in all_vertices) + + # Create surface + surface_width = math.ceil(max_x - min_x) + 2 + surface_height = math.ceil(max_y - min_y) + 2 + agent_surface = pg.Surface((surface_width, surface_height), pg.SRCALPHA) + + # Adjust vertices for the surface coordinate system + body_offset = [(v[0] - min_x, v[1] - min_y) for v in body_vertices] + cabin_offset = [(v[0] - min_x, v[1] - min_y) for v in cabin_vertices] + + # Draw agent parts + pg.draw.polygon(agent_surface, body_color, body_offset) + pg.draw.polygon(agent_surface, cabin_color, cabin_offset) + + return agent_surface, (min_x, min_y) + + def _render_agents_for_env(self, env_idx, world, agent, total_offset_x, total_offset_y, info): + """Render all agents for a specific environment and return surfaces and positions.""" + agent_surfaces = [] + agent_positions = [] + + # Check if we have additional agents in info - if so, skip primary agent to avoid duplication + has_additional_agents = (info and 'additional_agents' in info and + 'positions' in info['additional_agents'] and + len(info['additional_agents']['positions']) > 0) + + # Only render the primary agent if we don't have additional agents + if not has_additional_agents: + # Render the primary agent (from the original arrays) body_vertices = agent.agent["body"]["vertices"] ca = agent.agent["body"]["color"] @@ -225,32 +247,245 @@ def draw(self): surface_width = math.ceil(max_x - min_x) + 2 surface_height = math.ceil(max_y - min_y) + 2 - # Create surface for the agent - agent_surfaces.append( - pg.Surface((surface_width, surface_height), pg.SRCALPHA) - ) + # Create surface for the primary agent + primary_surface = pg.Surface((surface_width, surface_height), pg.SRCALPHA) + # Calculate surface position agent_x = min_x + total_offset_x agent_y = min_y + total_offset_y - agent_positions.append((agent_x, agent_y)) # Adjust vertices for the agent's surface offset_vertices = [(v[0] - min_x, v[1] - min_y) for v in body_vertices] - # Draw agent body as polygon - pg.draw.polygon(agent_surfaces[-1], ca, offset_vertices) + # Draw primary agent body as polygon + pg.draw.polygon(primary_surface, ca, offset_vertices) # Get cabin vertices and adjust for agent surface cabin = agent.agent["cabin"]["vertices"] cabin_offset = [(v[0] - min_x, v[1] - min_y) for v in cabin] cabin_color = agent.agent["cabin"]["color"] - pg.draw.polygon(agent_surfaces[-1], cabin_color, cabin_offset) + pg.draw.polygon(primary_surface, cabin_color, cabin_offset) + + agent_surfaces.append(primary_surface) + agent_positions.append((agent_x, agent_y)) + #print(f"Primary agent rendered at position: ({agent_x}, {agent_y})") + + # Render additional agents if they exist in info + if info and 'additional_agents' in info: + additional_agents = info['additional_agents'] + + # Check for the required keys in your specific format + if ('positions' in additional_agents and + 'angles base' in additional_agents and + 'angles cabin' in additional_agents and + 'loaded' in additional_agents): + + positions = additional_agents['positions'] + angles_base = additional_agents['angles base'] + angles_cabin = additional_agents['angles cabin'] + loaded_states = additional_agents['loaded'] + + # Handle both single environment and multi-environment cases + if env_idx == 0: # For single environment or first environment + #print(f"Processing additional agents for env_idx 0") + + # Convert numpy arrays/lists to Python lists if needed + if hasattr(positions, 'tolist'): + positions = positions.tolist() + if hasattr(angles_base, 'tolist'): + angles_base = angles_base.tolist() + if hasattr(angles_cabin, 'tolist'): + angles_cabin = angles_cabin.tolist() + if hasattr(loaded_states, 'tolist'): + loaded_states = loaded_states.tolist() + + # Render each additional agent + for i, pos in enumerate(positions): + + if i < len(angles_base) and i < len(angles_cabin) and i < len(loaded_states): + base_dir = angles_base[i] + cabin_dir = angles_cabin[i] + loaded = bool(loaded_states[i]) + + # Create agent surface with specific angles + additional_surface, offset = self._create_agent_surface( + agent, base_dir, cabin_dir, loaded + ) + + # Calculate position on screen + # pos is already in pixel coordinates based on your example + screen_x = pos[0] * self.tile_size + total_offset_x + offset[0] + screen_y = pos[1] * self.tile_size + total_offset_y + offset[1] + + agent_surfaces.append(additional_surface) + agent_positions.append((screen_x, screen_y)) + else: + print(f"Skipping additional agents for env_idx {env_idx}") - self.screen.blit(self.surface, (0, 0)) + return agent_surfaces, agent_positions + + def draw(self, info=None): + if info is None: + self.surface.fill("#F0F0F0") + agent_surfaces = [] + agent_positions = [] - for agent_surface, agent_position in zip(agent_surfaces, agent_positions): - self.screen.blit(agent_surface, agent_position) + for i, (world, agent) in enumerate(zip(self.worlds, self.agents)): + ix = i % self.n_envs_y + iy = i // self.n_envs_y - if self.display: - pg.display.flip() + total_offset_x = ( + ix * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size + ) + total_offset_y = ( + iy * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size + ) + + # Draw terrain + for x in range(world.grid_length_x): + for y in range(world.grid_length_y): + sq = world.action_map[x][y]["cart_rect"] + c = world.action_map[x][y]["color"] + rect = pg.Rect( + sq[0][0] + total_offset_x, + sq[0][1] + total_offset_y, + self.tile_size, + self.tile_size, + ) + pg.draw.rect(self.surface, c, rect, 0) + + # Highlight target tiles (where the digger will dig / dump) + if hasattr(world, 'target_tiles') and world.target_tiles is not None: + flat_idx = y * world.grid_length_x + x + if flat_idx < len(world.target_tiles) and world.target_tiles[flat_idx]: + pg.draw.rect(self.surface, "#FF3300", rect, 2) + + body_vertices = agent.agent["body"]["vertices"] + ca = agent.agent["body"]["color"] + + # Calculate the bounding box + min_x = min(v[0] for v in body_vertices) + min_y = min(v[1] for v in body_vertices) + max_x = max(v[0] for v in body_vertices) + max_y = max(v[1] for v in body_vertices) + + # Calculate surface size with a small padding + surface_width = math.ceil(max_x - min_x) + 2 + surface_height = math.ceil(max_y - min_y) + 2 + + # Create surface for the agent + agent_surfaces.append( + pg.Surface((surface_width, surface_height), pg.SRCALPHA) + ) + + # Calculate surface position + agent_x = min_x + total_offset_x + agent_y = min_y + total_offset_y + agent_positions.append((agent_x, agent_y)) + + # Adjust vertices for the agent's surface + offset_vertices = [(v[0] - min_x, v[1] - min_y) for v in body_vertices] + + # Draw agent body as polygon + pg.draw.polygon(agent_surfaces[-1], ca, offset_vertices) + + # Get cabin vertices and adjust for agent surface + cabin = agent.agent["cabin"]["vertices"] + cabin_offset = [(v[0] - min_x, v[1] - min_y) for v in cabin] + cabin_color = agent.agent["cabin"]["color"] + pg.draw.polygon(agent_surfaces[-1], cabin_color, cabin_offset) + + self.screen.blit(self.surface, (0, 0)) + + for agent_surface, agent_position in zip(agent_surfaces, agent_positions): + self.screen.blit(agent_surface, agent_position) + + if self.display: + pg.display.flip() + else: + self.surface.fill("#F0F0F0") + all_agent_surfaces = [] + all_agent_positions = [] + for i, (world, agent) in enumerate(zip(self.worlds, self.agents)): + ix = i % self.n_envs_y + iy = i // self.n_envs_y + + total_offset_x = ( + ix * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size + ) + total_offset_y = ( + iy * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size + ) + + # Draw terrain + for x in range(world.grid_length_x): + for y in range(world.grid_length_y): + sq = world.action_map[x][y]["cart_rect"] + c = world.action_map[x][y]["color"] + rect = pg.Rect( + sq[0][0] + total_offset_x, + sq[0][1] + total_offset_y, + self.tile_size, + self.tile_size, + ) + pg.draw.rect(self.surface, c, rect, 0) + + # # Highlight target tiles (where the digger will dig / dump) + # if hasattr(world, 'target_tiles') and world.target_tiles is not None: + # flat_idx = y * world.grid_length_x + x + # if flat_idx < len(world.target_tiles) and world.target_tiles[flat_idx]: + # pg.draw.rect(self.surface, "#FF3300", rect, 2) + # Draw partition rectangles (only for the first environment) + if info and info.get('show_partitions', False) and i == 0: + partitions = info.get('partitions', []) + self._draw_partition_rectangles(partitions, total_offset_x, total_offset_y) + + # Render all agents for this environment + env_agent_surfaces, env_agent_positions = self._render_agents_for_env( + i, world, agent, total_offset_x, total_offset_y, info + ) + + all_agent_surfaces.extend(env_agent_surfaces) + all_agent_positions.extend(env_agent_positions) + + self.screen.blit(self.surface, (0, 0)) + + + for agent_surface, agent_position in zip(all_agent_surfaces, all_agent_positions): + self.screen.blit(agent_surface, agent_position) + + if self.display: + pg.display.flip() + + def _draw_partition_rectangles(self, partitions, total_offset_x, total_offset_y): + """Draw simple rectangles around each partition.""" + #import pygame as pg + + for i, partition in enumerate(partitions): + y_start, x_start, y_end, x_end = partition['region_coords'] + status = partition['status'] + + # Convert to screen coordinates + rect_x = x_start * self.tile_size + total_offset_x + rect_y = y_start * self.tile_size + total_offset_y + rect_width = (x_end - x_start + 1) * self.tile_size + rect_height = (y_end - y_start + 1) * self.tile_size + + # Choose color based on status + if status == 'active': + color = (0, 255, 0) # Green + width = 3 + elif status == 'completed': + color = (0, 0, 255) # Blue + width = 2 + elif status == 'failed': + color = (255, 0, 0) # Red + width = 2 + else: # pending + color = (255, 255, 0) # Yellow + width = 1 + + # Draw the rectangle + pg.draw.rect(self.surface, color, + (rect_x, rect_y, rect_width, rect_height), width) \ No newline at end of file From d1c17011aad90006b9abe63d674f500191d0b29f Mon Sep 17 00:00:00 2001 From: Gioele Molinari Date: Mon, 21 Jul 2025 16:16:40 +0200 Subject: [PATCH 3/9] Small fix in visualization --- terra/viz/game/game.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index 804d2bfe4..e3ee0c662 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -326,7 +326,7 @@ def _render_agents_for_env(self, env_idx, world, agent, total_offset_x, total_of return agent_surfaces, agent_positions def draw(self, info=None): - if info is None: + if "additional_agents" not in info or info["additional_agents"] is None: self.surface.fill("#F0F0F0") agent_surfaces = [] agent_positions = [] From 334fd86de736f36efd5255ac10b4cb7e46c5c00d Mon Sep 17 00:00:00 2001 From: Gioele Molinari Date: Mon, 21 Jul 2025 16:54:18 +0200 Subject: [PATCH 4/9] Small fix --- terra/viz/game/game.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index e3ee0c662..e6726f018 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -326,7 +326,7 @@ def _render_agents_for_env(self, env_idx, world, agent, total_offset_x, total_of return agent_surfaces, agent_positions def draw(self, info=None): - if "additional_agents" not in info or info["additional_agents"] is None: + if info is None or "additional_agents" not in info or info["additional_agents"] is None: self.surface.fill("#F0F0F0") agent_surfaces = [] agent_positions = [] From 421b310a874c2a32e4e51aff65765b9c90cd632a Mon Sep 17 00:00:00 2001 From: Gioele Molinari Date: Thu, 24 Jul 2025 14:55:21 +0200 Subject: [PATCH 5/9] Fix config --- .gitignore | 3 +- terra/config.py | 74 ++++++++++++++++++++++++------------------------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 14e89c0fa..6c58d383f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ terra/digbench/ data/* !data/custom/ !data/custom/** -*.pkl \ No newline at end of file +*.pkl +.DS_Store \ No newline at end of file diff --git a/terra/config.py b/terra/config.py index 08715263a..fd1bc7730 100644 --- a/terra/config.py +++ b/terra/config.py @@ -176,47 +176,47 @@ class CurriculumGlobalConfig(NamedTuple): # NOTE: all maps need to have the same size levels = [ { - "maps_path": "terra/foundations", + "maps_path": "foundations", "max_steps_in_episode": 400, "rewards_type": RewardsType.DENSE, "apply_trench_rewards": False, }, - { - "maps_path": "terra/trenches/single", - "max_steps_in_episode": 400, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": True, - }, - { - "maps_path": "terra/trenches/double", - "max_steps_in_episode": 400, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": True, - }, - { - "maps_path": "terra/trenches/double_diagonal", - "max_steps_in_episode": 400, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": True, - }, - { - "maps_path": "terra/foundations", - "max_steps_in_episode": 400, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": False, - }, - { - "maps_path": "terra/trenches/triple_diagonal", - "max_steps_in_episode": 400, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": True, - }, - { - "maps_path": "terra/foundations_large", - "max_steps_in_episode": 500, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": False, - }, + # { + # "maps_path": "trenches/single", + # "max_steps_in_episode": 400, + # "rewards_type": RewardsType.DENSE, + # "apply_trench_rewards": True, + # }, + # { + # "maps_path": "trenches/double", + # "max_steps_in_episode": 400, + # "rewards_type": RewardsType.DENSE, + # "apply_trench_rewards": True, + # }, + # { + # "maps_path": "terra/trenches/double_diagonal", + # "max_steps_in_episode": 400, + # "rewards_type": RewardsType.DENSE, + # "apply_trench_rewards": True, + # }, + # { + # "maps_path": "trenches/triple", + # "max_steps_in_episode": 400, + # "rewards_type": RewardsType.DENSE, + # "apply_trench_rewards": False, + # }, + # { + # "maps_path": "terra/trenches/triple_diagonal", + # "max_steps_in_episode": 400, + # "rewards_type": RewardsType.DENSE, + # "apply_trench_rewards": True, + # }, + # { + # "maps_path": "terra/foundations_large", + # "max_steps_in_episode": 500, + # "rewards_type": RewardsType.DENSE, + # "apply_trench_rewards": False, + # }, ] From e89cefd8b63880eb0e9a601ce4753fa3020941ca Mon Sep 17 00:00:00 2001 From: Gioele Molinari <63927701+gioelemo@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:30:42 +0200 Subject: [PATCH 6/9] Update terra/agent.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- terra/agent.py | 52 +------------------------------------------------- 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/terra/agent.py b/terra/agent.py index 4eda547cd..78ccdc55e 100644 --- a/terra/agent.py +++ b/terra/agent.py @@ -196,57 +196,7 @@ def return_false(_): return valid -def _validate_agent_position( - pos_base: Array, - angle_base: Array, - env_cfg: EnvConfig, - padding_mask: Array, - agent_width: int, - agent_height: int, -) -> Array: - """ - Validate if an agent position is valid (within bounds and not intersecting obstacles). - - Returns: - JAX array with boolean value indicating if the position is valid - """ - map_width = padding_mask.shape[0] - map_height = padding_mask.shape[1] - - # Check if position is within bounds - max_center_coord = jnp.ceil( - jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1])) - ).astype(IntMap) - - max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width) - max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height) - - within_bounds = jnp.logical_and( - jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord), - jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord) - ) - - # Check if position intersects with obstacles - def check_obstacle_intersection(_): - agent_corners_xy = get_agent_corners( - pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base - ) - polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height) - has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1)) - return jnp.logical_not(has_obstacle) - - def return_false(_): - return jnp.array(False) - - # Only check obstacles if we're within bounds (to avoid unnecessary computations) - valid = jax.lax.cond( - within_bounds, - check_obstacle_intersection, - return_false, - None - ) - - return valid +# Removed duplicate definition of _validate_agent_position def _get_top_left_init_state(key: jax.random.PRNGKey, env_cfg: EnvConfig): From 9f7b5b2b353fbc12c53bd004a893f8f716f8fa7f Mon Sep 17 00:00:00 2001 From: Gioele Molinari <63927701+gioelemo@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:32:01 +0200 Subject: [PATCH 7/9] Update terra/viz/game/game.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- terra/viz/game/game.py | 1 - 1 file changed, 1 deletion(-) diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index e6726f018..7abacd7ed 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -200,7 +200,6 @@ def _create_agent_surface(self, agent, base_dir, cabin_dir, loaded): # Calculate bounding box all_vertices = body_vertices + cabin_vertices - #all_vertices = body_vertices min_x = min(v[0] for v in all_vertices) min_y = min(v[1] for v in all_vertices) max_x = max(v[0] for v in all_vertices) From 9f40fcff9b9831923b50772a8f1ceac9f392b6bd Mon Sep 17 00:00:00 2001 From: Gioele Molinari <63927701+gioelemo@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:32:42 +0200 Subject: [PATCH 8/9] Update terra/viz/game/game.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- terra/viz/game/game.py | 1 - 1 file changed, 1 deletion(-) diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index 7abacd7ed..a4d4d283f 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -459,7 +459,6 @@ def draw(self, info=None): def _draw_partition_rectangles(self, partitions, total_offset_x, total_offset_y): """Draw simple rectangles around each partition.""" - #import pygame as pg for i, partition in enumerate(partitions): y_start, x_start, y_end, x_end = partition['region_coords'] From 045b1e073a6333cf195f7e79d62ab4f2c0c2422c Mon Sep 17 00:00:00 2001 From: Gioele Molinari Date: Sat, 2 Aug 2025 16:34:08 +0200 Subject: [PATCH 9/9] Support big maps --- terra/config.py | 20 ++++++++++---------- terra/env.py | 3 +++ terra/viz/game/game.py | 19 ++++++++++++++++--- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/terra/config.py b/terra/config.py index fd1bc7730..a4bcd2228 100644 --- a/terra/config.py +++ b/terra/config.py @@ -175,24 +175,24 @@ class CurriculumGlobalConfig(NamedTuple): # NOTE: all maps need to have the same size levels = [ - { - "maps_path": "foundations", - "max_steps_in_episode": 400, - "rewards_type": RewardsType.DENSE, - "apply_trench_rewards": False, - }, # { - # "maps_path": "trenches/single", + # "maps_path": "foundations", # "max_steps_in_episode": 400, # "rewards_type": RewardsType.DENSE, - # "apply_trench_rewards": True, + # "apply_trench_rewards": False, # }, # { - # "maps_path": "trenches/double", + # "maps_path": "trenches/single", # "max_steps_in_episode": 400, # "rewards_type": RewardsType.DENSE, # "apply_trench_rewards": True, # }, + { + "maps_path": "trenches/double", + "max_steps_in_episode": 400, + "rewards_type": RewardsType.DENSE, + "apply_trench_rewards": True, + }, # { # "maps_path": "terra/trenches/double_diagonal", # "max_steps_in_episode": 400, @@ -206,7 +206,7 @@ class CurriculumGlobalConfig(NamedTuple): # "apply_trench_rewards": False, # }, # { - # "maps_path": "terra/trenches/triple_diagonal", + # "maps_path": "trenches/triple_diagonal", # "max_steps_in_episode": 400, # "rewards_type": RewardsType.DENSE, # "apply_trench_rewards": True, diff --git a/terra/env.py b/terra/env.py index 11b5653ce..9e5cf55ea 100644 --- a/terra/env.py +++ b/terra/env.py @@ -40,6 +40,7 @@ def new( n_envs_x: int = 1, n_envs_y: int = 1, display: bool = False, + agent_config_override: Optional[dict[str, Any]] = None, ) -> "TerraEnv": re = None tile_size_rendering = MAP_TILES // maps_size_px @@ -69,6 +70,8 @@ def new( n_envs_x=n_envs_x, n_envs_y=n_envs_y, display=display, + agent_config=agent_config_override if agent_config_override is not None else None, + ) return TerraEnv(rendering_engine=re) diff --git a/terra/viz/game/game.py b/terra/viz/game/game.py index a4d4d283f..a16ce4582 100644 --- a/terra/viz/game/game.py +++ b/terra/viz/game/game.py @@ -34,6 +34,7 @@ def __init__( n_envs_x=1, n_envs_y=1, display=True, + agent_config=None, ): self.screen = screen self.surface = surface @@ -51,10 +52,21 @@ def __init__( self.maps_size_px = maps_size_px tile_size = MAP_TILES // maps_size_px self.tile_size = tile_size - excavator_dims = ExcavatorDims() - agent_h, agent_w = get_agent_dims( + + if maps_size_px == 128: + # Extract dimensions from config + agent_config = {} + agent_config['height'] = 5 + agent_config['width'] = 9 + agent_h = int(agent_config['height'][0]) if hasattr(agent_config['height'], '__getitem__') else int(agent_config['height']) + agent_w = int(agent_config['width'][0]) if hasattr(agent_config['width'], '__getitem__') else int(agent_config['width']) + print(f"Using provided agent config: {agent_w}x{agent_h}") + + else: + excavator_dims = ExcavatorDims() + agent_h, agent_w = get_agent_dims( excavator_dims.WIDTH, excavator_dims.HEIGHT, tile_size_m - ) + ) angles_base = ImmutableAgentConfig().angles_base angles_cabin = ImmutableAgentConfig().angles_cabin print(f"Agent size (in rendering): {agent_w}x{agent_h}") @@ -406,6 +418,7 @@ def draw(self, info=None): self.surface.fill("#F0F0F0") all_agent_surfaces = [] all_agent_positions = [] + print("Map size in px:", self.maps_size_px) for i, (world, agent) in enumerate(zip(self.worlds, self.agents)): ix = i % self.n_envs_y iy = i // self.n_envs_y