tidb expression_rewriter 源码

  • 2022-09-19
  • 浏览 (446)

tidb expression_rewriter 代码

文件路径:/planner/core/expression_rewriter.go

// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import (
	"context"
	"encoding/hex"
	"encoding/json"
	"strconv"
	"strings"
	"time"

	"github.com/pingcap/errors"
	"github.com/pingcap/tidb/domain"
	"github.com/pingcap/tidb/expression"
	"github.com/pingcap/tidb/expression/aggregation"
	"github.com/pingcap/tidb/infoschema"
	"github.com/pingcap/tidb/parser/ast"
	"github.com/pingcap/tidb/parser/charset"
	"github.com/pingcap/tidb/parser/model"
	"github.com/pingcap/tidb/parser/mysql"
	"github.com/pingcap/tidb/parser/opcode"
	"github.com/pingcap/tidb/sessionctx"
	"github.com/pingcap/tidb/sessionctx/variable"
	"github.com/pingcap/tidb/table"
	"github.com/pingcap/tidb/table/tables"
	"github.com/pingcap/tidb/tablecodec"
	"github.com/pingcap/tidb/types"
	driver "github.com/pingcap/tidb/types/parser_driver"
	"github.com/pingcap/tidb/util/chunk"
	"github.com/pingcap/tidb/util/codec"
	"github.com/pingcap/tidb/util/collate"
	"github.com/pingcap/tidb/util/hint"
	"github.com/pingcap/tidb/util/sem"
	"github.com/pingcap/tidb/util/stringutil"
)

// EvalSubqueryFirstRow evaluates incorrelated subqueries once, and get first row.
var EvalSubqueryFirstRow func(ctx context.Context, p PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) (row []types.Datum, err error)

// evalAstExpr evaluates ast expression directly.
func evalAstExpr(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) {
	if val, ok := expr.(*driver.ValueExpr); ok {
		return val.Datum, nil
	}
	newExpr, err := rewriteAstExpr(sctx, expr, nil, nil)
	if err != nil {
		return types.Datum{}, err
	}
	return newExpr.Eval(chunk.Row{})
}

// rewriteAstExpr rewrites ast expression directly.
func rewriteAstExpr(sctx sessionctx.Context, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice) (expression.Expression, error) {
	var is infoschema.InfoSchema
	// in tests, it may be null
	if s, ok := sctx.GetInfoSchema().(infoschema.InfoSchema); ok {
		is = s
	}
	b, savedBlockNames := NewPlanBuilder().Init(sctx, is, &hint.BlockHintProcessor{})
	fakePlan := LogicalTableDual{}.Init(sctx, 0)
	if schema != nil {
		fakePlan.schema = schema
		fakePlan.names = names
	}
	b.curClause = expressionClause
	newExpr, _, err := b.rewrite(context.TODO(), expr, fakePlan, nil, true)
	if err != nil {
		return nil, err
	}
	sctx.GetSessionVars().PlannerSelectBlockAsName = savedBlockNames
	return newExpr, nil
}

func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(ctx context.Context, exprNode ast.ExprNode, mockPlan LogicalPlan, insertPlan *Insert) (expression.Expression, error) {
	b.rewriterCounter++
	defer func() { b.rewriterCounter-- }()

	b.curClause = fieldList
	rewriter := b.getExpressionRewriter(ctx, mockPlan)
	// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
	// not nil means certain previous procedure has not handled this error.
	// Here we give us one more chance to make a correct behavior by handling
	// this missed error.
	if rewriter.err != nil {
		return nil, rewriter.err
	}

	rewriter.insertPlan = insertPlan
	rewriter.asScalar = true

	expr, _, err := b.rewriteExprNode(rewriter, exprNode, true)
	return expr, err
}

// rewrite function rewrites ast expr to expression.Expression.
// aggMapper maps ast.AggregateFuncExpr to the columns offset in p's output schema.
// asScalar means whether this expression must be treated as a scalar expression.
// And this function returns a result expression, a new plan that may have apply or semi-join.
func (b *PlanBuilder) rewrite(ctx context.Context, exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, LogicalPlan, error) {
	expr, resultPlan, err := b.rewriteWithPreprocess(ctx, exprNode, p, aggMapper, nil, asScalar, nil)
	return expr, resultPlan, err
}

// rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree
// before really using its node in `expressionRewriter.Leave`. In that case, we first call
// er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`.
func (b *PlanBuilder) rewriteWithPreprocess(
	ctx context.Context,
	exprNode ast.ExprNode,
	p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int,
	windowMapper map[*ast.WindowFuncExpr]int,
	asScalar bool,
	preprocess func(ast.Node) ast.Node,
) (expression.Expression, LogicalPlan, error) {
	b.rewriterCounter++
	defer func() { b.rewriterCounter-- }()

	rewriter := b.getExpressionRewriter(ctx, p)
	// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
	// not nil means certain previous procedure has not handled this error.
	// Here we give us one more chance to make a correct behavior by handling
	// this missed error.
	if rewriter.err != nil {
		return nil, nil, rewriter.err
	}

	rewriter.aggrMap = aggMapper
	rewriter.windowMap = windowMapper
	rewriter.asScalar = asScalar
	rewriter.preprocess = preprocess

	expr, resultPlan, err := b.rewriteExprNode(rewriter, exprNode, asScalar)
	return expr, resultPlan, err
}

func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p LogicalPlan) (rewriter *expressionRewriter) {
	defer func() {
		if p != nil {
			rewriter.schema = p.Schema()
			rewriter.names = p.OutputNames()
		}
	}()

	if len(b.rewriterPool) < b.rewriterCounter {
		rewriter = &expressionRewriter{p: p, b: b, sctx: b.ctx, ctx: ctx}
		rewriter.sctx.SetValue(expression.TiDBDecodeKeyFunctionKey, decodeKeyFromString)
		b.rewriterPool = append(b.rewriterPool, rewriter)
		return
	}

	rewriter = b.rewriterPool[b.rewriterCounter-1]
	rewriter.p = p
	rewriter.asScalar = false
	rewriter.aggrMap = nil
	rewriter.preprocess = nil
	rewriter.insertPlan = nil
	rewriter.disableFoldCounter = 0
	rewriter.tryFoldCounter = 0
	rewriter.ctxStack = rewriter.ctxStack[:0]
	rewriter.ctxNameStk = rewriter.ctxNameStk[:0]
	rewriter.ctx = ctx
	rewriter.err = nil
	return
}

func (b *PlanBuilder) rewriteExprNode(rewriter *expressionRewriter, exprNode ast.ExprNode, asScalar bool) (expression.Expression, LogicalPlan, error) {
	if rewriter.p != nil {
		curColLen := rewriter.p.Schema().Len()
		defer func() {
			names := rewriter.p.OutputNames().Shallow()[:curColLen]
			for i := curColLen; i < rewriter.p.Schema().Len(); i++ {
				names = append(names, types.EmptyName)
			}
			// After rewriting finished, only old columns are visible.
			// e.g. select * from t where t.a in (select t1.a from t1);
			// The output columns before we enter the subquery are the columns from t.
			// But when we leave the subquery `t.a in (select t1.a from t1)`, we got a Apply operator
			// and the output columns become [t.*, t1.*]. But t1.* is used only inside the subquery. If there's another filter
			// which is also a subquery where t1 is involved. The name resolving will fail if we still expose the column from
			// the previous subquery.
			// So here we just reset the names to empty to avoid this situation.
			// TODO: implement ScalarSubQuery and resolve it during optimizing. In building phase, we will not change the plan's structure.
			rewriter.p.SetOutputNames(names)
		}()
	}
	exprNode.Accept(rewriter)
	if rewriter.err != nil {
		return nil, nil, errors.Trace(rewriter.err)
	}
	if !asScalar && len(rewriter.ctxStack) == 0 {
		return nil, rewriter.p, nil
	}
	if len(rewriter.ctxStack) != 1 {
		return nil, nil, errors.Errorf("context len %v is invalid", len(rewriter.ctxStack))
	}
	rewriter.err = expression.CheckArgsNotMultiColumnRow(rewriter.ctxStack[0])
	if rewriter.err != nil {
		return nil, nil, errors.Trace(rewriter.err)
	}
	return rewriter.ctxStack[0], rewriter.p, nil
}

type expressionRewriter struct {
	ctxStack   []expression.Expression
	ctxNameStk []*types.FieldName
	p          LogicalPlan
	schema     *expression.Schema
	names      []*types.FieldName
	err        error
	aggrMap    map[*ast.AggregateFuncExpr]int
	windowMap  map[*ast.WindowFuncExpr]int
	b          *PlanBuilder
	sctx       sessionctx.Context
	ctx        context.Context

	// asScalar indicates the return value must be a scalar value.
	// NOTE: This value can be changed during expression rewritten.
	asScalar bool

	// preprocess is called for every ast.Node in Leave.
	preprocess func(ast.Node) ast.Node

	// insertPlan is only used to rewrite the expressions inside the assignment
	// of the "INSERT" statement.
	insertPlan *Insert

	// disableFoldCounter controls fold-disabled scope. If > 0, rewriter will NOT do constant folding.
	// Typically, during visiting AST, while entering the scope(disable), the counter will +1; while
	// leaving the scope(enable again), the counter will -1.
	// NOTE: This value can be changed during expression rewritten.
	disableFoldCounter int
	tryFoldCounter     int
}

func (er *expressionRewriter) ctxStackLen() int {
	return len(er.ctxStack)
}

func (er *expressionRewriter) ctxStackPop(num int) {
	l := er.ctxStackLen()
	er.ctxStack = er.ctxStack[:l-num]
	er.ctxNameStk = er.ctxNameStk[:l-num]
}

func (er *expressionRewriter) ctxStackAppend(col expression.Expression, name *types.FieldName) {
	er.ctxStack = append(er.ctxStack, col)
	er.ctxNameStk = append(er.ctxNameStk, name)
}

