Skip to content

Added support for @@port and @@hostname system variables #3005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions enginetest/queries/variable_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ import (
)

var VariableQueries = []ScriptTest{
{
Name: "use string name for foreign_key checks",
SetUpScript: []string{},
Query: "select @@GLOBAL.unknown",
ExpectedErr: sql.ErrUnknownSystemVariable,
},
{
Name: "use string name for foreign_key checks",
SetUpScript: []string{},
Expand Down Expand Up @@ -649,6 +643,10 @@ var VariableQueries = []ScriptTest{
}

var VariableErrorTests = []QueryErrorTest{
{
Query: "select @@GLOBAL.unknown",
ExpectedErr: sql.ErrUnknownSystemVariable,
},
{
Query: "set @@does_not_exist = 100",
ExpectedErr: sql.ErrUnknownSystemVariable,
Expand Down
91 changes: 91 additions & 0 deletions enginetest/server_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"math"
"net"
"os"
"testing"

"github.com/dolthub/vitess/go/mysql"
Expand Down Expand Up @@ -375,3 +376,93 @@ func TestServerPreparedStatements(t *testing.T) {
})
}
}

func TestServerVariables(t *testing.T) {
hostname, herr := os.Hostname()
require.NoError(t, herr)

port, perr := findEmptyPort()
require.NoError(t, perr)

s, serr := initTestServer(port)
require.NoError(t, serr)

go s.Start()
defer s.Close()

tests := []serverScriptTest{
{
name: "test that config system variables are properly set",
setup: []string{},
assertions: []serverScriptTestAssertion{
{
query: "select @@hostname, @@port, @@max_connections, @@net_read_timeout, @@net_write_timeout",
isExec: false,
expectedRows: []any{
sql.Row{hostname, port, 1, 1, 1},
},
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
var resHostname string
var resPort int
var resMaxConnections int
var resNetReadTimeout int
var resNetWriteTimeout int
var rowNum int
for rows.Next() {
if err := rows.Scan(&resHostname, &resPort, &resMaxConnections, &resNetReadTimeout, &resNetWriteTimeout); err != nil {
return false, err
}
if rowNum >= len(expectedRows) {
return false, nil
}
expectedRow := expectedRows[rowNum].(sql.Row)
require.Equal(t, expectedRow[0].(string), resHostname)
require.Equal(t, expectedRow[1].(int), resPort)
}
return true, nil
},
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil)
require.NoError(t, cerr)
defer conn.Close()
commonSetup := []string{
"create database test_db;",
"use test_db;",
}
commonTeardown := []string{
"drop database test_db",
}
for _, stmt := range append(commonSetup, test.setup...) {
_, err := conn.Exec(stmt)
require.NoError(t, err)
}
for _, assertion := range test.assertions {
t.Run(assertion.query, func(t *testing.T) {
if assertion.skip {
t.Skip()
}
rows, err := conn.Query(assertion.query, assertion.args...)
if assertion.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)

ok, err := assertion.checkRows(t, rows, assertion.expectedRows)
require.NoError(t, err)
require.True(t, ok)
})
}
for _, stmt := range append(commonTeardown) {
_, err := conn.Exec(stmt)
require.NoError(t, err)
}
})
}
}
52 changes: 46 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"time"

"github.com/dolthub/vitess/go/mysql"
Expand Down Expand Up @@ -118,15 +119,49 @@ func portInUse(hostPort string) bool {
return false
}

func getPortOrDefault(cfg mysql.ListenerConfig) int64 {
// TODO read this values from systemVars
defaultPort := int64(3606)
_, port, err := net.SplitHostPort(cfg.Listener.Addr().String())
if err != nil {
return defaultPort
}
portInt, err := strconv.ParseInt(port, 10, 64)
if err != nil {
return defaultPort
}
return portInt
}

