Skip to content

LLM Version #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ terra/digbench/
data/*
!data/custom/
!data/custom/**
*.pkl
*.pkl
.DS_Store
197 changes: 182 additions & 15 deletions terra/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -49,21 +49,82 @@ def new(
max_traversable_y: int,
padding_mask: Array,
action_map: Array,
custom_pos: Optional[Tuple[int, int]] = None,
custom_angle: Optional[int] = None,

) -> tuple["Agent", jax.random.PRNGKey]:
"""
Create a new agent with specified parameters.

Args:
key: JAX random key
env_cfg: Environment configuration
max_traversable_x: Maximum traversable x coordinate
max_traversable_y: Maximum traversable y coordinate
padding_mask: Mask indicating obstacles
custom_pos: Optional custom position (x, y) to place the agent
custom_angle: Optional custom angle for the agent

Returns:
New agent instance and updated random key
"""
# Handle custom position or default based on config
has_custom_args = (custom_pos is not None) or (custom_angle is not None)

def use_custom_position(k):
# Create position based on custom args or defaults
temp_pos = IntMap(jnp.array(custom_pos)) if custom_pos is not None else IntMap(jnp.array([-1, -1]))
temp_angle = jnp.full((1,), custom_angle, dtype=IntMap) if custom_angle is not None else jnp.full((1,), -1, dtype=IntMap)

# Get default position for missing components
def_pos, def_angle, _ = _get_top_left_init_state(k, env_cfg)

# Combine custom and default values
pos = jnp.where(jnp.any(temp_pos < 0), def_pos, temp_pos)
angle = jnp.where(jnp.any(temp_angle < 0), def_angle, temp_angle)

# Check validity and return result using jax.lax.cond
valid = _validate_agent_position(
pos, angle, env_cfg, padding_mask,
env_cfg.agent.width, env_cfg.agent.height
)

# Define the true and false branches for jax.lax.cond
def true_fn(_):
return (pos, angle, k)

def false_fn(_):
return jax.lax.cond(
env_cfg.agent.random_init_state,
lambda k_inner: _get_random_init_state(
k_inner, env_cfg, max_traversable_x, max_traversable_y,
padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height,
),
lambda k_inner: _get_top_left_init_state(k_inner, env_cfg),
k
)

# Use jax.lax.cond to handle the validity check
return jax.lax.cond(valid, true_fn, false_fn, None)

def use_default_position(k):
# Use existing logic for random or top-left position
return jax.lax.cond(
env_cfg.agent.random_init_state,
lambda k_inner: _get_random_init_state(
k_inner, env_cfg, max_traversable_x, max_traversable_y,
padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height,
),
lambda k_inner: _get_top_left_init_state(k_inner, env_cfg),
k
)

# Use jax.lax.cond for JAX-compatible control flow
pos_base, angle_base, key = jax.lax.cond(
env_cfg.agent.random_init_state,
lambda k: _get_random_init_state(
k,
env_cfg,
max_traversable_x,
max_traversable_y,
padding_mask,
action_map,
env_cfg.agent.width,
env_cfg.agent.height,
),
lambda k: _get_top_left_init_state(k, env_cfg),
key,
has_custom_args,
use_custom_position,
use_default_position,
key
)

agent_state = AgentState(
Expand All @@ -82,6 +143,112 @@ def new(
return Agent(agent_state=agent_state, width=width, height=height, moving_dumped_dirt=moving_dumped_dirt), key


def _validate_agent_position(
pos_base: Array,
angle_base: Array,
env_cfg: EnvConfig,
padding_mask: Array,
agent_width: int,
agent_height: int,
) -> Array:
"""
Validate if an agent position is valid (within bounds and not intersecting obstacles).

Returns:
JAX array with boolean value indicating if the position is valid
"""
map_width = padding_mask.shape[0]
map_height = padding_mask.shape[1]

# Check if position is within bounds
max_center_coord = jnp.ceil(
jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1]))
).astype(IntMap)

max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width)
max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height)

within_bounds = jnp.logical_and(
jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord),
jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord)
)

# Check if position intersects with obstacles
def check_obstacle_intersection(_):
agent_corners_xy = get_agent_corners(
pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base
)
polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height)
has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1))
return jnp.logical_not(has_obstacle)

def return_false(_):
return jnp.array(False)

# Only check obstacles if we're within bounds (to avoid unnecessary computations)
valid = jax.lax.cond(
within_bounds,
check_obstacle_intersection,
return_false,
None
)

return valid


def _validate_agent_position(
pos_base: Array,
angle_base: Array,
env_cfg: EnvConfig,
padding_mask: Array,
agent_width: int,
agent_height: int,
) -> Array:
"""
Validate if an agent position is valid (within bounds and not intersecting obstacles).

Returns:
JAX array with boolean value indicating if the position is valid
"""
map_width = padding_mask.shape[0]
map_height = padding_mask.shape[1]

# Check if position is within bounds
max_center_coord = jnp.ceil(
jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1]))
).astype(IntMap)

max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width)
max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height)

within_bounds = jnp.logical_and(
jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord),
jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord)
)

# Check if position intersects with obstacles
def check_obstacle_intersection(_):
agent_corners_xy = get_agent_corners(
pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base
)
polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height)
has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1))
return jnp.logical_not(has_obstacle)

def return_false(_):
return jnp.array(False)

# Only check obstacles if we're within bounds (to avoid unnecessary computations)
valid = jax.lax.cond(
within_bounds,
check_obstacle_intersection,
return_false,
None
)

return valid


def _get_top_left_init_state(key: jax.random.PRNGKey, env_cfg: EnvConfig):
max_center_coord = jnp.ceil(
jnp.max(
Expand Down Expand Up @@ -174,4 +341,4 @@ def _check_intersection():
),
)

return pos_base, angle_base, key
return pos_base, angle_base, key
74 changes: 37 additions & 37 deletions terra/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,47 +176,47 @@ class CurriculumGlobalConfig(NamedTuple):
# NOTE: all maps need to have the same size
levels = [
{
"maps_path": "terra/foundations",
"maps_path": "foundations",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": False,
},
{
"maps_path": "terra/trenches/single",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/trenches/double",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/trenches/double_diagonal",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/foundations",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": False,
},
{
"maps_path": "terra/trenches/triple_diagonal",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/foundations_large",
"max_steps_in_episode": 500,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": False,
},
# {
# "maps_path": "trenches/single",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "trenches/double",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "terra/trenches/double_diagonal",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "trenches/triple",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": False,
# },
# {
# "maps_path": "terra/trenches/triple_diagonal",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "terra/foundations_large",
# "max_steps_in_episode": 500,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": False,
# },
]


Expand Down
Loading