// constructBinaryOpFunction converts binary operator functions
// 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
// `IF( a0 NE b0, a0 op b0,
//
//	IF ( isNull(a0 NE b0), Null,
//		IF ( a1 NE b1, a1 op b1,
//			IF ( isNull(a1 NE b1), Null, a2 op b2))))`
func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
	lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r)
	if lLen == 1 && rLen == 1 {
		return er.newFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
	} else if rLen != lLen {
		return nil, expression.ErrOperandColumns.GenWithStackByArgs(lLen)
	}
	switch op {
	case ast.EQ, ast.NE, ast.NullEQ:
		funcs := make([]expression.Expression, lLen)
		for i := 0; i < lLen; i++ {
			var err error
			funcs[i], err = er.constructBinaryOpFunction(expression.GetFuncArg(l, i), expression.GetFuncArg(r, i), op)
			if err != nil {
				return nil, err
			}
		}
		if op == ast.NE {
			return expression.ComposeDNFCondition(er.sctx, funcs...), nil
		}
		return expression.ComposeCNFCondition(er.sctx, funcs...), nil
	default:
		larg0, rarg0 := expression.GetFuncArg(l, 0), expression.GetFuncArg(r, 0)
		var expr1, expr2, expr3, expr4, expr5 expression.Expression
		expr1 = expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
		expr2 = expression.NewFunctionInternal(er.sctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0)
		expr3 = expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1)
		var err error
		l, err = expression.PopRowFirstArg(er.sctx, l)
		if err != nil {
			return nil, err
		}
		r, err = expression.PopRowFirstArg(er.sctx, r)
		if err != nil {
			return nil, err
		}
		expr4, err = er.constructBinaryOpFunction(l, r, op)
		if err != nil {
			return nil, err
		}
		expr5, err = er.newFunction(ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.NewNull(), expr4)
		if err != nil {
			return nil, err
		}
		return er.newFunction(ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5)
	}
}

// buildSubquery translates the subquery ast to plan.
// Currently, only the EXIST can apply the rewrite hint(rewrite the semi join to inner join with aggregation).
func (er *expressionRewriter) buildSubquery(ctx context.Context, subq *ast.SubqueryExpr, rewriteHintCanTakeEffect bool) (np LogicalPlan, hasSemiJoinRewriteHint bool, err error) {
	if er.schema != nil {
		outerSchema := er.schema.Clone()
		er.b.outerSchemas = append(er.b.outerSchemas, outerSchema)
		er.b.outerNames = append(er.b.outerNames, er.names)
		defer func() {
			er.b.outerSchemas = er.b.outerSchemas[0 : len(er.b.outerSchemas)-1]
			er.b.outerNames = er.b.outerNames[0 : len(er.b.outerNames)-1]
		}()
	}
	// Store the old value before we enter the subquery and reset they to default value.
	oldRewriteHintCanTakeEffect := er.b.checkSemiJoinHint
	er.b.checkSemiJoinHint = rewriteHintCanTakeEffect
	oldHasHint := er.b.hasValidSemiJoinHint
	er.b.hasValidSemiJoinHint = false
	outerWindowSpecs := er.b.windowSpecs
	defer func() {
		er.b.windowSpecs = outerWindowSpecs
		er.b.checkSemiJoinHint = oldRewriteHintCanTakeEffect
		er.b.hasValidSemiJoinHint = oldHasHint
	}()

	np, err = er.b.buildResultSetNode(ctx, subq.Query, false)
	if err != nil {
		return nil, false, err
	}
	// Pop the handle map generated by the subquery.
	er.b.handleHelper.popMap()
	return np, er.b.hasValidSemiJoinHint, nil
}

// Enter implements Visitor interface.
func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
	switch v := inNode.(type) {
	case *ast.AggregateFuncExpr:
		index, ok := -1, false
		if er.aggrMap != nil {
			index, ok = er.aggrMap[v]
		}
		if ok {
			// index < 0 indicates this is a correlated aggregate belonging to outer query,
			// for which a correlated column will be created later, so we append a null constant
			// as a temporary result expression.
			if index < 0 {
				er.ctxStackAppend(expression.NewNull(), types.EmptyName)
			} else {
				// index >= 0 indicates this is a regular aggregate column
				er.ctxStackAppend(er.schema.Columns[index], er.names[index])
			}
			return inNode, true
		}
		// replace correlated aggregate in sub-query with its corresponding correlated column
		if col, ok := er.b.correlatedAggMapper[v]; ok {
			er.ctxStackAppend(col, types.EmptyName)
			return inNode, true
		}
		er.err = ErrInvalidGroupFuncUse
		return inNode, true
	case *ast.ColumnNameExpr:
		if index, ok := er.b.colMapper[v]; ok {
			er.ctxStackAppend(er.schema.Columns[index], er.names[index])
			return inNode, true
		}
	case *ast.CompareSubqueryExpr:
		return er.handleCompareSubquery(er.ctx, v)
	case *ast.ExistsSubqueryExpr:
		return er.handleExistSubquery(er.ctx, v)
	case *ast.PatternInExpr:
		if v.Sel != nil {
			return er.handleInSubquery(er.ctx, v)
		}
		if len(v.List) != 1 {
			break
		}
		// For 10 in ((select * from t)), the parser won't set v.Sel.
		// So we must process this case here.
		x := v.List[0]
		for {
			switch y := x.(type) {
			case *ast.SubqueryExpr:
				v.Sel = y
				return er.handleInSubquery(er.ctx, v)
			case *ast.ParenthesesExpr:
				x = y.Expr
			default:
				return inNode, false
			}
		}
	case *ast.SubqueryExpr:
		return er.handleScalarSubquery(er.ctx, v)
	case *ast.ParenthesesExpr:
	case *ast.ValuesExpr:
		schema, names := er.schema, er.names
		// NOTE: "er.insertPlan != nil" means that we are rewriting the
		// expressions inside the assignment of "INSERT" statement. we have to
		// use the "tableSchema" of that "insertPlan".
		if er.insertPlan != nil {
			schema = er.insertPlan.tableSchema
			names = er.insertPlan.tableColNames
		}
		idx, err := expression.FindFieldName(names, v.Column.Name)
		if err != nil {
			er.err = err
			return inNode, false
		}
		if idx < 0 {
			er.err = ErrUnknownColumn.GenWithStackByArgs(v.Column.Name.OrigColName(), "field list")
			return inNode, false
		}
		col := schema.Columns[idx]
		er.ctxStackAppend(expression.NewValuesFunc(er.sctx, col.Index, col.RetType), types.EmptyName)
		return inNode, true
	case *ast.WindowFuncExpr:
		index, ok := -1, false
		if er.windowMap != nil {
			index, ok = er.windowMap[v]
		}
		if !ok {
			er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.F))
			return inNode, true
		}
		er.ctxStackAppend(er.schema.Columns[index], er.names[index])
		return inNode, true
	case *ast.FuncCallExpr:
		er.asScalar = true
		if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
			er.disableFoldCounter++
		}
		if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
			er.tryFoldCounter++
		}
	case *ast.CaseExpr:
		er.asScalar = true
		if _, ok := expression.DisableFoldFunctions["case"]; ok {
			er.disableFoldCounter++
		}
		if _, ok := expression.TryFoldFunctions["case"]; ok {
			er.tryFoldCounter++
		}
	case *ast.BinaryOperationExpr:
		er.asScalar = true
		if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr {
			er.tryFoldCounter++
		}
	case *ast.SetCollationExpr:
		// Do nothing
	default:
		er.asScalar = true
	}
	return inNode, false
}

func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r expression.Expression, not bool) {
	if er.asScalar || not {
		if expression.GetRowLen(r) == 1 {
			rCol := r.(*expression.Column)
			// If both input columns of `!= all / = any` expression are not null, we can treat the expression
			// as normal column equal condition.
			if !expression.ExprNotNull(l) || !expression.ExprNotNull(rCol) {
				rColCopy := *rCol
				rColCopy.InOperand = true
				r = &rColCopy
				l = expression.SetExprColumnInOperand(l)
			}
		} else {
			rowFunc := r.(*expression.ScalarFunction)
			rargs := rowFunc.GetArgs()
			args := make([]expression.Expression, 0, len(rargs))
			modified := false
			for i, rarg := range rargs {
				larg := expression.GetFuncArg(l, i)
				if !expression.ExprNotNull(larg) || !expression.ExprNotNull(rarg) {
					rCol := rarg.(*expression.Column)
					rColCopy := *rCol
					rColCopy.InOperand = true
					rarg = &rColCopy
					modified = true
				}
				args = append(args, rarg)
			}
			if modified {
				r, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
				if er.err != nil {
					return
				}
				l = expression.SetExprColumnInOperand(l)
			}
		}
	}
	var condition expression.Expression
	condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ)
	if er.err != nil {
		return
	}
	er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, not, false)
}

func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast.CompareSubqueryExpr) (ast.Node, bool) {
	ci := er.b.prepareCTECheckForSubQuery()
	defer resetCTECheckForSubQuery(ci)
	v.L.Accept(er)
	if er.err != nil {
		return v, true
	}
	lexpr := er.ctxStack[len(er.ctxStack)-1]
	subq, ok := v.R.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown compare type %T", v.R)
		return v, true
	}
	np, _, err := er.buildSubquery(ctx, subq, false)
	if err != nil {
		er.err = err
		return v, true
	}
	// Only (a,b,c) = any (...) and (a,b,c) != all (...) can use row expression.
	canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE)
	if !canMultiCol && (expression.GetRowLen(lexpr) != 1 || np.Schema().Len() != 1) {
		er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
		return v, true
	}
	lLen := expression.GetRowLen(lexpr)
	if lLen != np.Schema().Len() {
		er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen)
		return v, true
	}
	var rexpr expression.Expression
	if np.Schema().Len() == 1 {
		rexpr = np.Schema().Columns[0]
	} else {
		args := make([]expression.Expression, 0, np.Schema().Len())
		for _, col := range np.Schema().Columns {
			args = append(args, col)
		}
		rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
		if er.err != nil {
			return v, true
		}
	}

	// Lexpr cannot compare with rexpr by different collate
	opString := new(strings.Builder)
	v.Op.Format(opString)
	_, er.err = expression.CheckAndDeriveCollationFromExprs(er.sctx, opString.String(), types.ETInt, lexpr, rexpr)
	if er.err != nil {
		return v, true
	}

	switch v.Op {
	// Only EQ, NE and NullEQ can be composed with and.
	case opcode.EQ, opcode.NE, opcode.NullEQ:
		if v.Op == opcode.EQ {
			if v.All {
				er.handleEQAll(lexpr, rexpr, np)
			} else {
				// `a = any(subq)` will be rewriten as `a in (subq)`.
				er.asScalar = true
				er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, false)
				if er.err != nil {
					return v, true
				}
			}
		} else if v.Op == opcode.NE {
			if v.All {
				// `a != all(subq)` will be rewriten as `a not in (subq)`.
				er.asScalar = true
				er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, true)
				if er.err != nil {
					return v, true
				}
			} else {
				er.handleNEAny(lexpr, rexpr, np)
			}
		} else {
			// TODO: Support this in future.
			er.err = errors.New("We don't support <=> all or <=> any now")
			return v, true
		}
	default:
		// When < all or > any , the agg function should use min.
		useMin := ((v.Op == opcode.LT || v.Op == opcode.LE) && v.All) || ((v.Op == opcode.GT || v.Op == opcode.GE) && !v.All)
		er.handleOtherComparableSubq(lexpr, rexpr, np, useMin, v.Op.String(), v.All)
	}
	if er.asScalar {
		// The parent expression only use the last column in schema, which represents whether the condition is matched.
		er.ctxStack[len(er.ctxStack)-1] = er.p.Schema().Columns[er.p.Schema().Len()-1]
		er.ctxNameStk[len(er.ctxNameStk)-1] = er.p.OutputNames()[er.p.Schema().Len()-1]
	}
	return v, true
}

