Skip to content

Commit 5b29946

Browse files
committed
fix higher role user registration requirements
1 parent f4630c4 commit 5b29946

File tree

4 files changed

+119
-70
lines changed

4 files changed

+119
-70
lines changed

compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
services:
2-
opp-backend:
2+
opp-auth:
33
build:
44
context: .
55
dockerfile: Containerfile.dev

src/auth/auth.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"OPP/auth/api"
45
opp_jwt "OPP/auth/jwt"
56
"context"
67
"errors"
@@ -14,59 +15,74 @@ import (
1415

1516
var DEBUG_MODE = os.Getenv("DEBUG_MODE")
1617

17-
func AuthenticationFunc(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
18+
func AuthenticationWrapperFunc(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
1819
req := input.RequestValidationInput.Request
1920
if req == nil {
2021
return errors.New("missing HTTP request in authentication input")
2122
}
2223

23-
authHeader := req.Header.Get("Authorization")
24+
username, role, err := AuthenticationFunc(req.Header.Get("Authorization"))
25+
if err != nil {
26+
return err
27+
}
28+
29+
// Update the request context with the username and role
30+
ctx = context.WithValue(ctx, "username", username)
31+
ctx = context.WithValue(ctx, "role", role)
32+
33+
*req = *req.WithContext(ctx)
34+
35+
return nil
36+
}
37+
38+
// AuthenticationFunc can be used for endpoints that aren't marked as requiring authentication
39+
// but still need to check auth tokens when provided.
40+
// Returns (username, role, error) where error is nil if authentication succeeded
41+
func AuthenticationFunc(authHeader string) (string, api.UserRequestRole, error) {
42+
// Debug mode: override username and role
43+
if DEBUG_MODE == "true" {
44+
return "admin_debug", api.UserRequestRoleAdmin, nil
45+
}
46+
2447
if authHeader == "" {
25-
return errors.New("missing Authorization header")
48+
return "", "", errors.New("missing Authorization header")
2649
}
2750

2851
if !strings.HasPrefix(authHeader, "Bearer ") {
29-
return errors.New("invalid Authorization header format")
52+
return "", "", errors.New("invalid Authorization header format")
3053
}
3154

32-
tokenstr := strings.TrimPrefix(authHeader, "Bearer ")
33-
token, err := opp_jwt.ValidateToken(tokenstr)
55+
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
56+
token, err := opp_jwt.ValidateToken(tokenStr)
3457
if err != nil {
35-
return errors.New("failed to parse token")
58+
return "", "", errors.New("failed to parse token: " + err.Error())
3659
}
60+
3761
claims, ok := token.Claims.(jwt.MapClaims)
3862
if !ok || !token.Valid {
39-
return errors.New("invalid token")
63+
return "", "", errors.New("invalid token")
4064
}
4165

66+
// Validate expiration time
4267
expire, err := claims.GetExpirationTime()
4368
if err != nil {
44-
return errors.New("failed to get expiration time")
69+
return "", "", errors.New("failed to get expiration time")
4570
}
4671

4772
if expire.Before(time.Now()) {
48-
return errors.New("token expired")
73+
return "", "", errors.New("token expired")
4974
}
5075

5176
username, ok := claims["username"].(string)
5277
if !ok {
53-
return errors.New("missing username in token claims")
54-
}
55-
role, ok := claims["role"].(string)
56-
if !ok {
57-
return errors.New("missing role in token claims")
78+
return "", "", errors.New("missing username in token claims")
5879
}
5980

60-
// Update the request context with the username and role
61-
ctx = context.WithValue(ctx, "username", username)
62-
ctx = context.WithValue(ctx, "role", role)
63-
64-
// Debug mode: override role with "admin"
65-
if DEBUG_MODE == "true" {
66-
ctx = context.WithValue(ctx, "role", "admin")
81+
roleStr, ok := claims["role"].(string)
82+
if !ok {
83+
return "", "", errors.New("missing role in token claims")
6784
}
6885

69-
*req = *req.WithContext(ctx)
70-
71-
return nil
86+
role := api.UserRequestRole(roleStr)
87+
return username, role, nil
7288
}

src/handlers/session.go

Lines changed: 76 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package handlers
22

33
import (
44
"OPP/auth/api"
5+
"OPP/auth/auth"
56
"OPP/auth/dao"
67
"OPP/auth/jwt"
78
"net/http"
@@ -19,6 +20,41 @@ func NewSessionHandler() *SessionHandlers {
1920
}
2021
}
2122

23+
func getLoggedUser(c *gin.Context, userDao dao.UserDao) (*api.UserResponse, api.UserRequestRole, error) {
24+
curUsername := c.Request.Context().Value("username")
25+
if curUsername == nil {
26+
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
27+
return nil, "", nil
28+
}
29+
curUsernameStr, ok := curUsername.(string)
30+
if !ok {
31+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get username"})
32+
return nil, "", nil
33+
}
34+
user, err := userDao.GetUser(c.Request.Context(), curUsernameStr)
35+
if err != nil {
36+
if err == dao.ErrUserNotFound {
37+
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
38+
return nil, "", nil
39+
}
40+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
41+
return nil, "", err
42+
}
43+
44+
curRole := c.Request.Context().Value("role")
45+
if curRole == nil {
46+
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
47+
return nil, "", nil
48+
}
49+
curRoleStr, ok := curRole.(string)
50+
if !ok {
51+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get role"})
52+
return nil, "", nil
53+
}
54+
curUserRole := api.UserRequestRole(curRoleStr)
55+
return user, curUserRole, nil
56+
}
57+
2258
func (h *SessionHandlers) GetPubKey(c *gin.Context) {
2359
if jwt.PublicKeyBase64 == "" {
2460
c.JSON(http.StatusInternalServerError, gin.H{"error": "Public key not available"})
@@ -28,28 +64,51 @@ func (h *SessionHandlers) GetPubKey(c *gin.Context) {
2864
}
2965

3066
func (h *SessionHandlers) Register(c *gin.Context) {
31-
user := api.UserRequest{}
32-
if err := c.ShouldBindJSON(&user); err != nil {
67+
newUser := api.UserRequest{}
68+
if err := c.ShouldBindJSON(&newUser); err != nil {
3369
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
3470
return
3571
}
36-
if *user.Role == api.UserRequestRoleAdmin || *user.Role == api.UserRequestRoleController {
37-
c.JSON(http.StatusForbidden, gin.H{"error": "Cannot register as" + *user.Role + ", permission denied"})
72+
73+
// Default to creating a normal user if role is not specified
74+
if newUser.Role == nil {
75+
defaultRole := api.UserRequestRoleDriver
76+
newUser.Role = &defaultRole
77+
}
78+
79+
// If registering as an admin or controller, verify permissions
80+
if *newUser.Role == api.UserRequestRoleAdmin || *newUser.Role == api.UserRequestRoleController {
81+
// Check if the current user is authenticated with admin privileges
82+
_, role, err := auth.AuthenticationFunc(c.GetHeader("Authorization"))
83+
if err != nil {
84+
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication failed: " + err.Error()})
85+
return
86+
}
87+
88+
// Check if the user has admin privileges
89+
if role != api.UserRequestRoleAdmin {
90+
c.JSON(http.StatusForbidden, gin.H{"error": "Admin privileges required to register " + string(*newUser.Role) + " accounts"})
91+
return
92+
}
3893
}
94+
95+
// For regular users, allow unauthenticated registration
96+
3997
// Check if the user already exists
40-
_, err := h.dao.GetUser(c.Request.Context(), user.Username)
98+
_, err := h.dao.GetUser(c.Request.Context(), newUser.Username)
4199
if err != dao.ErrUserNotFound {
42100
c.JSON(http.StatusConflict, gin.H{"error": "User already exists"})
43101
return
44102
}
103+
45104
// Add the user to the database
46105
var id int64
47-
id, err = h.dao.AddUser(c.Request.Context(), user)
106+
id, err = h.dao.AddUser(c.Request.Context(), newUser)
48107
if err != nil {
49108
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add user"})
50109
return
51110
}
52-
token, err := jwt.GenerateToken(user.Username, *user.Role)
111+
token, err := jwt.GenerateToken(newUser.Username, *newUser.Role)
53112
if err != nil {
54113
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
55114
return
@@ -61,11 +120,11 @@ func (h *SessionHandlers) Register(c *gin.Context) {
61120
TokenType: "Bearer",
62121
User: api.UserResponse{
63122
Id: id,
64-
Role: api.UserResponseRole(*user.Role),
65-
Username: user.Username,
66-
Email: user.Email,
67-
Name: user.Name,
68-
Surname: user.Surname,
123+
Role: api.UserResponseRole(*newUser.Role),
124+
Username: newUser.Username,
125+
Email: newUser.Email,
126+
Name: newUser.Name,
127+
Surname: newUser.Surname,
69128
},
70129
}
71130
c.JSON(http.StatusCreated, response)
@@ -123,40 +182,14 @@ func (h *SessionHandlers) Login(c *gin.Context) {
123182
}
124183

125184
func (h *SessionHandlers) GetSession(c *gin.Context) {
126-
username := c.Request.Context().Value("username")
127-
if username == nil {
128-
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
129-
return
130-
}
131-
usernameStr, ok := username.(string)
132-
if !ok {
133-
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get username"})
134-
return
135-
}
136-
user, err := h.dao.GetUser(c.Request.Context(), usernameStr)
137-
if err != nil {
138-
if err == dao.ErrUserNotFound {
139-
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
140-
return
141-
}
142-
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
143-
return
144-
}
145-
146-
role := c.Request.Context().Value("role")
147-
if role == nil {
148-
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
149-
return
150-
}
151-
roleStr, ok := role.(string)
152-
if !ok {
153-
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get role"})
185+
user, userRole, err := getLoggedUser(c, h.dao)
186+
if err != nil || user == nil {
154187
return
155188
}
156-
userRole := api.UserRequestRole(roleStr)
157189

190+
// No need to fetch the user again since we already have it
158191
// return a new token for the authenticated user
159-
token, err := jwt.GenerateToken(usernameStr, userRole)
192+
token, err := jwt.GenerateToken(user.Username, userRole)
160193
if err != nil {
161194
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
162195
return
@@ -171,7 +204,7 @@ func (h *SessionHandlers) GetSession(c *gin.Context) {
171204
Email: user.Email,
172205
Name: user.Name,
173206
Surname: user.Surname,
174-
Role: api.UserResponseRole(roleStr),
207+
Role: api.UserResponseRole(string(userRole)),
175208
},
176209
}
177210
c.JSON(http.StatusOK, sessionResponse)

src/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func main() {
7171
validatorOptions := &ginmiddleware.Options{
7272
Options: openapi3filter.Options{
7373
AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
74-
return auth.AuthenticationFunc(ctx, input)
74+
return auth.AuthenticationWrapperFunc(ctx, input)
7575
},
7676
},
7777
SilenceServersWarning: silenceServersWarning,

0 commit comments

Comments
 (0)