Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions packages/api/internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import (
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/api/internal/api"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/api/internal/cfg"
"github.com/e2b-dev/infra/packages/api/internal/db"
"github.com/e2b-dev/infra/packages/api/internal/db/types"
"github.com/e2b-dev/infra/packages/shared/pkg/telemetry"
)

Expand Down Expand Up @@ -134,13 +134,13 @@ func adminValidationFunction(adminToken string) func(context.Context, string) (s

func CreateAuthenticationFunc(
config cfg.Config,
teamValidationFunction func(context.Context, string) (authcache.AuthTeamInfo, *api.APIError),
teamValidationFunction func(context.Context, string) (*types.Team, *api.APIError),
userValidationFunction func(context.Context, string) (uuid.UUID, *api.APIError),
supabaseTokenValidationFunction func(context.Context, string) (uuid.UUID, *api.APIError),
supabaseTeamValidationFunction func(context.Context, string) (authcache.AuthTeamInfo, *api.APIError),
supabaseTeamValidationFunction func(context.Context, string) (*types.Team, *api.APIError),
) openapi3filter.AuthenticationFunc {
authenticators := []authenticator{
&commonAuthenticator[authcache.AuthTeamInfo]{
&commonAuthenticator[*types.Team]{
securitySchemeName: "ApiKeyAuth",
headerKey: headerKey{
name: "X-API-Key",
Expand Down Expand Up @@ -173,7 +173,7 @@ func CreateAuthenticationFunc(
contextKey: UserIDContextKey,
errorMessage: "Invalid Supabase token.",
},
&commonAuthenticator[authcache.AuthTeamInfo]{
&commonAuthenticator[*types.Team]{
securitySchemeName: "Supabase2TeamAuth",
headerKey: headerKey{
name: "X-Supabase-Team",
Expand Down
28 changes: 11 additions & 17 deletions packages/api/internal/cache/auth/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,22 @@ import (
"github.com/jellydator/ttlcache/v3"
"golang.org/x/sync/singleflight"

"github.com/e2b-dev/infra/packages/db/queries"
"github.com/e2b-dev/infra/packages/api/internal/db/types"
)

const (
authInfoExpiration = 5 * time.Minute
refreshInterval = 1 * time.Minute
)

type AuthTeamInfo struct {
Team *queries.Team
Tier *queries.Tier
}

type TeamInfo struct {
team *queries.Team
tier *queries.Tier
team *types.Team

lastRefresh time.Time
once singleflight.Group
}

type DataCallback = func(ctx context.Context, key string) (*queries.Team, *queries.Tier, error)
type DataCallback = func(ctx context.Context, key string) (*types.Team, error)

type TeamAuthCache struct {
cache *ttlcache.Cache[string, *TeamInfo]
Expand All @@ -45,21 +39,21 @@ func NewTeamAuthCache() *TeamAuthCache {
}

// TODO: save blocked teams to cache as well, handle the condition in the GetOrSet method
func (c *TeamAuthCache) GetOrSet(ctx context.Context, key string, dataCallback DataCallback) (team *queries.Team, tier *queries.Tier, err error) {
func (c *TeamAuthCache) GetOrSet(ctx context.Context, key string, dataCallback DataCallback) (team *types.Team, err error) {
var item *ttlcache.Item[string, *TeamInfo]
var templateInfo *TeamInfo

item = c.cache.Get(key)
if item == nil {
team, tier, err = dataCallback(ctx, key)
team, err = dataCallback(ctx, key)
if err != nil {
return nil, nil, fmt.Errorf("error while getting the team: %w", err)
return nil, fmt.Errorf("error while getting the team: %w", err)
}

templateInfo = &TeamInfo{team: team, tier: tier, lastRefresh: time.Now()}
templateInfo = &TeamInfo{team: team, lastRefresh: time.Now()}
c.cache.Set(key, templateInfo, authInfoExpiration)

return team, tier, nil
return team, nil
}

templateInfo = item.Value()
Expand All @@ -70,20 +64,20 @@ func (c *TeamAuthCache) GetOrSet(ctx context.Context, key string, dataCallback D
})
}

return templateInfo.team, templateInfo.tier, nil
return templateInfo.team, nil
}

// Refresh refreshes the cache for the given team ID.
func (c *TeamAuthCache) Refresh(key string, dataCallback DataCallback) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

team, tier, err := dataCallback(ctx, key)
team, err := dataCallback(ctx, key)
if err != nil {
c.cache.Delete(key)

return
}

c.cache.Set(key, &TeamInfo{team: team, tier: tier, lastRefresh: time.Now()}, authInfoExpiration)
c.cache.Set(key, &TeamInfo{team: team, lastRefresh: time.Now()}, authInfoExpiration)
}
10 changes: 6 additions & 4 deletions packages/api/internal/db/apikeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/e2b-dev/infra/packages/api/internal/db/types"
sqlcdb "github.com/e2b-dev/infra/packages/db/client"
"github.com/e2b-dev/infra/packages/db/queries"
)
Expand Down Expand Up @@ -36,18 +37,19 @@ func validateTeamUsage(team queries.Team) error {
return nil
}

func GetTeamAuth(ctx context.Context, db *sqlcdb.Client, apiKey string) (*queries.Team, *queries.Tier, error) {
func GetTeamAuth(ctx context.Context, db *sqlcdb.Client, apiKey string) (*types.Team, error) {
result, err := db.GetTeamWithTierByAPIKeyWithUpdateLastUsed(ctx, apiKey)
if err != nil {
errMsg := fmt.Errorf("failed to get team from API key: %w", err)

return nil, nil, errMsg
return nil, errMsg
}

err = validateTeamUsage(result.Team)
if err != nil {
return nil, nil, err
return nil, err
}

return &result.Team, &result.Tier, nil
team := types.NewTeam(&result.Team, &result.TeamLimit)
return team, nil
}
28 changes: 28 additions & 0 deletions packages/api/internal/db/teams.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package db

import (
"context"
"fmt"

"github.com/google/uuid"

"github.com/e2b-dev/infra/packages/api/internal/db/types"
sqlcdb "github.com/e2b-dev/infra/packages/db/client"
)

func GetTeamsByUser(ctx context.Context, db *sqlcdb.Client, userID uuid.UUID) ([]*types.TeamWithDefault, error) {
teams, err := db.GetTeamsWithUsersTeamsWithTier(ctx, userID)
if err != nil {
return nil, fmt.Errorf("error when getting default team: %w", err)
}

teamsWithLimits := make([]*types.TeamWithDefault, 0, len(teams))
for _, team := range teams {
teamsWithLimits = append(teamsWithLimits, &types.TeamWithDefault{
Team: types.NewTeam(&team.Team, &team.TeamLimit),
IsDefault: team.UsersTeam.IsDefault,
})
}

return teamsWithLimits, nil
}
11 changes: 11 additions & 0 deletions packages/api/internal/db/types/limits.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package types

type TeamLimits struct {
SandboxConcurrency int64
BuildConcurrency int64
MaxLengthHours int64

MaxVcpu int64
MaxRamMb int64
DiskMb int64
}
40 changes: 40 additions & 0 deletions packages/api/internal/db/types/teams.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package types

import (
"github.com/e2b-dev/infra/packages/db/queries"
)

type Team struct {
*queries.Team

Limits *TeamLimits
}

func newTeamLimits(
teamLimits *queries.TeamLimit,
) *TeamLimits {
return &TeamLimits{
SandboxConcurrency: int64(teamLimits.ConcurrentSandboxes),
BuildConcurrency: int64(teamLimits.ConcurrentTemplateBuilds),
MaxLengthHours: teamLimits.MaxLengthHours,
MaxVcpu: int64(teamLimits.MaxVcpu),
MaxRamMb: int64(teamLimits.MaxRamMb),
DiskMb: int64(teamLimits.DiskMb),
}
}

func NewTeam(
team *queries.Team,
teamLimits *queries.TeamLimit,
) *Team {
return &Team{
Team: team,
Limits: newTeamLimits(teamLimits),
}
}

type TeamWithDefault struct {
*Team

IsDefault bool
}
12 changes: 7 additions & 5 deletions packages/api/internal/db/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ import (

"github.com/google/uuid"

"github.com/e2b-dev/infra/packages/api/internal/db/types"
sqlcdb "github.com/e2b-dev/infra/packages/db/client"
"github.com/e2b-dev/infra/packages/db/queries"
)

func GetTeamByIDAndUserIDAuth(ctx context.Context, db *sqlcdb.Client, teamID string, userID uuid.UUID) (*queries.Team, *queries.Tier, error) {
func GetTeamByIDAndUserIDAuth(ctx context.Context, db *sqlcdb.Client, teamID string, userID uuid.UUID) (*types.Team, error) {
teamIDParsed, err := uuid.Parse(teamID)
if err != nil {
errMsg := fmt.Errorf("failed to parse team ID: %w", err)

return nil, nil, errMsg
return nil, errMsg
}

result, err := db.GetTeamWithTierByTeamAndUser(ctx, queries.GetTeamWithTierByTeamAndUserParams{
Expand All @@ -25,13 +26,14 @@ func GetTeamByIDAndUserIDAuth(ctx context.Context, db *sqlcdb.Client, teamID str
if err != nil {
errMsg := fmt.Errorf("failed to get team from teamID and userID key: %w", err)

return nil, nil, errMsg
return nil, errMsg
}

err = validateTeamUsage(result.Team)
if err != nil {
return nil, nil, err
return nil, err
}

return &result.Team, &result.Tier, nil
team := types.NewTeam(&result.Team, &result.TeamLimit)
return team, nil
}
49 changes: 25 additions & 24 deletions packages/api/internal/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,101 +10,102 @@ import (

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/db/queries"
dbapi "github.com/e2b-dev/infra/packages/api/internal/db"
"github.com/e2b-dev/infra/packages/api/internal/db/types"
)

func (a *APIStore) GetUserID(c *gin.Context) uuid.UUID {
return c.Value(auth.UserIDContextKey).(uuid.UUID)
}

func (a *APIStore) GetUserAndTeams(c *gin.Context) (*uuid.UUID, []queries.GetTeamsWithUsersTeamsWithTierRow, error) {
func (a *APIStore) GetUserAndTeams(c *gin.Context) (*uuid.UUID, []*types.TeamWithDefault, error) {
userID := a.GetUserID(c)
ctx := c.Request.Context()

teams, err := a.sqlcDB.GetTeamsWithUsersTeamsWithTier(ctx, userID)
teams, err := dbapi.GetTeamsByUser(ctx, a.sqlcDB, userID)
if err != nil {
return nil, nil, fmt.Errorf("error when getting default team: %w", err)
}

return &userID, teams, err
}

func (a *APIStore) GetTeamInfo(c *gin.Context) authcache.AuthTeamInfo {
return c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo)
func (a *APIStore) GetTeamInfo(c *gin.Context) *types.Team {
return c.Value(auth.TeamContextKey).(*types.Team)
}

func (a *APIStore) GetTeamAndTier(
func (a *APIStore) GetTeamAndLimits(
c *gin.Context,
// Deprecated: use API Token authentication instead.
teamID *string,
) (*queries.Team, *queries.Tier, *api.APIError) {
) (*types.Team, *api.APIError) {
_, span := tracer.Start(c.Request.Context(), "get-team-and-tier")
defer span.End()

if c.Value(auth.TeamContextKey) != nil {
teamInfo := c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo)
teamInfo := c.Value(auth.TeamContextKey).(*types.Team)

return teamInfo.Team, teamInfo.Tier, nil
return teamInfo, nil
} else if c.Value(auth.UserIDContextKey) != nil {
_, teams, err := a.GetUserAndTeams(c)
if err != nil {
return nil, nil, &api.APIError{
return nil, &api.APIError{
Code: http.StatusInternalServerError,
ClientMsg: "Error when getting user and teams",
Err: err,
}
}
team, tier, err := findTeamAndTier(teams, teamID)

team, err := findTeamAndLimits(teams, teamID)
if err != nil {
if teamID == nil {
return nil, nil, &api.APIError{
return nil, &api.APIError{
Code: http.StatusInternalServerError,
ClientMsg: "Default team not found",
Err: err,
}
}

return nil, nil, &api.APIError{
return nil, &api.APIError{
Code: http.StatusForbidden,
ClientMsg: "You are not allowed to access this team",
Err: err,
}
}

return team, tier, nil
return team, nil
}

return nil, nil, &api.APIError{
return nil, &api.APIError{
Code: http.StatusUnauthorized,
ClientMsg: "You are not authenticated",
Err: errors.New("invalid authentication context for team and tier"),
}
}

// findTeamAndTier finds the appropriate team and tier based on the provided teamID or returns the default team
func findTeamAndTier(teams []queries.GetTeamsWithUsersTeamsWithTierRow, teamID *string) (*queries.Team, *queries.Tier, error) {
// findTeamAndTier finds the appropriate team and limits based on the provided teamID or returns the default team
func findTeamAndLimits(teams []*types.TeamWithDefault, teamID *string) (*types.Team, error) {
if teamID != nil {
teamUUID, err := uuid.Parse(*teamID)
if err != nil {
return nil, nil, fmt.Errorf("invalid team ID: %s", *teamID)
return nil, fmt.Errorf("invalid team ID: %s", *teamID)
}

for _, t := range teams {
if t.Team.ID == teamUUID {
return &t.Team, &t.Tier, nil
return t.Team, nil
}
}

return nil, nil, fmt.Errorf("team '%s' not found", *teamID)
return nil, fmt.Errorf("team '%s' not found", *teamID)
}

// Find default team
for _, t := range teams {
if t.UsersTeam.IsDefault {
return &t.Team, &t.Tier, nil
if t.IsDefault {
return t.Team, nil
}
}

return nil, nil, fmt.Errorf("default team not found")
return nil, fmt.Errorf("default team not found")
}
Loading
Loading