// handleOtherComparableSubq handles the queries like < any, < max, etc. For example, if the query is t.id < any (select s.id from s),
// it will be rewrote to t.id < (select max(s.id) from s).
func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) {
	plan4Agg := LogicalAggregation{}.Init(er.sctx, er.b.getSelectOffset())
	if hint := er.b.TableHints(); hint != nil {
		plan4Agg.aggHints = hint.aggHints
	}
	plan4Agg.SetChildren(np)

	// Create a "max" or "min" aggregation.
	funcName := ast.AggFuncMax
	if useMin {
		funcName = ast.AggFuncMin
	}
	funcMaxOrMin, err := aggregation.NewAggFuncDesc(er.sctx, funcName, []expression.Expression{rexpr}, false)
	if err != nil {
		er.err = err
		return
	}

	// Create a column and append it to the schema of that aggregation.
	colMaxOrMin := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  funcMaxOrMin.RetTp,
	}
	colMaxOrMin.SetCoercibility(rexpr.Coercibility())
	schema := expression.NewSchema(colMaxOrMin)

	plan4Agg.names = append(plan4Agg.names, types.EmptyName)
	plan4Agg.SetSchema(schema)
	plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin}

	cond := expression.NewFunctionInternal(er.sctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin)
	er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all)
}

// buildQuantifierPlan adds extra condition for any / all subquery.
func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) {
	innerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr)
	outerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr)

	funcSum, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false)
	if err != nil {
		er.err = err
		return
	}
	colSum := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  funcSum.RetTp,
	}
	plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum)
	plan4Agg.schema.Append(colSum)
	innerHasNull := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.NewZero())

	// Build `count(1)` aggregation to check if subquery is empty.
	funcCount, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false)
	if err != nil {
		er.err = err
		return
	}
	colCount := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  funcCount.RetTp,
	}
	plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
	plan4Agg.schema.Append(colCount)

	if all {
		// All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it
		// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true).
		innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.NewNull(), expression.NewOne())
		cond = expression.ComposeCNFCondition(er.sctx, cond, innerNullChecker)
		// If the subquery is empty, it should always return true.
		emptyChecker := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.NewZero())
		// If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`.
		outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.NewNull(), expression.NewZero())
		cond = expression.ComposeDNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
	} else {
		// For "any" expression, if the subquery has null and the cond returns false, the result should be NULL.
		// Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)`
		innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.NewNull(), expression.NewZero())
		cond = expression.ComposeDNFCondition(er.sctx, cond, innerNullChecker)
		// If the subquery is empty, it should always return false.
		emptyChecker := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.NewZero())
		// If outer key is null, and subquery is not empty, it should return null.
		outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.NewNull(), expression.NewOne())
		cond = expression.ComposeCNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
	}

	// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
	// plan4Agg.buildProjectionIfNecessary()
	if !er.asScalar {
		// For Semi LogicalApply without aux column, the result is no matter false or null. So we can add it to join predicate.
		er.p, er.err = er.b.buildSemiApply(er.p, plan4Agg, []expression.Expression{cond}, false, false, false)
		return
	}
	// If we treat the result as a scalar value, we will add a projection with a extra column to output true, false or null.
	outerSchemaLen := er.p.Schema().Len()
	er.p = er.b.buildApplyWithJoinType(er.p, plan4Agg, InnerJoin)
	joinSchema := er.p.Schema()
	proj := LogicalProjection{
		Exprs: expression.Column2Exprs(joinSchema.Clone().Columns[:outerSchemaLen]),
	}.Init(er.sctx, er.b.getSelectOffset())
	proj.names = make([]*types.FieldName, outerSchemaLen, outerSchemaLen+1)
	copy(proj.names, er.p.OutputNames())
	proj.SetSchema(expression.NewSchema(joinSchema.Clone().Columns[:outerSchemaLen]...))
	proj.Exprs = append(proj.Exprs, cond)
	proj.schema.Append(&expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  cond.GetType(),
	})
	proj.names = append(proj.names, types.EmptyName)
	proj.SetChildren(er.p)
	er.p = proj
}

// handleNEAny handles the case of != any. For example, if the query is t.id != any (select s.id from s), it will be rewrote to
// t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id ,
// there must exist a s.id that doesn't equal to t.id.
func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np LogicalPlan) {
	// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id != s.id.
	// So use function max to filter NULL.
	maxFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
	if err != nil {
		er.err = err
		return
	}
	countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []expression.Expression{rexpr}, true)
	if err != nil {
		er.err = err
		return
	}
	plan4Agg := LogicalAggregation{
		AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
	}.Init(er.sctx, er.b.getSelectOffset())
	if hint := er.b.TableHints(); hint != nil {
		plan4Agg.aggHints = hint.aggHints
	}
	plan4Agg.SetChildren(np)
	maxResultCol := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  maxFunc.RetTp,
	}
	maxResultCol.SetCoercibility(rexpr.Coercibility())
	count := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  countFunc.RetTp,
	}
	plan4Agg.names = append(plan4Agg.names, types.EmptyName, types.EmptyName)
	plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
	gtFunc := expression.NewFunctionInternal(er.sctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
	neCond := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
	cond := expression.ComposeDNFCondition(er.sctx, gtFunc, neCond)
	er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false)
}

// handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
// t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]).
func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np LogicalPlan) {
	firstRowFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false)
	if err != nil {
		er.err = err
		return
	}
	countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []expression.Expression{rexpr}, true)
	if err != nil {
		er.err = err
		return
	}
	plan4Agg := LogicalAggregation{
		AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
	}.Init(er.sctx, er.b.getSelectOffset())
	if hint := er.b.TableHints(); hint != nil {
		plan4Agg.aggHints = hint.aggHints
	}
	plan4Agg.SetChildren(np)
	plan4Agg.names = append(plan4Agg.names, types.EmptyName)

	// Currently, firstrow agg function is treated like the exact representation of aggregate group key,
	// so the data type is the same with group key, even if the group key is not null.
	// However, the return type of firstrow should be nullable, we clear the null flag here instead of
	// during invoking NewAggFuncDesc, in order to keep compatibility with the existing presumption
	// that the return type firstrow does not change nullability, whatsoever.
	// Cloning it because the return type is the same object with argument's data type.
	newRetTp := firstRowFunc.RetTp.Clone()
	newRetTp.DelFlag(mysql.NotNullFlag)
	firstRowFunc.RetTp = newRetTp

	firstRowResultCol := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  firstRowFunc.RetTp,
	}
	firstRowResultCol.SetCoercibility(rexpr.Coercibility())
	plan4Agg.names = append(plan4Agg.names, types.EmptyName)
	count := &expression.Column{
		UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(),
		RetType:  countFunc.RetTp,
	}
	plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count))
	leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
	eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol)
	cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond)
	er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true)
}

func (er *expressionRewriter) handleExistSubquery(ctx context.Context, v *ast.ExistsSubqueryExpr) (ast.Node, bool) {
	ci := er.b.prepareCTECheckForSubQuery()
	defer resetCTECheckForSubQuery(ci)
	subq, ok := v.Sel.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown exists type %T", v.Sel)
		return v, true
	}
	np, hasRewriteHint, err := er.buildSubquery(ctx, subq, true)
	if err != nil {
		er.err = err
		return v, true
	}
	np = er.popExistsSubPlan(np)

	if er.b.disableSubQueryPreprocessing || len(ExtractCorrelatedCols4LogicalPlan(np)) > 0 {
		er.p, er.err = er.b.buildSemiApply(er.p, np, nil, er.asScalar, v.Not, hasRewriteHint)
		if er.err != nil || !er.asScalar {
			return v, true
		}
		er.ctxStackAppend(er.p.Schema().Columns[er.p.Schema().Len()-1], er.p.OutputNames()[er.p.Schema().Len()-1])
	} else {
		// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
		NthPlanBackup := er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan
		er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1
		physicalPlan, _, err := DoOptimize(ctx, er.sctx, er.b.optFlag, np)
		er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = NthPlanBackup
		if err != nil {
			er.err = err
			return v, true
		}
		row, err := EvalSubqueryFirstRow(ctx, physicalPlan, er.b.is, er.b.ctx)
		if err != nil {
			er.err = err
			return v, true
		}
		if (row != nil && !v.Not) || (row == nil && v.Not) {
			er.ctxStackAppend(expression.NewOne(), types.EmptyName)
		} else {
			er.ctxStackAppend(expression.NewZero(), types.EmptyName)
		}
	}
	return v, true
}