func updateSystemVariables(cfg mysql.ListenerConfig) error {
port := getPortOrDefault(cfg)

// TODO: add the rest of the config variables
err := sql.SystemVariables.AssignValues(map[string]interface{}{
"port": port,
"max_connections": cfg.MaxConns,
"net_read_timeout": cfg.ConnReadTimeout.Seconds(),
"net_write_timeout": cfg.ConnWriteTimeout.Seconds(),
})
if err != nil {
return err
}
return nil
}

func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
if cfg.ConnReadTimeout < 0 {
cfg.ConnReadTimeout = 0
oneSecond := time.Duration(1) * time.Second
if cfg.ConnReadTimeout < oneSecond {
// TODO set to MySQL default value
cfg.ConnReadTimeout = oneSecond
}
if cfg.ConnWriteTimeout < 0 {
cfg.ConnWriteTimeout = 0
if cfg.ConnWriteTimeout < oneSecond {
// TODO set to MySQL default value
cfg.ConnWriteTimeout = oneSecond
}
if cfg.MaxConnections < 0 {
cfg.MaxConnections = 0
if cfg.MaxConnections < 1 {
// TODO set to MySQL default value
cfg.MaxConnections = 1
}

for _, opt := range cfg.Options {
Expand Down Expand Up @@ -172,6 +207,11 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
return nil, err
}

err = updateSystemVariables(listenerCfg)
if err != nil {
return nil, err
}

return &Server{
Listener: protocolListener,
handler: handler,
Expand Down
4 changes: 2 additions & 2 deletions server/server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ func (c Config) NewConfig() (Config, error) {
if !ok {
return Config{}, sql.ErrUnknownSystemVariable.New("net_write_timeout")
}
c.ConnWriteTimeout = time.Duration(timeout) * time.Millisecond
c.ConnWriteTimeout = time.Duration(timeout) * time.Second
}
if _, val, ok := sql.SystemVariables.GetGlobal("net_read_timeout"); ok {
timeout, ok := val.(int64)
if !ok {
return Config{}, sql.ErrUnknownSystemVariable.New("net_read_timeout")
}
c.ConnReadTimeout = time.Duration(timeout) * time.Millisecond
c.ConnReadTimeout = time.Duration(timeout) * time.Second
}
return c, nil
}
4 changes: 2 additions & 2 deletions server/server_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ func TestConfigWithDefaults(t *testing.T) {
Type: types.NewSystemIntType("net_write_timeout", 1, 9223372036854775807, false),
ConfigField: "ConnWriteTimeout",
Default: int64(76),
ExpectedCmp: int64(76000000),
ExpectedCmp: int64(76000000000),
}, {
Name: "net_read_timeout",
Scope: sql.SystemVariableScope_Both,
Type: types.NewSystemIntType("net_read_timeout", 1, 9223372036854775807, false),
ConfigField: "ConnReadTimeout",
Default: int64(67),
ExpectedCmp: int64(67000000),
ExpectedCmp: int64(67000000000),
},
}

Expand Down
4 changes: 2 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import (
gsql "github.com/dolthub/go-mysql-server/sql"
)

// TestSeverCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
func TestSeverCustomListener(t *testing.T) {
// TestServerCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
func TestServerCustomListener(t *testing.T) {
dbName := "mydb"
// create a net.Conn thats based on a golang buffer
buffer := 1024
Expand Down
8 changes: 7 additions & 1 deletion sql/variables/system_variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package variables
import (
"fmt"
"math"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -187,6 +188,11 @@ func init() {
InitSystemVariables()
}

func getHostname() string {
hostname, _ := os.Hostname()
return hostname
}

// systemVars is the internal collection of all MySQL system variables according to the following pages:
// https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html
// https://dev.mysql.com/doc/refman/8.0/en/replication-options-gtids.html
Expand Down Expand Up @@ -1009,7 +1015,7 @@ var systemVars = map[string]sql.SystemVariable{
Dynamic: false,
SetVarHintApplies: false,
Type: types.NewSystemStringType("hostname"),
Default: "",
Default: getHostname(),
},
"immediate_server_version": &sql.MysqlSystemVariable{
Name: "immediate_server_version",
Expand Down
Loading