tidb aes_layer 源码

  • 2022-09-19
  • 浏览 (441)

tidb aes_layer 代码

文件路径:/util/encrypt/aes_layer.go

// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package encrypt

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"encoding/binary"
	"errors"
	"io"
	"math"
	"math/big"
)

var errInvalidBlockSize = errors.New("invalid encrypt block size")

// defaultEncryptBlockSize indicates the default encrypt block size in bytes
const defaultEncryptBlockSize = 1024

// CtrCipher encrypting data using AES in counter mode
type CtrCipher struct {
	block cipher.Block
	nonce uint64
	// encryptBlockSize indicates the encrypt block size in bytes.
	encryptBlockSize int64
	// aesBlockCount indicates the total aes blocks in one encrypt block
	aesBlockCount int64
}

// NewCtrCipher return a CtrCipher using the default encrypt block size
func NewCtrCipher() (ctr *CtrCipher, err error) {
	return NewCtrCipherWithBlockSize(defaultEncryptBlockSize)
}

// NewCtrCipherWithBlockSize return a CtrCipher with the encrypt block size
func NewCtrCipherWithBlockSize(encryptBlockSize int64) (ctr *CtrCipher, err error) {
	key := make([]byte, aes.BlockSize)
	_, err = rand.Read(key)
	if err != nil {
		return nil, err
	}
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	if encryptBlockSize%aes.BlockSize != 0 {
		return nil, errInvalidBlockSize
	}
	ctr = new(CtrCipher)
	ctr.block = block
	nonce, err := rand.Int(rand.Reader, big.NewInt(int64(math.MaxInt64)))
	if err != nil {
		return nil, err
	}
	ctr.nonce = nonce.Uint64()
	ctr.encryptBlockSize = encryptBlockSize
	ctr.aesBlockCount = encryptBlockSize / aes.BlockSize
	return
}

// stream returns a cipher.Stream be use to encrypts/decrypts
func (ctr *CtrCipher) stream(counter uint64) cipher.Stream {
	counterBuf := make([]byte, aes.BlockSize)
	binary.BigEndian.PutUint64(counterBuf, ctr.nonce)
	binary.BigEndian.PutUint64(counterBuf[8:], counter)
	return cipher.NewCTR(ctr.block, counterBuf)
}

// Writer implements an io.WriteCloser, it encrypt data using AES before writing to the underlying object.
type Writer struct {
	err                error
	w                  io.WriteCloser
	cipherStream       cipher.Stream
	buf                []byte
	flushedUserDataCnt int64
	n                  int
}

// NewWriter returns a new Writer which encrypt data using AES before writing to the underlying object.
func NewWriter(w io.WriteCloser, ctrCipher *CtrCipher) *Writer {
	writer := &Writer{w: w}
	writer.buf = make([]byte, ctrCipher.encryptBlockSize)
	writer.cipherStream = ctrCipher.stream(0)
	return writer
}

// AvailableSize returns how many bytes are unused in the buffer.
func (w *Writer) AvailableSize() int { return len(w.buf) - w.n }

// Write implements the io.Writer interface.
func (w *Writer) Write(p []byte) (n int, err error) {
	if w.err != nil {
		return n, w.err
	}
	for len(p) > w.AvailableSize() && w.err == nil {
		copiedNum := copy(w.buf[w.n:], p)
		w.n += copiedNum
		err = w.Flush()
		if err != nil {
			return
		}
		n += copiedNum
		p = p[copiedNum:]
	}
	copiedNum := copy(w.buf[w.n:], p)
	w.n += copiedNum
	n += copiedNum
	return
}

// Buffered returns the number of bytes that have been written into the current buffer.
func (w *Writer) Buffered() int { return w.n }

// Flush writes all the buffered data to the underlying object.
func (w *Writer) Flush() error {
	if w.err != nil {
		return w.err
	}
	if w.n == 0 {
		return nil
	}
	w.cipherStream.XORKeyStream(w.buf[:w.n], w.buf[:w.n])
	n, err := w.w.Write(w.buf[:w.n])
	w.flushedUserDataCnt += int64(n)
	if n < w.n && err == nil {
		err = io.ErrShortWrite
	}
	if err != nil {
		w.err = err
		return err
	}
	w.n = 0
	return nil
}

// GetCache returns the byte slice that holds the data not flushed to disk.
func (w *Writer) GetCache() []byte {
	return w.buf[:w.n]
}

// GetCacheDataOffset return the user data offset in cache.
func (w *Writer) GetCacheDataOffset() int64 {
	return w.flushedUserDataCnt
}

// Close implements the io.Closer interface.
func (w *Writer) Close() (err error) {
	err = w.Flush()
	if err != nil {
		return
	}
	return w.w.Close()
}

// Reader implements an io.ReadAt, reading from the input source after decrypting.
type Reader struct {
	r      io.ReaderAt
	cipher *CtrCipher
}

// NewReader returns a new Reader which can read from the input source after decrypting.
func NewReader(r io.ReaderAt, ctrCipher *CtrCipher) *Reader {
	reader := &Reader{r: r, cipher: ctrCipher}
	return reader
}

// ReadAt implements the io.ReadAt interface.
func (r *Reader) ReadAt(p []byte, off int64) (n int, err error) {
	if len(p) == 0 {
		return 0, nil
	}
	offset := off % r.cipher.encryptBlockSize
	counter := (off / r.cipher.encryptBlockSize) * r.cipher.aesBlockCount
	cursor := off - offset

	buf := make([]byte, r.cipher.encryptBlockSize)
	var readNum int
	cipherStream := r.cipher.stream(uint64(counter))
	for len(p) > 0 && err == nil {
		readNum, err = r.r.ReadAt(buf, cursor)
		if err != nil {
			if readNum == 0 || err != io.EOF {
				return n, err
			}
			err = nil
			// continue if n > 0 and r.err is io.EOF
		}
		cursor += int64(readNum)
		cipherStream.XORKeyStream(buf[:readNum], buf[:readNum])
		copiedNum := copy(p, buf[offset:readNum])
		n += copiedNum
		p = p[copiedNum:]
		offset = 0
	}
	return n, err
}

相关信息

tidb 源码目录

相关文章

tidb aes 源码

tidb crypt 源码

0  赞