// popExistsSubPlan will remove the useless plan in exist's child.
// See comments inside the method for more details.
func (er *expressionRewriter) popExistsSubPlan(p LogicalPlan) LogicalPlan {
out:
	for {
		switch plan := p.(type) {
		// This can be removed when in exists clause,
		// e.g. exists(select count(*) from t order by a) is equal to exists t.
		case *LogicalProjection, *LogicalSort:
			p = p.Children()[0]
		case *LogicalAggregation:
			if len(plan.GroupByItems) == 0 {
				p = LogicalTableDual{RowCount: 1}.Init(er.sctx, er.b.getSelectOffset())
				break out
			}
			p = p.Children()[0]
		default:
			break out
		}
	}
	return p
}

func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.PatternInExpr) (ast.Node, bool) {
	ci := er.b.prepareCTECheckForSubQuery()
	defer resetCTECheckForSubQuery(ci)
	asScalar := er.asScalar
	er.asScalar = true
	v.Expr.Accept(er)
	if er.err != nil {
		return v, true
	}
	lexpr := er.ctxStack[len(er.ctxStack)-1]
	subq, ok := v.Sel.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown compare type %T", v.Sel)
		return v, true
	}
	np, _, err := er.buildSubquery(ctx, subq, false)
	if err != nil {
		er.err = err
		return v, true
	}
	lLen := expression.GetRowLen(lexpr)
	if lLen != np.Schema().Len() {
		er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen)
		return v, true
	}
	var rexpr expression.Expression
	if np.Schema().Len() == 1 {
		rexpr = np.Schema().Columns[0]
		rCol := rexpr.(*expression.Column)
		// For AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, we cannot treat `in` expression as
		// normal column equal condition, so we specially mark the inner operand here.
		if v.Not || asScalar {
			// If both input columns of `in` expression are not null, we can treat the expression
			// as normal column equal condition instead. Otherwise, mark the left and right side.
			// eg: for some optimization, the column substitute in right side in projection elimination
			// will cause case  like <lcol EQ rcol(inOperand)> as <lcol EQ constant> which is not
			// a valid null-aware EQ. (null in lcol still need to be null-aware)
			if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) {
				rColCopy := *rCol
				rColCopy.InOperand = true
				rexpr = &rColCopy
				lexpr = expression.SetExprColumnInOperand(lexpr)
			}
		}
	} else {
		args := make([]expression.Expression, 0, np.Schema().Len())
		for i, col := range np.Schema().Columns {
			if v.Not || asScalar {
				larg := expression.GetFuncArg(lexpr, i)
				// If both input columns of `in` expression are not null, we can treat the expression
				// as normal column equal condition instead. Otherwise, mark the left and right side.
				if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) {
					rarg := *col
					rarg.InOperand = true
					col = &rarg
					if larg != nil {
						lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg)
					}
				}
			}
			args = append(args, col)
		}
		rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
		if er.err != nil {
			return v, true
		}
	}
	checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ)
	if err != nil {
		er.err = err
		return v, true
	}

	// If the leftKey and the rightKey have different collations, don't convert the sub-query to an inner-join
	// since when converting we will add a distinct-agg upon the right child and this distinct-agg doesn't have the right collation.
	// To keep it simple, we forbid this converting if they have different collations.
	lt, rt := lexpr.GetType(), rexpr.GetType()
	collFlag := collate.CompatibleCollate(lt.GetCollate(), rt.GetCollate())

	// If it's not the form of `not in (SUBQUERY)`,
	// and has no correlated column from the current level plan(if the correlated column is from upper level,
	// we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node),
	// and don't need to append a scalar value, we can rewrite it to inner join.
	if er.sctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(extractCorColumnsBySchema4LogicalPlan(np, er.p.Schema())) == 0 && collFlag {
		// We need to try to eliminate the agg and the projection produced by this operation.
		er.b.optFlag |= flagEliminateAgg
		er.b.optFlag |= flagEliminateProjection
		er.b.optFlag |= flagJoinReOrder
		// Build distinct for the inner query.
		agg, err := er.b.buildDistinct(np, np.Schema().Len())
		if err != nil {
			er.err = err
			return v, true
		}
		// Build inner join above the aggregation.
		join := LogicalJoin{JoinType: InnerJoin}.Init(er.sctx, er.b.getSelectOffset())
		join.SetChildren(er.p, agg)
		join.SetSchema(expression.MergeSchema(er.p.Schema(), agg.schema))
		join.names = make([]*types.FieldName, er.p.Schema().Len()+agg.Schema().Len())
		copy(join.names, er.p.OutputNames())
		copy(join.names[er.p.Schema().Len():], agg.OutputNames())
		join.AttachOnConds(expression.SplitCNFItems(checkCondition))
		// Set join hint for this join.
		if er.b.TableHints() != nil {
			join.setPreferredJoinTypeAndOrder(er.b.TableHints())
		}
		er.p = join
	} else {
		er.p, er.err = er.b.buildSemiApply(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not, false)
		if er.err != nil {
			return v, true
		}
	}

	er.ctxStackPop(1)
	if asScalar {
		col := er.p.Schema().Columns[er.p.Schema().Len()-1]
		er.ctxStackAppend(col, er.p.OutputNames()[er.p.Schema().Len()-1])
	}
	return v, true
}

func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, v *ast.SubqueryExpr) (ast.Node, bool) {
	ci := er.b.prepareCTECheckForSubQuery()
	defer resetCTECheckForSubQuery(ci)
	np, _, err := er.buildSubquery(ctx, v, false)
	if err != nil {
		er.err = err
		return v, true
	}
	np = er.b.buildMaxOneRow(np)
	if er.b.disableSubQueryPreprocessing || len(ExtractCorrelatedCols4LogicalPlan(np)) > 0 {
		er.p = er.b.buildApplyWithJoinType(er.p, np, LeftOuterJoin)
		if np.Schema().Len() > 1 {
			newCols := make([]expression.Expression, 0, np.Schema().Len())
			for _, col := range np.Schema().Columns {
				newCols = append(newCols, col)
			}
			expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(), newCols...)
			if err1 != nil {
				er.err = err1
				return v, true
			}
			er.ctxStackAppend(expr, types.EmptyName)
		} else {
			er.ctxStackAppend(er.p.Schema().Columns[er.p.Schema().Len()-1], er.p.OutputNames()[er.p.Schema().Len()-1])
		}
		return v, true
	}
	// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
	NthPlanBackup := er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan
	er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1
	physicalPlan, _, err := DoOptimize(ctx, er.sctx, er.b.optFlag, np)
	er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = NthPlanBackup
	if err != nil {
		er.err = err
		return v, true
	}
	row, err := EvalSubqueryFirstRow(ctx, physicalPlan, er.b.is, er.b.ctx)
	if err != nil {
		er.err = err
		return v, true
	}
	if np.Schema().Len() > 1 {
		newCols := make([]expression.Expression, 0, np.Schema().Len())
		for i, data := range row {
			constant := &expression.Constant{
				Value:   data,
				RetType: np.Schema().Columns[i].GetType()}
			constant.SetCoercibility(np.Schema().Columns[i].Coercibility())
			newCols = append(newCols, constant)
		}
		expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(), newCols...)
		if err1 != nil {
			er.err = err1
			return v, true
		}
		er.ctxStackAppend(expr, types.EmptyName)
	} else {
		constant := &expression.Constant{
			Value:   row[0],
			RetType: np.Schema().Columns[0].GetType(),
		}
		constant.SetCoercibility(np.Schema().Columns[0].Coercibility())
		er.ctxStackAppend(constant, types.EmptyName)
	}
	return v, true
}

