diff --git a/usermanagerprovider_core.go b/usermanagerprovider_core.go index aac9f20..411aabe 100644 --- a/usermanagerprovider_core.go +++ b/usermanagerprovider_core.go @@ -168,54 +168,14 @@ func (um *userManagerProviderCore) UpsertUser(user User, opts *UpsertUserOptions opts.DomainName = string(LocalDomain) } - parseWildcard := func(str string) string { - if str == "*" { - return "" - } - - return str - } - - isNullOrWildcard := func(str string) bool { - if str == "*" || str == "" { - return true - } - - return false - } - path := fmt.Sprintf("/settings/rbac/users/%s/%s", url.PathEscape(opts.DomainName), url.PathEscape(user.Username)) span := um.tracer.createSpan(opts.ParentSpan, "manager_users_upsert_user", "management") span.SetAttribute("db.operation", "PUT "+path) defer span.End() - var reqRoleStrs []string - for _, roleData := range user.Roles { - if roleData.Bucket == "" { - reqRoleStrs = append(reqRoleStrs, roleData.Name) - } else { - scope := parseWildcard(roleData.Scope) - collection := parseWildcard(roleData.Collection) - - if scope != "" && isNullOrWildcard(roleData.Bucket) { - return makeInvalidArgumentsError("when a scope is specified, the bucket cannot be null or wildcard") - } - if collection != "" && isNullOrWildcard(scope) { - return makeInvalidArgumentsError("when a collection is specified, the scope cannot be null or wildcard") - } - - roleStr := fmt.Sprintf("%s[%s", roleData.Name, roleData.Bucket) - if scope != "" { - roleStr += ":" + roleData.Scope - } - if collection != "" { - roleStr += ":" + roleData.Collection - } - roleStr += "]" - - reqRoleStrs = append(reqRoleStrs, roleStr) - - } + rolesString, err := getRolesString(user.Roles) + if err != nil { + return err } reqForm := make(url.Values) @@ -226,7 +186,7 @@ func (um *userManagerProviderCore) UpsertUser(user User, opts *UpsertUserOptions if len(user.Groups) > 0 { reqForm.Add("groups", strings.Join(user.Groups, ",")) } - reqForm.Add("roles", strings.Join(reqRoleStrs, ",")) + reqForm.Add("roles", rolesString) req := mgmtRequest{ Service: ServiceTypeManagement, @@ -470,19 +430,15 @@ func (um *userManagerProviderCore) UpsertGroup(group Group, opts *UpsertGroupOpt span.SetAttribute("db.operation", "PUT "+path) defer span.End() - var reqRoleStrs []string - for _, roleData := range group.Roles { - if roleData.Bucket == "" { - reqRoleStrs = append(reqRoleStrs, roleData.Name) - } else { - reqRoleStrs = append(reqRoleStrs, fmt.Sprintf("%s[%s]", roleData.Name, roleData.Bucket)) - } + rolesString, err := getRolesString(group.Roles) + if err != nil { + return err } reqForm := make(url.Values) reqForm.Add("description", group.Description) reqForm.Add("ldap_group_ref", group.LDAPGroupReference) - reqForm.Add("roles", strings.Join(reqRoleStrs, ",")) + reqForm.Add("roles", rolesString) req := mgmtRequest{ Service: ServiceTypeManagement, @@ -599,3 +555,50 @@ func (um *userManagerProviderCore) ChangePassword(newPassword string, opts *Chan return nil } + +func getRolesString(roles []Role) (string, error) { + var reqRoleStrs []string + for _, roleData := range roles { + if roleData.Bucket == "" { + reqRoleStrs = append(reqRoleStrs, roleData.Name) + } else { + scope := parseWildcard(roleData.Scope) + collection := parseWildcard(roleData.Collection) + + if scope != "" && isNullOrWildcard(roleData.Bucket) { + return "", makeInvalidArgumentsError("when a scope is specified, the bucket cannot be null or wildcard") + } + if collection != "" && isNullOrWildcard(scope) { + return "", makeInvalidArgumentsError("when a collection is specified, the scope cannot be null or wildcard") + } + + roleStr := fmt.Sprintf("%s[%s", roleData.Name, roleData.Bucket) + if scope != "" { + roleStr += ":" + roleData.Scope + } + if collection != "" { + roleStr += ":" + roleData.Collection + } + roleStr += "]" + + reqRoleStrs = append(reqRoleStrs, roleStr) + + } + } + + return strings.Join(reqRoleStrs, ","), nil +} + +func parseWildcard(str string) string { + if str == "*" { + return "" + } + return str +} + +func isNullOrWildcard(str string) bool { + if str == "*" || str == "" { + return true + } + return false +}