diff --git a/expression/builtin_math.go b/expression/builtin_math.go index 919097489e59e..ad2c0b87d7822 100644 --- a/expression/builtin_math.go +++ b/expression/builtin_math.go @@ -943,9 +943,23 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, argTps...) bt := bf if len(args) == 0 { - sig = &builtinRandSig{bt, nil} + seed := time.Now().UnixNano() + sig = &builtinRandSig{bt, rand.New(rand.NewSource(seed))} + } else if _, isConstant := args[0].(*Constant); isConstant { + // According to MySQL manual: + // If an integer argument N is specified, it is used as the seed value: + // With a constant initializer argument, the seed is initialized once + // when the statement is prepared, prior to execution. + seed, isNull, err := args[0].EvalInt(ctx, nil) + if err != nil { + return nil, err + } + if isNull { + seed = time.Now().UnixNano() + } + sig = &builtinRandSig{bt, rand.New(rand.NewSource(seed))} } else { - sig = &builtinRandWithSeedSig{bt, nil} + sig = &builtinRandWithSeedSig{bt} } return sig, nil } @@ -964,21 +978,11 @@ func (b *builtinRandSig) Clone() builtinFunc { // evalReal evals RAND(). // See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand func (b *builtinRandSig) evalReal(row types.Row) (float64, bool, error) { - if b.randGen == nil { - b.randGen = rand.New(rand.NewSource(time.Now().UnixNano())) - } return b.randGen.Float64(), false, nil } type builtinRandWithSeedSig struct { baseBuiltinFunc - randGen *rand.Rand -} - -func (b *builtinRandWithSeedSig) Clone() builtinFunc { - newSig := &builtinRandWithSeedSig{randGen: b.randGen} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig } // evalReal evals RAND(N). @@ -988,15 +992,22 @@ func (b *builtinRandWithSeedSig) evalReal(row types.Row) (float64, bool, error) if err != nil { return 0, true, errors.Trace(err) } - if b.randGen == nil { - if isNull { - // When seed is NULL, it is equal to RAND(). - b.randGen = rand.New(rand.NewSource(time.Now().UnixNano())) - } else { - b.randGen = rand.New(rand.NewSource(seed)) - } + // b.args[0] is promised to be a non-constant(such as a column name) in + // builtinRandWithSeedSig, the seed is initialized with the value for each + // invocation of RAND(). + var randGen *rand.Rand + if isNull { + randGen = rand.New(rand.NewSource(time.Now().UnixNano())) + } else { + randGen = rand.New(rand.NewSource(seed)) } - return b.randGen.Float64(), false, nil + return randGen.Float64(), false, nil +} + +func (b *builtinRandWithSeedSig) Clone() builtinFunc { + newSig := &builtinRandWithSeedSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig } type powFunctionClass struct { diff --git a/expression/integration_test.go b/expression/integration_test.go index 2bcb10f3862da..5cfd35f85772c 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -542,6 +542,13 @@ func (s *testIntegrationSuite) TestMathBuiltin(c *C) { // for radians result = tk.MustQuery("SELECT radians(1.0), radians(pi()), radians(pi()/2), radians(180), radians(1.009);") result.Check(testkit.Rows("0.017453292519943295 0.05483113556160754 0.02741556778080377 3.141592653589793 0.01761037215262278")) + + // for rand + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(1),(2),(3)") + tk.MustQuery("select rand(a) from t").Check(testkit.Rows("0.6046602879796196", "0.16729663442585624", "0.7199826688373036")) + tk.MustQuery("select rand(1), rand(2), rand(3)").Check(testkit.Rows("0.6046602879796196 0.16729663442585624 0.7199826688373036")) } func (s *testIntegrationSuite) TestStringBuiltin(c *C) {