// Leave implements Visitor interface.
func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) {
	if er.err != nil {
		return retNode, false
	}
	var inNode = originInNode
	if er.preprocess != nil {
		inNode = er.preprocess(inNode)
	}
	switch v := inNode.(type) {
	case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause,
		*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.TableNameExpr:
	case *driver.ValueExpr:
		// set right not null flag for constant value
		retType := v.Type.Clone()
		switch v.Datum.Kind() {
		case types.KindNull:
			retType.DelFlag(mysql.NotNullFlag)
		default:
			retType.AddFlag(mysql.NotNullFlag)
		}
		v.Datum.SetValue(v.Datum.GetValue(), retType)
		value := &expression.Constant{Value: v.Datum, RetType: retType}
		value.SetRepertoire(expression.ASCII)
		if retType.EvalType() == types.ETString {
			for _, b := range v.Datum.GetBytes() {
				// if any character in constant is not ascii, set the repertoire to UNICODE.
				if b >= 0x80 {
					value.SetRepertoire(expression.UNICODE)
					break
				}
			}
		}
		er.ctxStackAppend(value, types.EmptyName)
	case *driver.ParamMarkerExpr:
		var value expression.Expression
		value, er.err = expression.ParamMarkerExpression(er.sctx, v, false)
		if er.err != nil {
			return retNode, false
		}
		er.ctxStackAppend(value, types.EmptyName)
	case *ast.VariableExpr:
		er.rewriteVariable(v)
	case *ast.FuncCallExpr:
		if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
			er.tryFoldCounter--
		}
		er.funcCallToExpression(v)
		if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
			er.disableFoldCounter--
		}
	case *ast.TableName:
		er.toTable(v)
	case *ast.ColumnName:
		er.toColumn(v)
	case *ast.UnaryOperationExpr:
		er.unaryOpToExpression(v)
	case *ast.BinaryOperationExpr:
		if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr {
			er.tryFoldCounter--
		}
		er.binaryOpToExpression(v)
	case *ast.BetweenExpr:
		er.betweenToExpression(v)
	case *ast.CaseExpr:
		if _, ok := expression.TryFoldFunctions["case"]; ok {
			er.tryFoldCounter--
		}
		er.caseToExpression(v)
		if _, ok := expression.DisableFoldFunctions["case"]; ok {
			er.disableFoldCounter--
		}
	case *ast.FuncCastExpr:
		arg := er.ctxStack[len(er.ctxStack)-1]
		er.err = expression.CheckArgsNotMultiColumnRow(arg)
		if er.err != nil {
			return retNode, false
		}

		// check the decimal precision of "CAST(AS TIME)".
		er.err = er.checkTimePrecision(v.Tp)
		if er.err != nil {
			return retNode, false
		}

		castFunction := expression.BuildCastFunction(er.sctx, arg, v.Tp)
		if v.Tp.EvalType() == types.ETString {
			castFunction.SetCoercibility(expression.CoercibilityImplicit)
			if v.Tp.GetCharset() == charset.CharsetASCII {
				castFunction.SetRepertoire(expression.ASCII)
			} else {
				castFunction.SetRepertoire(expression.UNICODE)
			}
		} else {
			castFunction.SetCoercibility(expression.CoercibilityNumeric)
			castFunction.SetRepertoire(expression.ASCII)
		}

		er.ctxStack[len(er.ctxStack)-1] = castFunction
		er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName
	case *ast.PatternLikeExpr:
		er.patternLikeToExpression(v)
	case *ast.PatternRegexpExpr:
		er.regexpToScalarFunc(v)
	case *ast.RowExpr:
		er.rowToScalarFunc(v)
	case *ast.PatternInExpr:
		if v.Sel == nil {
			er.inToExpression(len(v.List), v.Not, &v.Type)
		}
	case *ast.PositionExpr:
		er.positionToScalarFunc(v)
	case *ast.IsNullExpr:
		er.isNullToExpression(v)
	case *ast.IsTruthExpr:
		er.isTrueToScalarFunc(v)
	case *ast.DefaultExpr:
		er.evalDefaultExpr(v)
	// TODO: Perhaps we don't need to transcode these back to generic integers/strings
	case *ast.TrimDirectionExpr:
		er.ctxStackAppend(&expression.Constant{
			Value:   types.NewIntDatum(int64(v.Direction)),
			RetType: types.NewFieldType(mysql.TypeTiny),
		}, types.EmptyName)
	case *ast.TimeUnitExpr:
		er.ctxStackAppend(&expression.Constant{
			Value:   types.NewStringDatum(v.Unit.String()),
			RetType: types.NewFieldType(mysql.TypeVarchar),
		}, types.EmptyName)
	case *ast.GetFormatSelectorExpr:
		er.ctxStackAppend(&expression.Constant{
			Value:   types.NewStringDatum(v.Selector.String()),
			RetType: types.NewFieldType(mysql.TypeVarchar),
		}, types.EmptyName)
	case *ast.SetCollationExpr:
		arg := er.ctxStack[len(er.ctxStack)-1]
		if collate.NewCollationEnabled() {
			var collInfo *charset.Collation
			// TODO(bb7133): use charset.ValidCharsetAndCollation when its bug is fixed.
			if collInfo, er.err = collate.GetCollationByName(v.Collate); er.err != nil {
				break
			}
			chs := arg.GetType().GetCharset()
			// if the field is json, the charset is always utf8mb4.
			if arg.GetType().GetType() == mysql.TypeJSON {
				chs = mysql.UTF8MB4Charset
			}
			if chs != "" && collInfo.CharsetName != chs {
				er.err = charset.ErrCollationCharsetMismatch.GenWithStackByArgs(collInfo.Name, chs)
				break
			}
		}
		// SetCollationExpr sets the collation explicitly, even when the evaluation type of the expression is non-string.
		if _, ok := arg.(*expression.Column); ok || arg.GetType().GetType() == mysql.TypeJSON {
			if arg.GetType().GetType() == mysql.TypeEnum || arg.GetType().GetType() == mysql.TypeSet {
				er.err = ErrNotSupportedYet.GenWithStackByArgs("use collate clause for enum or set")
				break
			}
			// Wrap a cast here to avoid changing the original FieldType of the column expression.
			exprType := arg.GetType().Clone()
			// if arg type is json, we should cast it to longtext if there is collate clause.
			if arg.GetType().GetType() == mysql.TypeJSON {
				exprType = types.NewFieldType(mysql.TypeLongBlob)
				exprType.SetCharset(mysql.UTF8MB4Charset)
			}
			exprType.SetCollate(v.Collate)
			casted := expression.BuildCastFunction(er.sctx, arg, exprType)
			arg = casted
			er.ctxStackPop(1)
			er.ctxStackAppend(casted, types.EmptyName)
		} else {
			// For constant and scalar function, we can set its collate directly.
			arg.GetType().SetCollate(v.Collate)
		}
		er.ctxStack[len(er.ctxStack)-1].SetCoercibility(expression.CoercibilityExplicit)
		er.ctxStack[len(er.ctxStack)-1].SetCharsetAndCollation(arg.GetType().GetCharset(), arg.GetType().GetCollate())
	default:
		er.err = errors.Errorf("UnknownType: %T", v)
		return retNode, false
	}

	if er.err != nil {
		return retNode, false
	}
	return originInNode, true
}

// newFunction chooses which expression.NewFunctionImpl() will be used.
func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldType, args ...expression.Expression) (ret expression.Expression, err error) {
	if er.disableFoldCounter > 0 {
		ret, err = expression.NewFunctionBase(er.sctx, funcName, retType, args...)
	} else if er.tryFoldCounter > 0 {
		ret, err = expression.NewFunctionTryFold(er.sctx, funcName, retType, args...)
	} else {
		ret, err = expression.NewFunction(er.sctx, funcName, retType, args...)
	}
	if err != nil {
		return
	}
	if scalarFunc, ok := ret.(*expression.ScalarFunction); ok {
		er.b.ctx.BuiltinFunctionUsageInc(scalarFunc.Function.PbCode().String())
	}
	return
}

func (er *expressionRewriter) checkTimePrecision(ft *types.FieldType) error {
	if ft.EvalType() == types.ETDuration && ft.GetDecimal() > types.MaxFsp {
		return errTooBigPrecision.GenWithStackByArgs(ft.GetDecimal(), "CAST", types.MaxFsp)
	}
	return nil
}

func (er *expressionRewriter) useCache() bool {
	return er.sctx.GetSessionVars().StmtCtx.UseCache
}

func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) {
	stkLen := len(er.ctxStack)
	name := strings.ToLower(v.Name)
	sessionVars := er.b.ctx.GetSessionVars()
	if !v.IsSystem {
		if v.Value != nil {
			tp := er.ctxStack[stkLen-1].GetType()
			er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, tp,
				expression.DatumToConstant(types.NewDatum(name), mysql.TypeString, 0),
				er.ctxStack[stkLen-1])
			er.ctxNameStk[stkLen-1] = types.EmptyName
			// Store the field type of the variable into SessionVars.UserVarTypes.
			// Normally we can infer the type from SessionVars.User, but we need SessionVars.UserVarTypes when
			// GetVar has not been executed to fill the SessionVars.Users.
			sessionVars.SetUserVarType(name, tp)
			return
		}
		tp, ok := sessionVars.GetUserVarType(name)
		if !ok {
			tp = types.NewFieldType(mysql.TypeVarString)
			tp.SetFlen(mysql.MaxFieldVarCharLength)
		}
		f, err := er.newFunction(ast.GetVar, tp, expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString, 0))
		if err != nil {
			er.err = err
			return
		}
		f.SetCoercibility(expression.CoercibilityImplicit)
		er.ctxStackAppend(f, types.EmptyName)
		return
	}
	sysVar := variable.GetSysVar(name)
	if sysVar == nil {
		er.err = variable.ErrUnknownSystemVar.FastGenByArgs(name)
		if err := variable.CheckSysVarIsRemoved(name); err != nil {
			// Removed vars still return an error, but we customize it from
			// "unknown" to an explanation of why it is not supported.
			// This is important so users at least know they had the name correct.
			er.err = err
		}
		return
	}
	if sysVar.IsNoop && !variable.EnableNoopVariables.Load() {
		// The variable does nothing, append a warning to the statement output.
		sessionVars.StmtCtx.AppendWarning(ErrGettingNoopVariable.GenWithStackByArgs(sysVar.Name))
	}
	if sem.IsEnabled() && sem.IsInvisibleSysVar(sysVar.Name) {
		err := ErrSpecificAccessDenied.GenWithStackByArgs("RESTRICTED_VARIABLES_ADMIN")
		er.b.visitInfo = appendDynamicVisitInfo(er.b.visitInfo, "RESTRICTED_VARIABLES_ADMIN", false, err)
	}
	if v.ExplicitScope && !sysVar.HasNoneScope() {
		if v.IsGlobal && !(sysVar.HasGlobalScope() || sysVar.HasInstanceScope()) {
			er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "SESSION")
			return
		}
		if !v.IsGlobal && !sysVar.HasSessionScope() {
			er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "GLOBAL")
			return
		}
	}
	var val string
	var err error
	if sysVar.HasNoneScope() {
		val = sysVar.Value
	} else if v.IsGlobal {
		val, err = sessionVars.GetGlobalSystemVar(name)
	} else {
		val, err = sessionVars.GetSessionOrGlobalSystemVar(name)
	}
	if err != nil {
		er.err = err
		return
	}
	nativeVal, nativeType, nativeFlag := sysVar.GetNativeValType(val)
	e := expression.DatumToConstant(nativeVal, nativeType, nativeFlag)
	charset, _ := sessionVars.GetSystemVar(variable.CharacterSetConnection)
	e.GetType().SetCharset(charset)
	collate, _ := sessionVars.GetSystemVar(variable.CollationConnection)
	e.GetType().SetCollate(collate)
	er.ctxStackAppend(e, types.EmptyName)
}

func (er *expressionRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) {
	stkLen := len(er.ctxStack)
	var op string
	switch v.Op {
	case opcode.Plus:
		// expression (+ a) is equal to a
		return
	case opcode.Minus:
		op = ast.UnaryMinus
	case opcode.BitNeg:
		op = ast.BitNeg
	case opcode.Not, opcode.Not2:
		op = ast.UnaryNot
	default:
		er.err = errors.Errorf("Unknown Unary Op %T", v.Op)
		return
	}
	if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
		er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
		return
	}
	er.ctxStack[stkLen-1], er.err = er.newFunction(op, &v.Type, er.ctxStack[stkLen-1])
	er.ctxNameStk[stkLen-1] = types.EmptyName
}

func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) {
	stkLen := len(er.ctxStack)
	var function expression.Expression
	switch v.Op {
	case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
		function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
			v.Op.String())
	default:
		lLen := expression.GetRowLen(er.ctxStack[stkLen-2])
		rLen := expression.GetRowLen(er.ctxStack[stkLen-1])
		if lLen != 1 || rLen != 1 {
			er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
			return
		}
		function, er.err = er.newFunction(v.Op.String(), types.NewFieldType(mysql.TypeUnspecified), er.ctxStack[stkLen-2:]...)
	}
	if er.err != nil {
		return
	}
	er.ctxStackPop(2)
	er.ctxStackAppend(function, types.EmptyName)
}

