tidb base_func 源码

  • 2022-09-19
  • 浏览 (504)

tidb base_func 代码

文件路径:/expression/aggregation/base_func.go

// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package aggregation

import (
	"bytes"
	"fmt"
	"math"
	"strings"

	"github.com/pingcap/errors"
	"github.com/pingcap/tidb/expression"
	"github.com/pingcap/tidb/parser/ast"
	"github.com/pingcap/tidb/parser/charset"
	"github.com/pingcap/tidb/parser/mysql"
	"github.com/pingcap/tidb/sessionctx"
	"github.com/pingcap/tidb/types"
	"github.com/pingcap/tidb/util/chunk"
	"github.com/pingcap/tidb/util/mathutil"
	"github.com/pingcap/tidb/util/size"
)

// baseFuncDesc describes an function signature, only used in planner.
type baseFuncDesc struct {
	// Name represents the function name.
	Name string
	// Args represents the arguments of the function.
	Args []expression.Expression
	// RetTp represents the return type of the function.
	RetTp *types.FieldType
}

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
	b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
	err := b.TypeInfer(ctx)
	return b, err
}

func (a *baseFuncDesc) equal(ctx sessionctx.Context, other *baseFuncDesc) bool {
	if a.Name != other.Name || len(a.Args) != len(other.Args) {
		return false
	}
	for i := range a.Args {
		if !a.Args[i].Equal(ctx, other.Args[i]) {
			return false
		}
	}
	return true
}

func (a *baseFuncDesc) clone() *baseFuncDesc {
	clone := *a
	newTp := *a.RetTp
	clone.RetTp = &newTp
	clone.Args = make([]expression.Expression, len(a.Args))
	for i := range a.Args {
		clone.Args[i] = a.Args[i].Clone()
	}
	return &clone
}

// String implements the fmt.Stringer interface.
func (a *baseFuncDesc) String() string {
	buffer := bytes.NewBufferString(a.Name)
	buffer.WriteString("(")
	for i, arg := range a.Args {
		buffer.WriteString(arg.String())
		if i+1 != len(a.Args) {
			buffer.WriteString(", ")
		}
	}
	buffer.WriteString(")")
	return buffer.String()
}

// TypeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
	switch a.Name {
	case ast.AggFuncCount:
		a.typeInfer4Count(ctx)
	case ast.AggFuncApproxCountDistinct:
		a.typeInfer4ApproxCountDistinct(ctx)
	case ast.AggFuncApproxPercentile:
		return a.typeInfer4ApproxPercentile(ctx)
	case ast.AggFuncSum:
		a.typeInfer4Sum(ctx)
	case ast.AggFuncAvg:
		a.typeInfer4Avg(ctx)
	case ast.AggFuncGroupConcat:
		a.typeInfer4GroupConcat(ctx)
	case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
		ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue:
		a.typeInfer4MaxMin(ctx)
	case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
		a.typeInfer4BitFuncs(ctx)
	case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
		a.typeInfer4NumberFuncs()
	case ast.WindowFuncCumeDist:
		a.typeInfer4CumeDist()
	case ast.WindowFuncNtile:
		a.typeInfer4Ntile()
	case ast.WindowFuncPercentRank:
		a.typeInfer4PercentRank()
	case ast.WindowFuncLead, ast.WindowFuncLag:
		a.typeInfer4LeadLag(ctx)
	case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
		a.typeInfer4PopOrSamp(ctx)
	case ast.AggFuncJsonArrayagg:
		a.typeInfer4JsonArrayAgg(ctx)
	case ast.AggFuncJsonObjectAgg:
		return a.typeInfer4JsonObjectAgg(ctx)
	default:
		return errors.Errorf("unsupported agg function: %s", a.Name)
	}
	return nil
}

