Skip to content

Commit 36ec90a

Browse files
committed
fix higher role user registration requirements
1 parent f4630c4 commit 36ec90a

File tree

6 files changed

+147
-73
lines changed

6 files changed

+147
-73
lines changed

compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
services:
2-
opp-backend:
2+
opp-auth:
33
build:
44
context: .
55
dockerfile: Containerfile.dev

src/auth/auth.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"OPP/auth/api"
45
opp_jwt "OPP/auth/jwt"
56
"context"
67
"errors"
@@ -14,59 +15,74 @@ import (
1415

1516
var DEBUG_MODE = os.Getenv("DEBUG_MODE")
1617

17-
func AuthenticationFunc(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
18+
func AuthenticationWrapperFunc(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
1819
req := input.RequestValidationInput.Request
1920
if req == nil {
2021
return errors.New("missing HTTP request in authentication input")
2122
}
2223

23-
authHeader := req.Header.Get("Authorization")
24+
username, role, err := AuthenticationFunc(req.Header.Get("Authorization"))
25+
if err != nil {
26+
return err
27+
}
28+
29+
// Update the request context with the username and role
30+
ctx = context.WithValue(ctx, "username", username)
31+
ctx = context.WithValue(ctx, "role", role)
32+
33+
*req = *req.WithContext(ctx)
34+
35+
return nil
36+
}
37+
38+
// AuthenticationFunc can be used for endpoints that aren't marked as requiring authentication
39+
// but still need to check auth tokens when provided.
40+
// Returns (username, role, error) where error is nil if authentication succeeded
41+
func AuthenticationFunc(authHeader string) (string, api.UserRequestRole, error) {
42+
// Debug mode: override username and role
43+
if DEBUG_MODE == "true" {
44+
return "admin_debug", api.UserRequestRoleAdmin, nil
45+
}
46+
2447
if authHeader == "" {
25-
return errors.New("missing Authorization header")
48+
return "", "", errors.New("missing Authorization header")
2649
}
2750

2851
if !strings.HasPrefix(authHeader, "Bearer ") {
29-
return errors.New("invalid Authorization header format")
52+
return "", "", errors.New("invalid Authorization header format")
3053
}
3154

32-
tokenstr := strings.TrimPrefix(authHeader, "Bearer ")
33-
token, err := opp_jwt.ValidateToken(tokenstr)
55+
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
56+
token, err := opp_jwt.ValidateToken(tokenStr)
3457
if err != nil {
35-
return errors.New("failed to parse token")
58+
return "", "", errors.New("failed to parse token: " + err.Error())
3659
}
60+
3761
claims, ok := token.Claims.(jwt.MapClaims)
3862
if !ok || !token.Valid {
39-
return errors.New("invalid token")
63+
return "", "", errors.New("invalid token")
4064
}
4165

66+
// Validate expiration time
4267
expire, err := claims.GetExpirationTime()
4368
if err != nil {
44-
return errors.New("failed to get expiration time")
69+
return "", "", errors.New("failed to get expiration time")
4570
}
4671

4772
if expire.Before(time.Now()) {
48-
return errors.New("token expired")
73+
return "", "", errors.New("token expired")
4974
}
5075

5176
username, ok := claims["username"].(string)
5277
if !ok {
53-
return errors.New("missing username in token claims")
54-
}
55-
role, ok := claims["role"].(string)
56-
if !ok {
57-
return errors.New("missing role in token claims")
78+
return "", "", errors.New("missing username in token claims")
5879
}
5980

60-
// Update the request context with the username and role
61-
ctx = context.WithValue(ctx, "username", username)
62-
ctx = context.WithValue(ctx, "role", role)
63-
64-
// Debug mode: override role with "admin"
65-
if DEBUG_MODE == "true" {
66-
ctx = context.WithValue(ctx, "role", "admin")
81+
roleStr, ok := claims["role"].(string)
82+
if !ok {
83+
return "", "", errors.New("missing role in token claims")
6784
}
6885

69-
*req = *req.WithContext(ctx)
70-
71-
return nil
86+
role := api.UserRequestRole(roleStr)
87+
return username, role, nil
7288
}

src/dao/user.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (d *UserDao) AddUser(c context.Context, user api.UserRequest) (int64, error
7272
return id, nil
7373
}
7474

75-
func (d *UserDao) GetUser(c context.Context, username string) (*api.UserResponse, error) {
75+
func (d *UserDao) GetUserByUsername(c context.Context, username string) (*api.UserResponse, error) {
7676
query := "SELECT user_id, username, name, surname, email, role FROM users WHERE username = $1"
7777
rows, err := d.db.Query(c, query, username)
7878
if err != nil {
@@ -92,6 +92,25 @@ func (d *UserDao) GetUser(c context.Context, username string) (*api.UserResponse
9292
return nil, ErrUserNotFound
9393
}
9494

95+
func (d *UserDao) GetUserByEmail(c context.Context, email string) (*api.UserResponse, error) {
96+
query := "SELECT user_id, username, name, surname, email, role FROM users WHERE email = $1"
97+
rows, err := d.db.Query(c, query, email)
98+
if err != nil {
99+
return nil, fmt.Errorf("db error: %w", err)
100+
}
101+
defer rows.Close()
102+
var user api.UserResponse
103+
var roleStr string
104+
if rows.Next() {
105+
if err := rows.Scan(&user.Id, &user.Username, &user.Name, &user.Surname, &user.Email, &roleStr); err != nil {
106+
return nil, fmt.Errorf("failed to scan user: %w", err)
107+
}
108+
user.Role = api.UserResponseRole(roleStr)
109+
return &user, nil
110+
}
111+
return nil, ErrUserNotFound
112+
}
113+
95114
func (d *UserDao) GetUserRole(c context.Context, username string) (api.UserRequestRole, error) {
96115
query := "SELECT role FROM users WHERE username = $1"
97116
rows, err := d.db.Query(c, query, username)

src/handlers/session.go

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package handlers
22

33
import (
44
"OPP/auth/api"
5+
"OPP/auth/auth"
56
"OPP/auth/dao"
67
"OPP/auth/jwt"
78
"net/http"
@@ -19,6 +20,41 @@ func NewSessionHandler() *SessionHandlers {
1920
}
2021
}
2122

23+
func getLoggedUser(c *gin.Context, userDao dao.UserDao) (*api.UserResponse, api.UserRequestRole, error) {
24+
curUsername := c.Request.Context().Value("username")
25+
if curUsername == nil {
26+
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
27+
return nil, "", nil
28+
}
29+
curUsernameStr, ok := curUsername.(string)
30+
if !ok {
31+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get username"})
32+
return nil, "", nil
33+
}
34+
user, err := userDao.GetUserByUsername(c.Request.Context(), curUsernameStr)
35+
if err != nil {
36+
if err == dao.ErrUserNotFound {
37+
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
38+
return nil, "", nil
39+
}
40+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
41+
return nil, "", err
42+
}
43+
44+
curRole := c.Request.Context().Value("role")
45+
if curRole == nil {
46+
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
47+
return nil, "", nil
48+
}
49+
curRoleStr, ok := curRole.(string)
50+
if !ok {
51+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get role"})
52+
return nil, "", nil
53+
}
54+
curUserRole := api.UserRequestRole(curRoleStr)
55+
return user, curUserRole, nil
56+
}
57+
2258
func (h *SessionHandlers) GetPubKey(c *gin.Context) {
2359
if jwt.PublicKeyBase64 == "" {
2460
c.JSON(http.StatusInternalServerError, gin.H{"error": "Public key not available"})
@@ -28,28 +64,57 @@ func (h *SessionHandlers) GetPubKey(c *gin.Context) {
2864
}
2965

3066
func (h *SessionHandlers) Register(c *gin.Context) {
31-
user := api.UserRequest{}
32-
if err := c.ShouldBindJSON(&user); err != nil {
67+
newUser := api.UserRequest{}
68+
if err := c.ShouldBindJSON(&newUser); err != nil {
3369
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
3470
return
3571
}
36-
if *user.Role == api.UserRequestRoleAdmin || *user.Role == api.UserRequestRoleController {
37-
c.JSON(http.StatusForbidden, gin.H{"error": "Cannot register as" + *user.Role + ", permission denied"})
72+
73+
// Default to creating a normal user if role is not specified
74+
if newUser.Role == nil {
75+
defaultRole := api.UserRequestRoleDriver
76+
newUser.Role = &defaultRole
77+
}
78+
79+
// If registering as an admin or controller, verify permissions
80+
if *newUser.Role == api.UserRequestRoleAdmin || *newUser.Role == api.UserRequestRoleController {
81+
// Check if the current user is authenticated with admin privileges
82+
_, role, err := auth.AuthenticationFunc(c.GetHeader("Authorization"))
83+
if err != nil {
84+
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication failed: " + err.Error()})
85+
return
86+
}
87+
88+
// Check if the user has admin privileges
89+
if role != api.UserRequestRoleAdmin {
90+
c.JSON(http.StatusForbidden, gin.H{"error": "Admin privileges required to register " + string(*newUser.Role) + " accounts"})
91+
return
92+
}
3893
}
94+
95+
// For regular users, allow unauthenticated registration
96+
3997
// Check if the user already exists
40-
_, err := h.dao.GetUser(c.Request.Context(), user.Username)
98+
_, err := h.dao.GetUserByUsername(c.Request.Context(), newUser.Username)
4199
if err != dao.ErrUserNotFound {
42100
c.JSON(http.StatusConflict, gin.H{"error": "User already exists"})
43101
return
44102
}
103+
emailstr := string(newUser.Email)
104+
_, err = h.dao.GetUserByEmail(c.Request.Context(), emailstr)
105+
if err != dao.ErrUserNotFound {
106+
c.JSON(http.StatusConflict, gin.H{"error": "Email already in use"})
107+
return
108+
}
109+
45110
// Add the user to the database
46111
var id int64
47-
id, err = h.dao.AddUser(c.Request.Context(), user)
112+
id, err = h.dao.AddUser(c.Request.Context(), newUser)
48113
if err != nil {
49114
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add user"})
50115
return
51116
}
52-
token, err := jwt.GenerateToken(user.Username, *user.Role)
117+
token, err := jwt.GenerateToken(newUser.Username, *newUser.Role)
53118
if err != nil {
54119
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
55120
return
@@ -61,11 +126,11 @@ func (h *SessionHandlers) Register(c *gin.Context) {
61126
TokenType: "Bearer",
62127
User: api.UserResponse{
63128
Id: id,
64-
Role: api.UserResponseRole(*user.Role),
65-
Username: user.Username,
66-
Email: user.Email,
67-
Name: user.Name,
68-
Surname: user.Surname,
129+
Role: api.UserResponseRole(*newUser.Role),
130+
Username: newUser.Username,
131+
Email: newUser.Email,
132+
Name: newUser.Name,
133+
Surname: newUser.Surname,
69134
},
70135
}
71136
c.JSON(http.StatusCreated, response)
@@ -78,7 +143,7 @@ func (h *SessionHandlers) Login(c *gin.Context) {
78143
return
79144
}
80145
// Check if the user exists
81-
user, err := h.dao.GetUser(c.Request.Context(), session.Username)
146+
user, err := h.dao.GetUserByUsername(c.Request.Context(), session.Username)
82147
if err == dao.ErrUserNotFound {
83148
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
84149
return
@@ -123,40 +188,14 @@ func (h *SessionHandlers) Login(c *gin.Context) {
123188
}
124189

125190
func (h *SessionHandlers) GetSession(c *gin.Context) {
126-
username := c.Request.Context().Value("username")
127-
if username == nil {
128-
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
129-
return
130-
}
131-
usernameStr, ok := username.(string)
132-
if !ok {
133-
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get username"})
134-
return
135-
}
136-
user, err := h.dao.GetUser(c.Request.Context(), usernameStr)
137-
if err != nil {
138-
if err == dao.ErrUserNotFound {
139-
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
140-
return
141-
}
142-
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
143-
return
144-
}
145-
146-
role := c.Request.Context().Value("role")
147-
if role == nil {
148-
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
149-
return
150-
}
151-
roleStr, ok := role.(string)
152-
if !ok {
153-
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get role"})
191+
user, userRole, err := getLoggedUser(c, h.dao)
192+
if err != nil || user == nil {
154193
return
155194
}
156-
userRole := api.UserRequestRole(roleStr)
157195

196+
// No need to fetch the user again since we already have it
158197
// return a new token for the authenticated user
159-
token, err := jwt.GenerateToken(usernameStr, userRole)
198+
token, err := jwt.GenerateToken(user.Username, userRole)
160199
if err != nil {
161200
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
162201
return
@@ -171,7 +210,7 @@ func (h *SessionHandlers) GetSession(c *gin.Context) {
171210
Email: user.Email,
172211
Name: user.Name,
173212
Surname: user.Surname,
174-
Role: api.UserResponseRole(roleStr),
213+
Role: api.UserResponseRole(string(userRole)),
175214
},
176215
}
177216
c.JSON(http.StatusOK, sessionResponse)

src/handlers/user.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (uh *UserHandlers) GetUser(c *gin.Context) {
6363
return
6464
}
6565

66-
user, err := uh.dao.GetUser(c.Request.Context(), usernameStr)
66+
user, err := uh.dao.GetUserByUsername(c.Request.Context(), usernameStr)
6767
if err != nil {
6868
if errors.Is(err, dao.ErrUserNotFound) {
6969
c.JSON(http.StatusNotFound, gin.H{"error": "user not found"})

src/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func main() {
7171
validatorOptions := &ginmiddleware.Options{
7272
Options: openapi3filter.Options{
7373
AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
74-
return auth.AuthenticationFunc(ctx, input)
74+
return auth.AuthenticationWrapperFunc(ctx, input)
7575
},
7676
},
7777
SilenceServersWarning: silenceServersWarning,

0 commit comments

Comments
 (0)