func (er *expressionRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType,
	args ...expression.Expression) expression.Expression {
	opFunc, err := er.newFunction(op, tp, args...)
	if err != nil {
		er.err = err
		return nil
	}
	if !hasNot {
		return opFunc
	}

	opFunc, err = er.newFunction(ast.UnaryNot, tp, opFunc)
	if err != nil {
		er.err = err
		return nil
	}
	return opFunc
}

func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
	stkLen := len(er.ctxStack)
	if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
		er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
		return
	}
	function := er.notToExpression(v.Not, ast.IsNull, &v.Type, er.ctxStack[stkLen-1])
	er.ctxStackPop(1)
	er.ctxStackAppend(function, types.EmptyName)
}

func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) {
	pos := v.N
	str := strconv.Itoa(pos)
	if v.P != nil {
		stkLen := len(er.ctxStack)
		val := er.ctxStack[stkLen-1]
		intNum, isNull, err := expression.GetIntFromConstant(er.sctx, val)
		str = "?"
		if err == nil {
			if isNull {
				return
			}
			pos = intNum
			er.ctxStackPop(1)
		}
		er.err = err
	}
	if er.err == nil && pos > 0 && pos <= er.schema.Len() && !er.schema.Columns[pos-1].IsHidden {
		er.ctxStackAppend(er.schema.Columns[pos-1], er.names[pos-1])
	} else {
		er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause])
	}
}

func (er *expressionRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) {
	stkLen := len(er.ctxStack)
	op := ast.IsTruthWithoutNull
	if v.True == 0 {
		op = ast.IsFalsity
	}
	if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
		er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
		return
	}
	function := er.notToExpression(v.Not, op, &v.Type, er.ctxStack[stkLen-1])
	er.ctxStackPop(1)
	er.ctxStackAppend(function, types.EmptyName)
}

// inToExpression converts in expression to a scalar function. The argument lLen means the length of in list.
// The argument not means if the expression is not in. The tp stands for the expression type, which is always bool.
// a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`.
func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) {
	stkLen := len(er.ctxStack)
	l := expression.GetRowLen(er.ctxStack[stkLen-lLen-1])
	for i := 0; i < lLen; i++ {
		if l != expression.GetRowLen(er.ctxStack[stkLen-lLen+i]) {
			er.err = expression.ErrOperandColumns.GenWithStackByArgs(l)
			return
		}
	}
	args := er.ctxStack[stkLen-lLen-1:]
	leftFt := args[0].GetType()
	leftEt, leftIsNull := leftFt.EvalType(), leftFt.GetType() == mysql.TypeNull
	if leftIsNull {
		er.ctxStackPop(lLen + 1)
		er.ctxStackAppend(expression.NewNull(), types.EmptyName)
		return
	}
	if leftEt == types.ETInt {
		for i := 1; i < len(args); i++ {
			if c, ok := args[i].(*expression.Constant); ok {
				var isExceptional bool
				if expression.MaybeOverOptimized4PlanCache(er.sctx, []expression.Expression{c}) {
					if c.GetType().EvalType() == types.ETString {
						// To keep the result be compatible with MySQL, refine `int non-constant <cmp> str constant`
						// here and skip this refine operation in all other cases for safety.
						er.sctx.GetSessionVars().StmtCtx.SkipPlanCache = true
						expression.RemoveMutableConst(er.sctx, []expression.Expression{c})
					} else {
						continue
					}
				} else if er.sctx.GetSessionVars().StmtCtx.SkipPlanCache {
					// We should remove the mutable constant for correctness, because its value may be changed.
					expression.RemoveMutableConst(er.sctx, []expression.Expression{c})
				}
				args[i], isExceptional = expression.RefineComparedConstant(er.sctx, *leftFt, c, opcode.EQ)
				if isExceptional {
					args[i] = c
				}
			}
		}
	}
	allSameType := true
	for _, arg := range args[1:] {
		if arg.GetType().GetType() != mysql.TypeNull && expression.GetAccurateCmpType(args[0], arg) != leftEt {
			allSameType = false
			break
		}
	}
	var function expression.Expression
	if allSameType && l == 1 && lLen > 1 {
		function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...)
	} else {
		// If we rewrite IN to EQ, we need to decide what's the collation EQ uses.
		coll := er.deriveCollationForIn(l, lLen, args)
		if er.err != nil {
			return
		}
		er.castCollationForIn(l, lLen, stkLen, coll)
		eqFunctions := make([]expression.Expression, 0, lLen)
		for i := stkLen - lLen; i < stkLen; i++ {
			expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
			if err != nil {
				er.err = err
				return
			}
			eqFunctions = append(eqFunctions, expr)
		}
		function = expression.ComposeDNFCondition(er.sctx, eqFunctions...)
		if not {
			var err error
			function, err = er.newFunction(ast.UnaryNot, tp, function)
			if err != nil {
				er.err = err
				return
			}
		}
	}
	er.ctxStackPop(lLen + 1)
	er.ctxStackAppend(function, types.EmptyName)
}

// deriveCollationForIn derives collation for in expression.
// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)).
func (er *expressionRewriter) deriveCollationForIn(colLen int, _ int, args []expression.Expression) *expression.ExprCollation {
	if colLen == 1 {
		// a in (x, y, z) => coll[0]
		coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
		er.err = err
		if er.err != nil {
			return nil
		}
		return coll2
	}
	return nil
}

// castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we
// rewrite it to equal expression.
func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll *expression.ExprCollation) {
	// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)).
	if colLen != 1 {
		return
	}
	for i := stkLen - elemCnt; i < stkLen; i++ {
		// todo: consider refining the code and reusing expression.BuildCollationFunction here
		if er.ctxStack[i].GetType().EvalType() == types.ETString {
			rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction)
			if ok && rowFunc.FuncName.String() == ast.RowFunc {
				continue
			}
			// Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string.
			if er.ctxStack[i].GetType().GetCollate() == coll.Collation {
				continue
			}
			tp := er.ctxStack[i].GetType().Clone()
			if er.ctxStack[i].GetType().Hybrid() {
				if expression.GetAccurateCmpType(er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString {
					tp = types.NewFieldType(mysql.TypeVarString)
				} else {
					continue
				}
			} else if coll.Charset == charset.CharsetBin {
				// When cast character string to binary string, if we still use fixed length representation,
				// then 0 padding will be used, which can affect later execution.
				// e.g. https://github.com/pingcap/tidb/pull/35053#pullrequestreview-1008757770 gives an unexpected case.
				// On the other hand, we can not directly return origin expr back,
				// since we need binary collation to do string comparison later.
				// Here we use VarString type of cast, i.e `cast(a as binary)`, to avoid this problem.
				tp.SetType(mysql.TypeVarString)
			}
			tp.SetCharset(coll.Charset)
			tp.SetCollate(coll.Collation)
			er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp)
			er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit)
		}
	}
}

func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) {
	stkLen := len(er.ctxStack)
	argsLen := 2 * len(v.WhenClauses)
	if v.ElseClause != nil {
		argsLen++
	}
	er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[stkLen-argsLen:]...)
	if er.err != nil {
		return
	}

	// value                          -> ctxStack[stkLen-argsLen-1]
	// when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1];
	// else clause                    -> ctxStack[stkLen-1]
	var args []expression.Expression
	if v.Value != nil {
		// args:  eq scalar func(args: value, condition1), result1,
		//        eq scalar func(args: value, condition2), result2,
		//        ...
		//        else clause
		value := er.ctxStack[stkLen-argsLen-1]
		args = make([]expression.Expression, 0, argsLen)
		for i := stkLen - argsLen; i < stkLen-1; i += 2 {
			arg, err := er.newFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), value, er.ctxStack[i])
			if err != nil {
				er.err = err
				return
			}
			args = append(args, arg)
			args = append(args, er.ctxStack[i+1])
		}
		if v.ElseClause != nil {
			args = append(args, er.ctxStack[stkLen-1])
		}
		argsLen++ // for trimming the value element later
	} else {
		// args:  condition1, result1,
		//        condition2, result2,
		//        ...
		//        else clause
		args = er.ctxStack[stkLen-argsLen:]
	}
	function, err := er.newFunction(ast.Case, &v.Type, args...)
	if err != nil {
		er.err = err
		return
	}
	er.ctxStackPop(argsLen)
	er.ctxStackAppend(function, types.EmptyName)
}

func (er *expressionRewriter) patternLikeToExpression(v *ast.PatternLikeExpr) {
	l := len(er.ctxStack)
	er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[l-2:]...)
	if er.err != nil {
		return
	}

	char, col := er.sctx.GetSessionVars().GetCharsetInfo()
	var function expression.Expression
	fieldType := &types.FieldType{}
	isPatternExactMatch := false
	// Treat predicate 'like' the same way as predicate '=' when it is an exact match and new collation is not enabled.
	if patExpression, ok := er.ctxStack[l-1].(*expression.Constant); ok && !collate.NewCollationEnabled() {
		patString, isNull, err := patExpression.EvalString(nil, chunk.Row{})
		if err != nil {
			er.err = err
			return
		}
		if !isNull {
			patValue, patTypes := stringutil.CompilePattern(patString, v.Escape)
			if stringutil.IsExactMatch(patTypes) && er.ctxStack[l-2].GetType().EvalType() == types.ETString {
				op := ast.EQ
				if v.Not {
					op = ast.NE
				}
				types.DefaultTypeForValue(string(patValue), fieldType, char, col)
				function, er.err = er.constructBinaryOpFunction(er.ctxStack[l-2],
					&expression.Constant{Value: types.NewStringDatum(string(patValue)), RetType: fieldType},
					op)
				isPatternExactMatch = true
			}
		}
	}
	if !isPatternExactMatch {
		types.DefaultTypeForValue(int(v.Escape), fieldType, char, col)
		function = er.notToExpression(v.Not, ast.Like, &v.Type,
			er.ctxStack[l-2], er.ctxStack[l-1], &expression.Constant{Value: types.NewIntDatum(int64(v.Escape)), RetType: fieldType})
	}

	er.ctxStackPop(2)
	er.ctxStackAppend(function, types.EmptyName)
}