func (a *baseFuncDesc) typeInfer4Count(ctx sessionctx.Context) {
	a.RetTp = types.NewFieldType(mysql.TypeLonglong)
	a.RetTp.SetFlen(21)
	a.RetTp.SetDecimal(0)
	// count never returns null
	a.RetTp.AddFlag(mysql.NotNullFlag)
	types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4ApproxCountDistinct(ctx sessionctx.Context) {
	a.typeInfer4Count(ctx)
}

func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error {
	if len(a.Args) != 2 {
		return errors.New("APPROX_PERCENTILE should take 2 arguments")
	}

	if !a.Args[1].ConstItem(ctx.GetSessionVars().StmtCtx) {
		return errors.New("APPROX_PERCENTILE should take a constant expression as percentage argument")
	}
	percent, isNull, err := a.Args[1].EvalInt(ctx, chunk.Row{})
	if err != nil {
		return fmt.Errorf("APPROX_PERCENTILE: Invalid argument %s", a.Args[1].String())
	}
	if percent <= 0 || percent > 100 || isNull {
		if isNull {
			return errors.New("APPROX_PERCENTILE: Percentage value cannot be NULL")
		}
		return fmt.Errorf("Percentage value %d is out of range [1, 100]", percent)
	}

	switch a.Args[0].GetType().GetType() {
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
		a.RetTp = types.NewFieldType(mysql.TypeLonglong)
	case mysql.TypeDouble, mysql.TypeFloat:
		a.RetTp = types.NewFieldType(mysql.TypeDouble)
	case mysql.TypeNewDecimal:
		a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
		a.RetTp.SetFlen(mysql.MaxDecimalWidth)
		a.RetTp.SetDecimal(a.Args[0].GetType().GetDecimal())
		if a.RetTp.GetDecimal() < 0 || a.RetTp.GetDecimal() > mysql.MaxDecimalScale {
			a.RetTp.SetDecimal(mysql.MaxDecimalScale)
		}
	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeTimestamp:
		a.RetTp = a.Args[0].GetType().Clone()
	default:
		a.RetTp = a.Args[0].GetType().Clone()
		a.RetTp.DelFlag(mysql.NotNullFlag)
	}
	return nil
}

// typeInfer4Sum should return a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
	switch a.Args[0].GetType().GetType() {
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
		a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
		a.RetTp.SetFlenUnderLimit(a.Args[0].GetType().GetFlen() + 21)
		a.RetTp.SetDecimal(0)
		if a.Args[0].GetType().GetFlen() < 0 {
			a.RetTp.SetFlen(mysql.MaxDecimalWidth)
		}
	case mysql.TypeNewDecimal:
		a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
		a.RetTp.UpdateFlenAndDecimalUnderLimit(a.Args[0].GetType(), 0, 22)
	case mysql.TypeDouble, mysql.TypeFloat:
		a.RetTp = types.NewFieldType(mysql.TypeDouble)
		a.RetTp.SetFlen(mysql.MaxRealWidth)
		a.RetTp.SetDecimal(a.Args[0].GetType().GetDecimal())
	default:
		a.RetTp = types.NewFieldType(mysql.TypeDouble)
		a.RetTp.SetFlen(mysql.MaxRealWidth)
		a.RetTp.SetDecimal(types.UnspecifiedLength)
	}
	types.SetBinChsClnFlag(a.RetTp)
}

