Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 40216a6

Browse files
authored
Merge pull request #534 from theodesp/feature/pow_sqrt
Feature/pow sqrt
2 parents 94eaa89 + 94918c1 commit 40216a6

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

sql/expression/function/registry.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,7 @@ var Defaults = sql.Functions{
5555
"log": sql.FunctionN(NewLog),
5656
"rpad": sql.FunctionN(NewPadFunc(rPadType)),
5757
"lpad": sql.FunctionN(NewPadFunc(lPadType)),
58+
"sqrt": sql.Function1(NewSqrt),
59+
"pow": sql.Function2(NewPower),
60+
"power": sql.Function2(NewPower),
5861
}

sql/expression/function/sqrt_power.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
"math"
6+
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql"
9+
)
10+
11+
// Sqrt is a function that returns the square value of the number provided.
12+
type Sqrt struct {
13+
expression.UnaryExpression
14+
}
15+
16+
// NewSqrt creates a new Sqrt expression.
17+
func NewSqrt(e sql.Expression) sql.Expression {
18+
return &Sqrt{expression.UnaryExpression{Child: e}}
19+
}
20+
21+
func (s *Sqrt) String() string {
22+
return fmt.Sprintf("sqrt(%s)", s.Child.String())
23+
}
24+
25+
// Type implements the Expression interface.
26+
func (s *Sqrt) Type() sql.Type {
27+
return sql.Float64
28+
}
29+
30+
// IsNullable implements the Expression interface.
31+
func (s *Sqrt) IsNullable() bool {
32+
return s.Child.IsNullable()
33+
}
34+
35+
// TransformUp implements the Expression interface.
36+
func (s *Sqrt) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
37+
child, err := s.Child.TransformUp(fn)
38+
if err != nil {
39+
return nil, err
40+
}
41+
return fn(NewSqrt(child))
42+
}
43+
44+
// Eval implements the Expression interface.
45+
func (s *Sqrt) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
46+
child, err := s.Child.Eval(ctx, row)
47+
48+
if err != nil {
49+
return nil, err
50+
}
51+
52+
if child == nil {
53+
return nil, nil
54+
}
55+
56+
child, err = sql.Float64.Convert(child)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
return math.Sqrt(child.(float64)), nil
62+
}
63+
64+
// Power is a function that returns value of X raised to the power of Y.
65+
type Power struct {
66+
expression.BinaryExpression
67+
}
68+
69+
// NewPower creates a new Power expression.
70+
func NewPower(e1, e2 sql.Expression) sql.Expression {
71+
return &Power{
72+
expression.BinaryExpression{
73+
Left: e1,
74+
Right: e2,
75+
},
76+
}
77+
}
78+
79+
// Type implements the Expression interface.
80+
func (p *Power) Type() sql.Type { return sql.Float64 }
81+
82+
// IsNullable implements the Expression interface.
83+
func (p *Power) IsNullable() bool { return p.Left.IsNullable() || p.Right.IsNullable() }
84+
85+
func (p *Power) String() string {
86+
return fmt.Sprintf("power(%s, %s)", p.Left, p.Right)
87+
}
88+
89+
// TransformUp implements the Expression interface.
90+
func (p *Power) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
91+
left, err := p.Left.TransformUp(fn)
92+
if err != nil {
93+
return nil, err
94+
}
95+
96+
right, err := p.Right.TransformUp(fn)
97+
if err != nil {
98+
return nil, err
99+
}
100+
101+
return fn(NewPower(left, right))
102+
}
103+
104+
// Eval implements the Expression interface.
105+
func (p *Power) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
106+
left, err := p.Left.Eval(ctx, row)
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
if left == nil {
112+
return nil, nil
113+
}
114+
115+
left, err = sql.Float64.Convert(left)
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
right, err := p.Right.Eval(ctx, row)
121+
if err != nil {
122+
return nil, err
123+
}
124+
125+
if right == nil {
126+
return nil, nil
127+
}
128+
129+
right, err = sql.Float64.Convert(right)
130+
if err != nil {
131+
return nil, err
132+
}
133+
134+
return math.Pow(left.(float64), right.(float64)), nil
135+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
"math"
6+
7+
"github.com/stretchr/testify/require"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql"
9+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
10+
)
11+
12+
func TestSqrt(t *testing.T) {
13+
f := NewSqrt(
14+
expression.NewGetField(0, sql.Float64, "n", false),
15+
)
16+
testCases := []struct {
17+
name string
18+
row sql.Row
19+
expected interface{}
20+
err bool
21+
}{
22+
{"null input", sql.NewRow(nil), nil, false},
23+
{"invalid string", sql.NewRow("foo"), nil, true},
24+
{"valid string", sql.NewRow("9"), float64(3), false},
25+
{"number is zero", sql.NewRow(0), float64(0), false},
26+
{"positive number", sql.NewRow(8), float64(2.8284271247461903), false},
27+
}
28+
for _, tt := range testCases {
29+
t.Run(tt.name, func(t *testing.T) {
30+
t.Helper()
31+
require := require.New(t)
32+
ctx := sql.NewEmptyContext()
33+
34+
v, err := f.Eval(ctx, tt.row)
35+
if tt.err {
36+
require.Error(err)
37+
} else {
38+
require.NoError(err)
39+
require.Equal(tt.expected, v)
40+
}
41+
})
42+
}
43+
44+
// Test negative number
45+
f = NewSqrt(
46+
expression.NewGetField(0, sql.Float64, "n", false),
47+
)
48+
require := require.New(t)
49+
v, err := f.Eval(sql.NewEmptyContext(), []interface{}{float64(-4)})
50+
require.NoError(err)
51+
require.IsType(float64(0), v)
52+
require.True(math.IsNaN(v.(float64)))
53+
}
54+
55+
func TestPower(t *testing.T) {
56+
testCases := []struct {
57+
name string
58+
rowType sql.Type
59+
row sql.Row
60+
expected interface{}
61+
err bool
62+
}{
63+
{"Base and exp are nil", sql.Float64, sql.NewRow(nil, nil), nil, false},
64+
{"Base is nil", sql.Float64, sql.NewRow(2, nil), nil, false},
65+
{"Exp is nil", sql.Float64, sql.NewRow(nil, 2), nil, false},
66+
67+
{"Base is 0", sql.Float64, sql.NewRow(0, 2), float64(0), false},
68+
{"Base and exp is 0", sql.Float64, sql.NewRow(0, 0), float64(1), false},
69+
{"Exp is 0", sql.Float64, sql.NewRow(2, 0), float64(1), false},
70+
{"Base is negative", sql.Float64, sql.NewRow(-2, 2), float64(4), false},
71+
{"Exp is negative", sql.Float64, sql.NewRow(2, -2), float64(0.25), false},
72+
{"Base and exp are invalid strings", sql.Float64, sql.NewRow("a", "b"), nil, true},
73+
{"Base and exp are valid strings", sql.Float64, sql.NewRow("2", "2"), float64(4), false},
74+
}
75+
for _, tt := range testCases {
76+
f := NewPower(
77+
expression.NewGetField(0, tt.rowType, "", false),
78+
expression.NewGetField(1, tt.rowType, "", false),
79+
)
80+
t.Run(tt.name, func(t *testing.T) {
81+
t.Helper()
82+
require := require.New(t)
83+
ctx := sql.NewEmptyContext()
84+
85+
v, err := f.Eval(ctx, tt.row)
86+
if tt.err {
87+
require.Error(err)
88+
} else {
89+
require.NoError(err)
90+
require.Equal(tt.expected, v)
91+
}
92+
})
93+
}
94+
95+
// Test inf numbers
96+
f := NewPower(
97+
expression.NewGetField(0, sql.Float64, "", false),
98+
expression.NewGetField(1, sql.Float64, "", false),
99+
)
100+
require := require.New(t)
101+
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(2, math.Inf(1)))
102+
require.NoError(err)
103+
require.IsType(float64(0), v)
104+
require.True(math.IsInf(v.(float64), 1))
105+
106+
v, err = f.Eval(sql.NewEmptyContext(), sql.NewRow(math.Inf(1), 2))
107+
require.NoError(err)
108+
require.IsType(float64(0), v)
109+
require.True(math.IsInf(v.(float64), 1))
110+
}

0 commit comments

Comments
 (0)