func (er *expressionRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) {
	l := len(er.ctxStack)
	er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[l-2:]...)
	if er.err != nil {
		return
	}
	function := er.notToExpression(v.Not, ast.Regexp, &v.Type, er.ctxStack[l-2], er.ctxStack[l-1])
	er.ctxStackPop(2)
	er.ctxStackAppend(function, types.EmptyName)
}

func (er *expressionRewriter) rowToScalarFunc(v *ast.RowExpr) {
	stkLen := len(er.ctxStack)
	length := len(v.Values)
	rows := make([]expression.Expression, 0, length)
	for i := stkLen - length; i < stkLen; i++ {
		rows = append(rows, er.ctxStack[i])
	}
	er.ctxStackPop(length)
	function, err := er.newFunction(ast.RowFunc, rows[0].GetType(), rows...)
	if err != nil {
		er.err = err
		return
	}
	er.ctxStackAppend(function, types.EmptyName)
}

func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Expression) {
	stkLen := len(er.ctxStack)
	expr, lexp, rexp = er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]
	var castFunc func(sessionctx.Context, expression.Expression) expression.Expression
	switch expression.ResolveType4Between([3]expression.Expression{expr, lexp, rexp}) {
	case types.ETInt:
		castFunc = expression.WrapWithCastAsInt
	case types.ETReal:
		castFunc = expression.WrapWithCastAsReal
	case types.ETDecimal:
		castFunc = expression.WrapWithCastAsDecimal
	case types.ETString:
		castFunc = func(ctx sessionctx.Context, e expression.Expression) expression.Expression {
			// string kind expression do not need cast
			if e.GetType().EvalType().IsStringKind() {
				return e
			}
			return expression.WrapWithCastAsString(ctx, e)
		}
	case types.ETDuration:
		expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDuration))
		lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDuration))
		rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDuration))
		return
	case types.ETDatetime:
		expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDatetime))
		lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDatetime))
		rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDatetime))
		return
	default:
		return
	}

	expr = castFunc(er.sctx, expr)
	lexp = castFunc(er.sctx, lexp)
	rexp = castFunc(er.sctx, rexp)
	return
}

func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) {
	stkLen := len(er.ctxStack)
	er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[stkLen-3:]...)
	if er.err != nil {
		return
	}

	expr, lexp, rexp := er.wrapExpWithCast()

	coll, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "BETWEEN", types.ETInt, expr, lexp, rexp)
	er.err = err
	if er.err != nil {
		return
	}

	// Handle enum or set. We need to know their real type to decide whether to cast them.
	lt := expression.GetAccurateCmpType(expr, lexp)
	rt := expression.GetAccurateCmpType(expr, rexp)
	enumOrSetRealTypeIsStr := lt != types.ETInt && rt != types.ETInt

	expr = expression.BuildCastCollationFunction(er.sctx, expr, coll, enumOrSetRealTypeIsStr)
	lexp = expression.BuildCastCollationFunction(er.sctx, lexp, coll, enumOrSetRealTypeIsStr)
	rexp = expression.BuildCastCollationFunction(er.sctx, rexp, coll, enumOrSetRealTypeIsStr)

	var l, r expression.Expression
	l, er.err = expression.NewFunction(er.sctx, ast.GE, &v.Type, expr, lexp)
	if er.err != nil {
		return
	}
	r, er.err = expression.NewFunction(er.sctx, ast.LE, &v.Type, expr, rexp)
	if er.err != nil {
		return
	}
	function, err := er.newFunction(ast.LogicAnd, &v.Type, l, r)
	if err != nil {
		er.err = err
		return
	}
	if v.Not {
		function, err = er.newFunction(ast.UnaryNot, &v.Type, function)
		if err != nil {
			er.err = err
			return
		}
	}
	er.ctxStackPop(3)
	er.ctxStackAppend(function, types.EmptyName)
}

// rewriteFuncCall handles a FuncCallExpr and generates a customized function.
// It should return true if for the given FuncCallExpr a rewrite is performed so that original behavior is skipped.
// Otherwise it should return false to indicate (the caller) that original behavior needs to be performed.
func (er *expressionRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool {
	switch v.FnName.L {
	// when column is not null, ifnull on such column is not necessary.
	case ast.Ifnull:
		if len(v.Args) != 2 {
			er.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
			return true
		}
		stackLen := len(er.ctxStack)
		arg1 := er.ctxStack[stackLen-2]
		col, isColumn := arg1.(*expression.Column)
		var isEnumSet bool
		if arg1.GetType().GetType() == mysql.TypeEnum || arg1.GetType().GetType() == mysql.TypeSet {
			isEnumSet = true
		}
		// if expr1 is a column and column has not null flag, then we can eliminate ifnull on
		// this column.
		if isColumn && !isEnumSet && mysql.HasNotNullFlag(col.RetType.GetFlag()) {
			name := er.ctxNameStk[stackLen-2]
			newCol := col.Clone().(*expression.Column)
			er.ctxStackPop(len(v.Args))
			er.ctxStackAppend(newCol, name)
			return true
		}

		return false
	case ast.Nullif:
		if len(v.Args) != 2 {
			er.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
			return true
		}
		stackLen := len(er.ctxStack)
		param1 := er.ctxStack[stackLen-2]
		param2 := er.ctxStack[stackLen-1]
		// param1 = param2
		funcCompare, err := er.constructBinaryOpFunction(param1, param2, ast.EQ)
		if err != nil {
			er.err = err
			return true
		}
		// NULL
		nullTp := types.NewFieldType(mysql.TypeNull)
		flen, decimal := mysql.GetDefaultFieldLengthAndDecimal(mysql.TypeNull)
		nullTp.SetFlen(flen)
		nullTp.SetDecimal(decimal)
		paramNull := &expression.Constant{
			Value:   types.NewDatum(nil),
			RetType: nullTp,
		}
		// if(param1 = param2, NULL, param1)
		funcIf, err := er.newFunction(ast.If, &v.Type, funcCompare, paramNull, param1)
		if err != nil {
			er.err = err
			return true
		}
		er.ctxStackPop(len(v.Args))
		er.ctxStackAppend(funcIf, types.EmptyName)
		return true
	default:
		return false
	}
}

func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
	stackLen := len(er.ctxStack)
	args := er.ctxStack[stackLen-len(v.Args):]
	er.err = expression.CheckArgsNotMultiColumnRow(args...)
	if er.err != nil {
		return
	}

	if er.rewriteFuncCall(v) {
		return
	}

	var function expression.Expression
	er.ctxStackPop(len(v.Args))
	if _, ok := expression.DeferredFunctions[v.FnName.L]; er.useCache() && ok {
		// When the expression is unix_timestamp and the number of argument is not zero,
		// we deal with it as normal expression.
		if v.FnName.L == ast.UnixTimestamp && len(v.Args) != 0 {
			function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
			er.ctxStackAppend(function, types.EmptyName)
		} else {
			function, er.err = expression.NewFunctionBase(er.sctx, v.FnName.L, &v.Type, args...)
			c := &expression.Constant{Value: types.NewDatum(nil), RetType: function.GetType().Clone(), DeferredExpr: function}
			er.ctxStackAppend(c, types.EmptyName)
		}
	} else {
		function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
		er.ctxStackAppend(function, types.EmptyName)
	}
}

// Now TableName in expression only used by sequence function like nextval(seq).
// The function arg should be evaluated as a table name rather than normal column name like mysql does.
func (er *expressionRewriter) toTable(v *ast.TableName) {
	fullName := v.Name.L
	if len(v.Schema.L) != 0 {
		fullName = v.Schema.L + "." + fullName
	}
	val := &expression.Constant{
		Value:   types.NewDatum(fullName),
		RetType: types.NewFieldType(mysql.TypeString),
	}
	er.ctxStackAppend(val, types.EmptyName)
}

func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
	idx, err := expression.FindFieldName(er.names, v)
	if err != nil {
		er.err = ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
		return
	}
	if idx >= 0 {
		column := er.schema.Columns[idx]
		if column.IsHidden {
			er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[er.b.curClause])
			return
		}
		er.ctxStackAppend(column, er.names[idx])
		return
	}
	for i := len(er.b.outerSchemas) - 1; i >= 0; i-- {
		outerSchema, outerName := er.b.outerSchemas[i], er.b.outerNames[i]
		idx, err = expression.FindFieldName(outerName, v)
		if idx >= 0 {
			column := outerSchema.Columns[idx]
			er.ctxStackAppend(&expression.CorrelatedColumn{Column: *column, Data: new(types.Datum)}, outerName[idx])
			return
		}
		if err != nil {
			er.err = ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
			return
		}
	}
	if _, ok := er.p.(*LogicalUnionAll); ok && v.Table.O != "" {
		er.err = ErrTablenameNotAllowedHere.GenWithStackByArgs(v.Table.O, "SELECT", clauseMsg[er.b.curClause])
		return
	}
	col, name, err := findFieldNameFromNaturalUsingJoin(er.p, v)
	if err != nil {
		er.err = err
		return
	} else if col != nil {
		er.ctxStackAppend(col, name)
		return
	}
	if er.b.curClause == globalOrderByClause {
		er.b.curClause = orderByClause
	}
	er.err = ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause])
}

func findFieldNameFromNaturalUsingJoin(p LogicalPlan, v *ast.ColumnName) (col *expression.Column, name *types.FieldName, err error) {
	switch x := p.(type) {
	case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
		return findFieldNameFromNaturalUsingJoin(p.Children()[0], v)
	case *LogicalJoin:
		if x.fullSchema != nil {
			idx, err := expression.FindFieldName(x.fullNames, v)
			if err != nil {
				return nil, nil, err
			}
			if idx >= 0 {
				return x.fullSchema.Columns[idx], x.fullNames[idx], nil
			}
		}
	}
	return nil, nil, nil
}

