tidb compare 源码
tidb compare 代码
文件路径:/util/chunk/compare.go
// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package chunk
import (
"bytes"
"sort"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
)
// CompareFunc is a function to compare the two values in Row, the two columns must have the same type.
type CompareFunc = func(l Row, lCol int, r Row, rCol int) int
// GetCompareFunc gets a compare function for the field type.
func GetCompareFunc(tp *types.FieldType) CompareFunc {
switch tp.GetType() {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
if mysql.HasUnsignedFlag(tp.GetFlag()) {
return cmpUint64
}
return cmpInt64
case mysql.TypeFloat:
return cmpFloat32
case mysql.TypeDouble:
return cmpFloat64
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
return genCmpStringFunc(tp.GetCollate())
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
return cmpTime
case mysql.TypeDuration:
return cmpDuration
case mysql.TypeNewDecimal:
return cmpMyDecimal
case mysql.TypeSet, mysql.TypeEnum:
return cmpNameValue
case mysql.TypeBit:
return cmpBit
case mysql.TypeJSON:
return cmpJSON
}
return nil
}
func cmpNull(lNull, rNull bool) int {
if lNull && rNull {
return 0
}
if lNull {
return -1
}
return 1
}
func cmpInt64(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
return types.CompareInt64(l.GetInt64(lCol), r.GetInt64(rCol))
}
func cmpUint64(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol))
}
func genCmpStringFunc(collation string) func(l Row, lCol int, r Row, rCol int) int {
return func(l Row, lCol int, r Row, rCol int) int {
return cmpStringWithCollationInfo(l, lCol, r, rCol, collation)
}
}
func cmpStringWithCollationInfo(l Row, lCol int, r Row, rCol int, collation string) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
return types.CompareString(l.GetString(lCol), r.GetString(rCol), collation)
}
func cmpFloat32(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
return types.CompareFloat64(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol)))
}
func cmpFloat64(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
return types.CompareFloat64(l.GetFloat64(lCol), r.GetFloat64(rCol))
}
func cmpMyDecimal(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
lDec, rDec := l.GetMyDecimal(lCol), r.GetMyDecimal(rCol)
return lDec.Compare(rDec)
}
func cmpTime(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
lTime, rTime := l.GetTime(lCol), r.GetTime(rCol)
return lTime.Compare(rTime)
}
func cmpDuration(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
lDur, rDur := l.GetDuration(lCol, 0).Duration, r.GetDuration(rCol, 0).Duration
return types.CompareInt64(int64(lDur), int64(rDur))
}
func cmpNameValue(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
_, lVal := l.getNameValue(lCol)
_, rVal := r.getNameValue(rCol)
return types.CompareUint64(lVal, rVal)
}
func cmpBit(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
lBit := types.BinaryLiteral(l.GetBytes(lCol))
rBit := types.BinaryLiteral(r.GetBytes(rCol))
return lBit.Compare(rBit)
}
func cmpJSON(l Row, lCol int, r Row, rCol int) int {
lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
if lNull || rNull {
return cmpNull(lNull, rNull)
}
lJ, rJ := l.GetJSON(lCol), r.GetJSON(rCol)
return types.CompareBinaryJSON(lJ, rJ)
}
// Compare compares the value with ad.
// We assume that the collation information of the column is the same with the datum.
func Compare(row Row, colIdx int, ad *types.Datum) int {
switch ad.Kind() {
case types.KindNull:
if row.IsNull(colIdx) {
return 0
}
return 1
case types.KindMinNotNull:
if row.IsNull(colIdx) {
return -1
}
return 1
case types.KindMaxValue:
return -1
case types.KindInt64:
return types.CompareInt64(row.GetInt64(colIdx), ad.GetInt64())
case types.KindUint64:
return types.CompareUint64(row.GetUint64(colIdx), ad.GetUint64())
case types.KindFloat32:
return types.CompareFloat64(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32()))
case types.KindFloat64:
return types.CompareFloat64(row.GetFloat64(colIdx), ad.GetFloat64())
case types.KindString:
return types.CompareString(row.GetString(colIdx), ad.GetString(), ad.Collation())
case types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit:
return bytes.Compare(row.GetBytes(colIdx), ad.GetBytes())
case types.KindMysqlDecimal:
l, r := row.GetMyDecimal(colIdx), ad.GetMysqlDecimal()
return l.Compare(r)
case types.KindMysqlDuration:
l, r := row.GetDuration(colIdx, 0).Duration, ad.GetMysqlDuration().Duration
return types.CompareInt64(int64(l), int64(r))
case types.KindMysqlEnum:
l, r := row.GetEnum(colIdx).Value, ad.GetMysqlEnum().Value
return types.CompareUint64(l, r)
case types.KindMysqlSet:
l, r := row.GetSet(colIdx).Value, ad.GetMysqlSet().Value
return types.CompareUint64(l, r)
case types.KindMysqlJSON:
l, r := row.GetJSON(colIdx), ad.GetMysqlJSON()
return types.CompareBinaryJSON(l, r)
case types.KindMysqlTime:
l, r := row.GetTime(colIdx), ad.GetMysqlTime()
return l.Compare(r)
default:
return 0
}
}
// LowerBound searches on the non-decreasing Column colIdx,
// returns the smallest index i such that the value at row i is not less than `d`.
func (c *Chunk) LowerBound(colIdx int, d *types.Datum) (index int, match bool) {
index = sort.Search(c.NumRows(), func(i int) bool {
cmp := Compare(c.GetRow(i), colIdx, d)
if cmp == 0 {
match = true
}
return cmp >= 0
})
return
}
// UpperBound searches on the non-decreasing Column colIdx,
// returns the smallest index i such that the value at row i is larger than `d`.
func (c *Chunk) UpperBound(colIdx int, d *types.Datum) int {
return sort.Search(c.NumRows(), func(i int) bool {
return Compare(c.GetRow(i), colIdx, d) > 0
})
}
相关信息
相关文章
0
赞
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