// TypeInfer4AvgSum infers the type of sum from avg, which should extend the precision of decimal
// compatible with mysql.
func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) {
	if avgRetType.GetType() == mysql.TypeNewDecimal {
		a.RetTp.SetFlen(mathutil.Min(mysql.MaxDecimalWidth, a.RetTp.GetFlen()+22))
	}
}

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
	switch a.Args[0].GetType().GetType() {
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
		a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
		a.RetTp.SetDecimalUnderLimit(types.DivFracIncr)
		flen, _ := mysql.GetDefaultFieldLengthAndDecimal(a.Args[0].GetType().GetType())
		a.RetTp.SetFlenUnderLimit(flen + types.DivFracIncr)
	case mysql.TypeYear, mysql.TypeNewDecimal:
		a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
		a.RetTp.UpdateFlenAndDecimalUnderLimit(a.Args[0].GetType(), types.DivFracIncr, types.DivFracIncr)
	case mysql.TypeDouble, mysql.TypeFloat:
		a.RetTp = types.NewFieldType(mysql.TypeDouble)
		a.RetTp.SetFlen(mysql.MaxRealWidth)
		a.RetTp.SetDecimal(a.Args[0].GetType().GetDecimal())
	case mysql.TypeDate, mysql.TypeDuration, mysql.TypeDatetime, mysql.TypeTimestamp:
		a.RetTp = types.NewFieldType(mysql.TypeDouble)
		a.RetTp.SetFlen(mysql.MaxRealWidth)
		a.RetTp.SetDecimal(4)
	default:
		a.RetTp = types.NewFieldType(mysql.TypeDouble)
		a.RetTp.SetFlen(mysql.MaxRealWidth)
		a.RetTp.SetDecimal(types.UnspecifiedLength)
	}
	types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {
	a.RetTp = types.NewFieldType(mysql.TypeVarString)
	charset, collate := charset.GetDefaultCharsetAndCollate()
	a.RetTp.SetCharset(charset)
	a.RetTp.SetCollate(collate)

	a.RetTp.SetFlen(mysql.MaxBlobWidth)
	a.RetTp.SetDecimal(0)
	// TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i])
	for i := 0; i < len(a.Args)-1; i++ {
		if tp := a.Args[i].GetType(); tp.GetType() == mysql.TypeNewDecimal {
			a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
		}
	}
}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
	_, argIsScalaFunc := a.Args[0].(*expression.ScalarFunction)
	if argIsScalaFunc && a.Args[0].GetType().GetType() == mysql.TypeFloat {
		// For scalar function, the result of "float32" is set to the "float64"
		// field in the "Datum". If we do not wrap a cast-as-double function on a.Args[0],
		// error would happen when extracting the evaluation of a.Args[0] to a ProjectionExec.
		tp := types.NewFieldType(mysql.TypeDouble)
		tp.SetFlen(mysql.MaxRealWidth)
		tp.SetDecimal(types.UnspecifiedLength)
		types.SetBinChsClnFlag(tp)
		a.Args[0] = expression.BuildCastFunction(ctx, a.Args[0], tp)
	}
	a.RetTp = a.Args[0].GetType()
	if a.Name == ast.AggFuncMax || a.Name == ast.AggFuncMin ||
		a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag {
		a.RetTp = a.Args[0].GetType().Clone()
		a.RetTp.DelFlag(mysql.NotNullFlag)
	}
	// issue #13027, #13961
	if (a.RetTp.GetType() == mysql.TypeEnum || a.RetTp.GetType() == mysql.TypeSet) &&
		(a.Name != ast.AggFuncFirstRow && a.Name != ast.AggFuncMax && a.Name != ast.AggFuncMin) {
		a.RetTp = types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(mysql.MaxFieldCharLength).BuildP()
	}
}

func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
	a.RetTp = types.NewFieldType(mysql.TypeLonglong)
	a.RetTp.SetFlen(21)
	types.SetBinChsClnFlag(a.RetTp)
	a.RetTp.AddFlag(mysql.UnsignedFlag | mysql.NotNullFlag)
	a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *baseFuncDesc) typeInfer4JsonArrayAgg(ctx sessionctx.Context) {
	a.RetTp = types.NewFieldType(mysql.TypeJSON)
	types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4JsonObjectAgg(ctx sessionctx.Context) error {
	a.RetTp = types.NewFieldType(mysql.TypeJSON)
	types.SetBinChsClnFlag(a.RetTp)
	a.Args[0] = expression.WrapWithCastAsString(ctx, a.Args[0])
	return nil
}

func (a *baseFuncDesc) typeInfer4NumberFuncs() {
	a.RetTp = types.NewFieldType(mysql.TypeLonglong)
	a.RetTp.SetFlen(21)
	types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4CumeDist() {
	a.RetTp = types.NewFieldType(mysql.TypeDouble)
	a.RetTp.SetFlen(mysql.MaxRealWidth)
	a.RetTp.SetDecimal(mysql.NotFixedDec)
}

func (a *baseFuncDesc) typeInfer4Ntile() {
	a.RetTp = types.NewFieldType(mysql.TypeLonglong)
	a.RetTp.SetFlen(21)
	types.SetBinChsClnFlag(a.RetTp)
	a.RetTp.AddFlag(mysql.UnsignedFlag)
}

func (a *baseFuncDesc) typeInfer4PercentRank() {
	a.RetTp = types.NewFieldType(mysql.TypeDouble)
	a.RetTp.SetFlag(mysql.MaxRealWidth)
	a.RetTp.SetDecimal(mysql.NotFixedDec)
}

func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
	if len(a.Args) <= 2 {
		a.typeInfer4MaxMin(ctx)
	} else {
		// Merge the type of first and third argument.
		// FIXME: select lead(b collate utf8mb4_unicode_ci, 1, 'lead' collate utf8mb4_general_ci) over() as a from t; should report error.
		a.RetTp, _ = expression.InferType4ControlFuncs(ctx, a.Name, a.Args[0], a.Args[2])
	}
}

func (a *baseFuncDesc) typeInfer4PopOrSamp(ctx sessionctx.Context) {
	// var_pop/std/var_samp/stddev_samp's return value type is double
	a.RetTp = types.NewFieldType(mysql.TypeDouble)
	a.RetTp.SetFlen(mysql.MaxRealWidth)
	a.RetTp.SetDecimal(types.UnspecifiedLength)
}

// GetDefaultValue gets the default value when the function's input is null.
// According to MySQL, default values of the function are listed as follows:
// e.g.
// Table t which is empty:
// +-------+---------+---------+
// | Table | Field   | Type    |
// +-------+---------+---------+
// | t     | a       | int(11) |
// +-------+---------+---------+
//
// Query: `select avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a), approx_count_distinct(a), approx_percentile(a, 50) from test.t;`
// +--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+--------------------------+
// | avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a)           | max(a) | min(a) | group_concat(a) | approx_count_distinct(a) | approx_percentile(a, 50) |
// +--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+--------------------------+
// |   NULL |   NULL |        0 |          0 |         0 | 18446744073709551615 |   NULL |   NULL | NULL            |                        0 |                     NULL |
// +--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+--------------------------+
func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) {
	switch a.Name {
	case ast.AggFuncCount, ast.AggFuncBitOr, ast.AggFuncBitXor:
		v = types.NewIntDatum(0)
	case ast.AggFuncApproxCountDistinct:
		if a.RetTp.GetType() != mysql.TypeString {
			v = types.NewIntDatum(0)
		}
	case ast.AggFuncFirstRow, ast.AggFuncAvg, ast.AggFuncSum, ast.AggFuncMax,
		ast.AggFuncMin, ast.AggFuncGroupConcat, ast.AggFuncApproxPercentile:
		v = types.Datum{}
	case ast.AggFuncBitAnd:
		v = types.NewUintDatum(uint64(math.MaxUint64))
	}
	return
}

// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
var noNeedCastAggFuncs = map[string]struct{}{
	ast.AggFuncCount:               {},
	ast.AggFuncApproxCountDistinct: {},
	ast.AggFuncApproxPercentile:    {},
	ast.AggFuncMax:                 {},
	ast.AggFuncMin:                 {},
	ast.AggFuncFirstRow:            {},
	ast.WindowFuncNtile:            {},
	ast.AggFuncJsonArrayagg:        {},
	ast.AggFuncJsonObjectAgg:       {},
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
	if len(a.Args) == 0 {
		return
	}
	if _, ok := noNeedCastAggFuncs[a.Name]; ok {
		return
	}
	var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
	switch retTp := a.RetTp; retTp.EvalType() {
	case types.ETInt:
		castFunc = expression.WrapWithCastAsInt
	case types.ETReal:
		castFunc = expression.WrapWithCastAsReal
	case types.ETString:
		castFunc = expression.WrapWithCastAsString
	case types.ETDecimal:
		castFunc = expression.WrapWithCastAsDecimal
	case types.ETDatetime, types.ETTimestamp:
		castFunc = func(ctx sessionctx.Context, expr expression.Expression) expression.Expression {
			return expression.WrapWithCastAsTime(ctx, expr, retTp)
		}
	case types.ETDuration:
		castFunc = expression.WrapWithCastAsDuration
	case types.ETJson:
		castFunc = expression.WrapWithCastAsJSON
	default:
		panic("should never happen in baseFuncDesc.WrapCastForAggArgs")
	}
	for i := range a.Args {
		// Do not cast the second args of these functions, as they are simply non-negative numbers.
		if i == 1 && (a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag || a.Name == ast.WindowFuncNthValue) {
			continue
		}
		if a.Args[i].GetType().GetType() == mysql.TypeNull {
			continue
		}
		a.Args[i] = castFunc(ctx, a.Args[i])
	}
}

// MemoryUsage return the memory usage of baseFuncDesc
func (a *baseFuncDesc) MemoryUsage() (sum int64) {
	if a == nil {
		return
	}

	sum = size.SizeOfString + int64(len(a.Name))
	if a.RetTp != nil {
		sum += a.RetTp.MemoryUsage()
	}
	for _, expr := range a.Args {
		sum += expr.MemoryUsage()
	}
	return
}

相关信息

tidb 源码目录

相关文章

tidb agg_to_pb 源码

tidb aggregation 源码

tidb avg 源码

tidb bit_and 源码

tidb bit_or 源码

tidb bit_xor 源码

tidb concat 源码

tidb count 源码

tidb descriptor 源码

tidb explain 源码

0  赞