func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
	var name *types.FieldName
	// Here we will find the corresponding column for default function. At the same time, we need to consider the issue
	// of subquery and name space.
	// For example, we have two tables t1(a int default 1, b int) and t2(a int default -1, c int). Consider the following SQL:
	// 		select a from t1 where a > (select default(a) from t2)
	// Refer to the behavior of MySQL, we need to find column a in table t2. If table t2 does not have column a, then find it
	// in table t1. If there are none, return an error message.
	// Based on the above description, we need to look in er.b.allNames from back to front.
	for i := len(er.b.allNames) - 1; i >= 0; i-- {
		idx, err := expression.FindFieldName(er.b.allNames[i], v.Name)
		if err != nil {
			er.err = err
			return
		}
		if idx >= 0 {
			name = er.b.allNames[i][idx]
			break
		}
	}
	if name == nil {
		idx, err := expression.FindFieldName(er.names, v.Name)
		if err != nil {
			er.err = err
			return
		}
		if idx < 0 {
			er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field list")
			return
		}
		name = er.names[idx]
	}

	dbName := name.DBName
	if dbName.O == "" {
		// if database name is not specified, use current database name
		dbName = model.NewCIStr(er.sctx.GetSessionVars().CurrentDB)
	}
	if name.OrigTblName.O == "" {
		// column is evaluated by some expressions, for example:
		// `select default(c) from (select (a+1) as c from t) as t0`
		// in such case, a 'no default' error is returned
		er.err = table.ErrNoDefaultValue.GenWithStackByArgs(name.ColName)
		return
	}
	var tbl table.Table
	tbl, er.err = er.b.is.TableByName(dbName, name.OrigTblName)
	if er.err != nil {
		return
	}
	colName := name.OrigColName.O
	if colName == "" {
		// in some cases, OrigColName is empty, use ColName instead
		colName = name.ColName.O
	}
	col := table.FindCol(tbl.Cols(), colName)
	if col == nil {
		er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name, "field_list")
		return
	}
	isCurrentTimestamp := hasCurrentDatetimeDefault(col)
	var val *expression.Constant
	switch {
	case isCurrentTimestamp && (col.GetType() == mysql.TypeDatetime || col.GetType() == mysql.TypeTimestamp):
		t, err := expression.GetTimeValue(er.sctx, ast.CurrentTimestamp, col.GetType(), col.GetDecimal())
		if err != nil {
			return
		}
		val = &expression.Constant{
			Value:   t,
			RetType: types.NewFieldType(col.GetType()),
		}
	default:
		// for other columns, just use what it is
		val, er.err = er.b.getDefaultValue(col)
	}
	if er.err != nil {
		return
	}
	er.ctxStackAppend(val, types.EmptyName)
}

// hasCurrentDatetimeDefault checks if column has current_timestamp default value
func hasCurrentDatetimeDefault(col *table.Column) bool {
	x, ok := col.DefaultValue.(string)
	if !ok {
		return false
	}
	return strings.ToLower(x) == ast.CurrentTimestamp
}

func decodeKeyFromString(ctx sessionctx.Context, s string) string {
	sc := ctx.GetSessionVars().StmtCtx
	key, err := hex.DecodeString(s)
	if err != nil {
		sc.AppendWarning(errors.Errorf("invalid key: %X", key))
		return s
	}
	// Auto decode byte if needed.
	_, bs, err := codec.DecodeBytes(key, nil)
	if err == nil {
		key = bs
	}
	tableID := tablecodec.DecodeTableID(key)
	if tableID == 0 {
		sc.AppendWarning(errors.Errorf("invalid key: %X", key))
		return s
	}
	dm := domain.GetDomain(ctx)
	if dm == nil {
		sc.AppendWarning(errors.Errorf("domain not found when decoding key: %X", key))
		return s
	}
	is := dm.InfoSchema()
	if is == nil {
		sc.AppendWarning(errors.Errorf("infoschema not found when decoding key: %X", key))
		return s
	}
	tbl, _ := is.TableByID(tableID)
	loc := ctx.GetSessionVars().Location()
	if tablecodec.IsRecordKey(key) {
		ret, err := decodeRecordKey(key, tableID, tbl, loc)
		if err != nil {
			sc.AppendWarning(err)
			return s
		}
		return ret
	} else if tablecodec.IsIndexKey(key) {
		ret, err := decodeIndexKey(key, tableID, tbl, loc)
		if err != nil {
			sc.AppendWarning(err)
			return s
		}
		return ret
	} else if tablecodec.IsTableKey(key) {
		ret, err := decodeTableKey(key, tableID)
		if err != nil {
			sc.AppendWarning(err)
			return s
		}
		return ret
	}
	sc.AppendWarning(errors.Errorf("invalid key: %X", key))
	return s
}

func decodeRecordKey(key []byte, tableID int64, tbl table.Table, loc *time.Location) (string, error) {
	_, handle, err := tablecodec.DecodeRecordKey(key)
	if err != nil {
		return "", errors.Trace(err)
	}
	if handle.IsInt() {
		ret := make(map[string]interface{})
		ret["table_id"] = strconv.FormatInt(tableID, 10)
		// When the clustered index is enabled, we should show the PK name.
		if tbl != nil && tbl.Meta().HasClusteredIndex() {
			ret[tbl.Meta().GetPkName().String()] = handle.IntValue()
		} else {
			ret["_tidb_rowid"] = handle.IntValue()
		}
		retStr, err := json.Marshal(ret)
		if err != nil {
			return "", errors.Trace(err)
		}
		return string(retStr), nil
	}
	if tbl != nil {
		tblInfo := tbl.Meta()
		idxInfo := tables.FindPrimaryIndex(tblInfo)
		if idxInfo == nil {
			return "", errors.Trace(errors.Errorf("primary key not found when decoding record key: %X", key))
		}
		cols := make(map[int64]*types.FieldType, len(tblInfo.Columns))
		for _, col := range tblInfo.Columns {
			cols[col.ID] = &(col.FieldType)
		}
		handleColIDs := make([]int64, 0, len(idxInfo.Columns))
		for _, col := range idxInfo.Columns {
			handleColIDs = append(handleColIDs, tblInfo.Columns[col.Offset].ID)
		}

		if len(handleColIDs) != handle.NumCols() {
			return "", errors.Trace(errors.Errorf("primary key length not match handle columns number in key"))
		}
		datumMap, err := tablecodec.DecodeHandleToDatumMap(handle, handleColIDs, cols, loc, nil)
		if err != nil {
			return "", errors.Trace(err)
		}
		ret := make(map[string]interface{})
		ret["table_id"] = tableID
		handleRet := make(map[string]interface{})
		for colID := range datumMap {
			dt := datumMap[colID]
			dtStr, err := datumToJSONObject(&dt)
			if err != nil {
				return "", errors.Trace(err)
			}
			found := false
			for _, colInfo := range tblInfo.Columns {
				if colInfo.ID == colID {
					found = true
					handleRet[colInfo.Name.L] = dtStr
					break
				}
			}
			if !found {
				return "", errors.Trace(errors.Errorf("column not found when decoding record key: %X", key))
			}
		}
		ret["handle"] = handleRet
		retStr, err := json.Marshal(ret)
		if err != nil {
			return "", errors.Trace(err)
		}
		return string(retStr), nil
	}
	ret := make(map[string]interface{})
	ret["table_id"] = tableID
	ret["handle"] = handle.String()
	retStr, err := json.Marshal(ret)
	if err != nil {
		return "", errors.Trace(err)
	}
	return string(retStr), nil
}

func decodeIndexKey(key []byte, tableID int64, tbl table.Table, loc *time.Location) (string, error) {
	if tbl != nil {
		_, indexID, _, err := tablecodec.DecodeKeyHead(key)
		if err != nil {
			return "", errors.Trace(errors.Errorf("invalid record/index key: %X", key))
		}
		tblInfo := tbl.Meta()
		var targetIndex *model.IndexInfo
		for _, idx := range tblInfo.Indices {
			if idx.ID == indexID {
				targetIndex = idx
				break
			}
		}
		if targetIndex == nil {
			return "", errors.Trace(errors.Errorf("index not found when decoding index key: %X", key))
		}
		colInfos := tables.BuildRowcodecColInfoForIndexColumns(targetIndex, tblInfo)
		tps := tables.BuildFieldTypesForIndexColumns(targetIndex, tblInfo)
		values, err := tablecodec.DecodeIndexKV(key, []byte{0}, len(colInfos), tablecodec.HandleNotNeeded, colInfos)
		if err != nil {
			return "", errors.Trace(err)
		}
		ds := make([]types.Datum, 0, len(colInfos))
		for i := 0; i < len(colInfos); i++ {
			d, err := tablecodec.DecodeColumnValue(values[i], tps[i], loc)
			if err != nil {
				return "", errors.Trace(err)
			}
			ds = append(ds, d)
		}
		ret := make(map[string]interface{})
		ret["table_id"] = tableID
		ret["index_id"] = indexID
		idxValMap := make(map[string]interface{}, len(targetIndex.Columns))
		for i := 0; i < len(targetIndex.Columns); i++ {
			dtStr, err := datumToJSONObject(&ds[i])
			if err != nil {
				return "", errors.Trace(err)
			}
			idxValMap[targetIndex.Columns[i].Name.L] = dtStr
		}
		ret["index_vals"] = idxValMap
		retStr, err := json.Marshal(ret)
		if err != nil {
			return "", errors.Trace(err)
		}
		return string(retStr), nil
	}
	_, indexID, indexValues, err := tablecodec.DecodeIndexKey(key)
	if err != nil {
		return "", errors.Trace(errors.Errorf("invalid index key: %X", key))
	}
	ret := make(map[string]interface{})
	ret["table_id"] = tableID
	ret["index_id"] = indexID
	ret["index_vals"] = strings.Join(indexValues, ", ")
	retStr, err := json.Marshal(ret)
	if err != nil {
		return "", errors.Trace(err)
	}
	return string(retStr), nil
}

func decodeTableKey(_ []byte, tableID int64) (string, error) {
	ret := map[string]int64{"table_id": tableID}
	retStr, err := json.Marshal(ret)
	if err != nil {
		return "", errors.Trace(err)
	}
	return string(retStr), nil
}

func datumToJSONObject(d *types.Datum) (interface{}, error) {
	if d.IsNull() {
		return nil, nil
	}
	return d.ToString()
}

相关信息

tidb 源码目录

相关文章

tidb access_object 源码

tidb collect_column_stats_usage 源码

tidb common_plans 源码

tidb encode 源码

tidb errors 源码

tidb exhaust_physical_plans 源码

tidb explain 源码

tidb find_best_task 源码

tidb flat_plan 源码

tidb fragment 源码

0  赞