tidb format_yacc 源码
tidb format_yacc 代码
文件路径:/parser/goyacc/format_yacc.go
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bufio"
"fmt"
gofmt "go/format"
"go/token"
"io/ioutil"
"os"
"regexp"
"strings"
"github.com/cznic/strutil"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/format"
parser "modernc.org/parser/yacc"
)
func Format(inputFilename string, goldenFilename string) (err error) {
spec, err := parseFileToSpec(inputFilename)
if err != nil {
return err
}
yFmt := &OutputFormatter{}
if err = yFmt.Setup(goldenFilename); err != nil {
return err
}
defer func() {
teardownErr := yFmt.Teardown()
if err == nil {
err = teardownErr
}
}()
if err = printDefinitions(yFmt, spec.Defs); err != nil {
return err
}
return printRules(yFmt, spec.Rules)
}
func parseFileToSpec(inputFilename string) (*parser.Specification, error) {
src, err := ioutil.ReadFile(inputFilename)
if err != nil {
return nil, err
}
return parser.Parse(token.NewFileSet(), inputFilename, src)
}
// Definition represents data reduced by productions:
//
// Definition:
// START IDENTIFIER
// | UNION // Case 1
// | LCURL RCURL // Case 2
// | ReservedWord Tag NameList // Case 3
// | ReservedWord Tag // Case 4
// | ERROR_VERBOSE // Case 5
const (
StartIdentifierCase = iota
UnionDefinitionCase
LCURLRCURLCase
ReservedWordTagNameListCase
ReservedWordTagCase
)
func printDefinitions(formatter format.Formatter, definitions []*parser.Definition) error {
for _, def := range definitions {
var err error
switch def.Case {
case StartIdentifierCase:
err = handleStart(formatter, def)
case UnionDefinitionCase:
err = handleUnion(formatter, def)
case LCURLRCURLCase:
err = handleProlog(formatter, def)
case ReservedWordTagNameListCase, ReservedWordTagCase:
err = handleReservedWordTagNameList(formatter, def)
}
if err != nil {
return err
}
}
_, err := formatter.Format("\n%%%%")
return err
}
func handleStart(f format.Formatter, definition *parser.Definition) error {
if err := Ensure(definition).
and(definition.Token2).
and(definition.Token2).NotNil(); err != nil {
return err
}
cmt1 := strings.Join(definition.Token.Comments, "\n")
cmt2 := strings.Join(definition.Token2.Comments, "\n")
_, err := f.Format("\n%s%s\t%s%s\n", cmt1, definition.Token.Val, cmt2, definition.Token2.Val)
return err
}
func handleUnion(f format.Formatter, definition *parser.Definition) error {
if err := Ensure(definition).
and(definition.Value).NotNil(); err != nil {
return err
}
if len(definition.Value) != 0 {
_, err := f.Format("%%union%i%s%u\n\n", definition.Value)
if err != nil {
return err
}
}
return nil
}
func handleProlog(f format.Formatter, definition *parser.Definition) error {
if err := Ensure(definition).
and(definition.Value).NotNil(); err != nil {
return err
}
_, err := f.Format("%%{%s%%}\n\n", definition.Value)
return err
}
func handleReservedWordTagNameList(f format.Formatter, def *parser.Definition) error {
if err := Ensure(def).
and(def.ReservedWord).
and(def.ReservedWord.Token).NotNil(); err != nil {
return err
}
comment := getTokenComment(def.ReservedWord.Token, divNewLineStringLayout)
directive := def.ReservedWord.Token.Val
hasTag := def.Tag != nil
var wordAfterDirective string
if hasTag {
wordAfterDirective = joinTag(def.Tag)
} else {
wordAfterDirective = joinNames(def.Nlist)
}
if _, err := f.Format("%s%s%s%i", comment, directive, wordAfterDirective); err != nil {
return err
}
if hasTag {
if _, err := f.Format("\n"); err != nil {
return err
}
if err := printNameListVertical(f, def.Nlist); err != nil {
return err
}
}
_, err := f.Format("%u\n")
return err
}
func joinTag(tag *parser.Tag) string {
var sb strings.Builder
sb.WriteString("\t")
if tag.Token != nil {
sb.WriteString(tag.Token.Val)
}
if tag.Token2 != nil {
sb.WriteString(tag.Token2.Val)
}
if tag.Token3 != nil {
sb.WriteString(tag.Token3.Val)
}
return sb.String()
}
type stringLayout int8
const (
spanStringLayout stringLayout = iota
divStringLayout
divNewLineStringLayout
)
func getTokenComment(token *parser.Token, layout stringLayout) string {
if len(token.Comments) == 0 {
return ""
}
var splitter, beforeComment string
switch layout {
case spanStringLayout:
splitter, beforeComment = " ", ""
case divStringLayout:
splitter, beforeComment = "\n", ""
case divNewLineStringLayout:
splitter, beforeComment = "\n", "\n"
default:
panic(errors.Errorf("unsupported stringLayout: %v", layout))
}
var sb strings.Builder
sb.WriteString(beforeComment)
for _, comment := range token.Comments {
sb.WriteString(comment)
sb.WriteString(splitter)
}
return sb.String()
}
func printNameListVertical(f format.Formatter, names NameArr) (err error) {
rest := names
for len(rest) != 0 {
var processing NameArr
processing, rest = rest[:1], rest[1:]
var noComments NameArr
noComments, rest = rest.span(noComment)
processing = append(processing, noComments...)
maxCharLength := processing.findMaxLength()
for _, name := range processing {
if err := printSingleName(f, name, maxCharLength); err != nil {
return err
}
}
}
return nil
}
func joinNames(names NameArr) string {
var sb strings.Builder
for _, name := range names {
sb.WriteString(" ")
sb.WriteString(getTokenComment(name.Token, spanStringLayout))
sb.WriteString(name.Token.Val)
}
return sb.String()
}
func printSingleName(f format.Formatter, name *parser.Name, maxCharLength int) error {
cmt := getTokenComment(name.Token, divNewLineStringLayout)
if _, err := f.Format(escapePercent(cmt)); err != nil {
return err
}
strLit := name.LiteralStringOpt
if strLit != nil && strLit.Token != nil {
_, err := f.Format("%-*s %s\n", maxCharLength, name.Token.Val, strLit.Token.Val)
return err
}
_, err := f.Format("%s\n", name.Token.Val)
return err
}
type NameArr []*parser.Name
func (ns NameArr) span(pred func(*parser.Name) bool) (first NameArr, second NameArr) {
first = ns.takeWhile(pred)
second = ns[len(first):]
return first, second
}
func (ns NameArr) takeWhile(pred func(*parser.Name) bool) NameArr {
for i, def := range ns {
if pred(def) {
continue
}
return ns[:i]
}
return ns
}
func (ns NameArr) findMaxLength() int {
maxLen := -1
for _, s := range ns {
if len(s.Token.Val) > maxLen {
maxLen = len(s.Token.Val)
}
}
return maxLen
}
func hasComments(n *parser.Name) bool {
return len(n.Token.Comments) != 0
}
func noComment(n *parser.Name) bool {
return !hasComments(n)
}
func containsActionInRule(rule *parser.Rule) bool {
for _, b := range rule.Body {
if _, ok := b.(*parser.Action); ok {
return true
}
}
return false
}
type RuleArr []*parser.Rule
func printRules(f format.Formatter, rules RuleArr) (err error) {
var lastRuleName string
for _, rule := range rules {
if rule.Name.Val == lastRuleName {
cmt := getTokenComment(rule.Token, divStringLayout)
_, err = f.Format("\n%s|\t%i", cmt)
} else {
cmt := getTokenComment(rule.Name, divStringLayout)
_, err = f.Format("\n\n%s%s:%i\n", cmt, rule.Name.Val)
}
if err != nil {
return err
}
lastRuleName = rule.Name.Val
if err = printRuleBody(f, rule); err != nil {
return err
}
if _, err = f.Format("%u"); err != nil {
return err
}
}
_, err = f.Format("\n%%%%\n")
return err
}
type ruleItemType int8
const (
identRuleItemType ruleItemType = 1
actionRuleItemType ruleItemType = 2
strLiteralRuleItemType ruleItemType = 3
)
func printRuleBody(f format.Formatter, rule *parser.Rule) error {
firstRuleItem, counter := rule.RuleItemList, 0
for ri := rule.RuleItemList; ri != nil; ri = ri.RuleItemList {
switch ruleItemType(ri.Case) {
case identRuleItemType, strLiteralRuleItemType:
term := fmt.Sprintf(" %s", ri.Token.Val)
if ri == firstRuleItem {
term = term[1:]
}
cmt := getTokenComment(ri.Token, divStringLayout)
if _, err := f.Format(escapePercent(cmt)); err != nil {
return err
}
if _, err := f.Format("%s", term); err != nil {
return err
}
case actionRuleItemType:
isFirstRuleItem := ri == firstRuleItem
if err := handlePrecedence(f, rule.Precedence, isFirstRuleItem); err != nil {
return err
}
if err := handleAction(f, rule, ri.Action, isFirstRuleItem); err != nil {
return err
}
}
counter++
}
if err := checkInconsistencyInYaccParser(f, rule, counter); err != nil {
return err
}
if !containsActionInRule(rule) {
if err := handlePrecedence(f, rule.Precedence, counter == 0); err != nil {
return err
}
}
return nil
}
func handleAction(f format.Formatter, rule *parser.Rule, action *parser.Action, isFirstItem bool) error {
if !isFirstItem || rule.Precedence != nil {
if _, err := f.Format("\n"); err != nil {
return err
}
}
cmt := getTokenComment(action.Token, divStringLayout)
if _, err := f.Format(escapePercent(cmt)); err != nil {
return err
}
goSnippet, err := formatGoSnippet(action.Values)
goSnippet = escapePercent(goSnippet)
if err != nil {
return err
}
snippet := "{}"
if len(goSnippet) != 0 {
snippet = fmt.Sprintf("{%%i\n%s%%u\n}", goSnippet)
}
_, err = f.Format(snippet)
return err
}
func handlePrecedence(f format.Formatter, p *parser.Precedence, isFirstItem bool) error {
if p == nil {
return nil
}
if err := Ensure(p.Token).
and(p.Token2).NotNil(); err != nil {
return err
}
cmt := getTokenComment(p.Token, spanStringLayout)
if !isFirstItem {
if _, err := f.Format(" "); err != nil {
return err
}
}
_, err := f.Format("%s%s %s", cmt, p.Token.Val, p.Token2.Val)
return err
}
func formatGoSnippet(actVal []*parser.ActionValue) (string, error) {
tran := &SpecialActionValTransformer{
store: map[string]string{},
}
goSnippet := collectGoSnippet(tran, actVal)
formatted, err := gofmt.Source([]byte(goSnippet))
if err != nil {
return "", err
}
formattedSnippet := tran.restore(string(formatted))
return strings.TrimSpace(formattedSnippet), nil
}
func collectGoSnippet(tran *SpecialActionValTransformer, actionValArr []*parser.ActionValue) string {
var sb strings.Builder
for _, value := range actionValArr {
trimTab := removeLineBeginBlanks(value.Src)
sb.WriteString(tran.transform(trimTab))
}
snipWithPar := strings.TrimSpace(sb.String())
if strings.HasPrefix(snipWithPar, "{") && strings.HasSuffix(snipWithPar, "}") {
return snipWithPar[1 : len(snipWithPar)-1]
}
return ""
}
var lineBeginBlankRegex = regexp.MustCompile("(?m)^[\t ]+")
func removeLineBeginBlanks(src string) string {
return lineBeginBlankRegex.ReplaceAllString(src, "")
}
type SpecialActionValTransformer struct {
store map[string]string
}
const yaccFmtVar = "_yaccfmt_var_"
var yaccFmtVarRegex = regexp.MustCompile("_yaccfmt_var_[0-9]{1,5}")
func (s *SpecialActionValTransformer) transform(val string) string {
if strings.HasPrefix(val, "$") {
generated := fmt.Sprintf("%s%d", yaccFmtVar, len(s.store))
s.store[generated] = val
return generated
}
return val
}
func (s *SpecialActionValTransformer) restore(src string) string {
return yaccFmtVarRegex.ReplaceAllStringFunc(src, func(matched string) string {
origin, ok := s.store[matched]
if !ok {
panic(errors.Errorf("mismatch in SpecialActionValTransformer"))
}
return origin
})
}
type OutputFormatter struct {
file *os.File
out *bufio.Writer
formatter strutil.Formatter
}
func (y *OutputFormatter) Setup(filename string) (err error) {
if y.file, err = os.Create(filename); err != nil {
return
}
y.out = bufio.NewWriter(y.file)
y.formatter = strutil.IndentFormatter(y.out, "\t")
return
}
func (y *OutputFormatter) Teardown() error {
if y.out != nil {
if err := y.out.Flush(); err != nil {
return err
}
}
if y.file != nil {
if err := y.file.Close(); err != nil {
return err
}
}
return nil
}
func (y *OutputFormatter) Format(format string, args ...interface{}) (int, error) {
return y.formatter.Format(format, args...)
}
func (y *OutputFormatter) Write(bytes []byte) (int, error) {
return y.formatter.Write(bytes)
}
type NotNilAssert struct {
idx int
err error
}
func (n *NotNilAssert) and(target interface{}) *NotNilAssert {
if n.err != nil {
return n
}
if target == nil {
n.err = errors.Errorf("encounter nil, index: %d", n.idx)
}
n.idx++
return n
}
func (n *NotNilAssert) NotNil() error {
return n.err
}
func Ensure(target interface{}) *NotNilAssert {
return (&NotNilAssert{}).and(target)
}
func escapePercent(src string) string {
return strings.ReplaceAll(src, "%", "%%")
}
func checkInconsistencyInYaccParser(f format.Formatter, rule *parser.Rule, counter int) error {
if counter == len(rule.Body) {
return nil
}
// pickup rule item in ruleBody
for i := counter; i < len(rule.Body); i++ {
body := rule.Body[i]
switch b := body.(type) {
case string, int:
if bInt, ok := b.(int); ok {
b = fmt.Sprintf("'%c'", bInt)
}
term := fmt.Sprintf(" %s", b)
if i == 0 {
term = term[1:]
}
_, err := f.Format("%s", term)
return err
case *parser.Action:
isFirstRuleItem := i == 0
if err := handlePrecedence(f, rule.Precedence, isFirstRuleItem); err != nil {
return err
}
if err := handleAction(f, rule, b, isFirstRuleItem); err != nil {
return err
}
}
}
return nil
}
相关信息
相关文章
0
赞
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