Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ const requiredFrontendCaps = pnet.ClientProtocol41
const defRequiredBackendCaps = pnet.ClientDeprecateEOF
const ER_INVALID_SEQUENCE = 8052

// maxHandshakePacketSize limits client handshake packets to avoid OOM.
// Most known clients send ~1-2KB handshakes: 4B capability + 4B max packet size +
// 1B charset + 23B reserved + NUL-terminated user + auth data (lenenc/len) +
// optional NUL-terminated db name + optional auth plugin name +
// optional connection attributes (lenenc length + key/value pairs).
// The SSLRequest packet is fixed 32 bytes. These stay far below 1MB in practice.
//
// Ref https://dev.mysql.com/doc/dev/mysql-server/9.5.0/page_protocol_connection_phase_packets_protocol_handshake_response.html
const maxHandshakePacketSize = 1 << 20

// SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported.
// TiDB supports ClientDeprecateEOF since v6.3.0.
// TiDB supports ClientCompress and ClientZstdCompressionAlgorithm since v7.2.0.
Expand Down Expand Up @@ -89,6 +99,11 @@ type backendIOGetter func(ctx context.Context, cctx ConnContext, resp *pnet.Hand
func (auth *Authenticator) handshakeFirstTime(ctx context.Context, logger *zap.Logger, cctx ConnContext, clientIO pnet.PacketIO, handshakeHandler HandshakeHandler,
getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error {
clientIO.ResetSequence()
clientIO.ApplyOpts(pnet.WithReadPacketLimit(maxHandshakePacketSize))
// TODO: now we only limit the size of the handshake packet, we assume that all clients with proper
// user / password will send reasonable packets. However, it's not true when the TiProxy is shared by
// many TiDBs keyspaces and customers.
defer clientIO.ApplyOpts(pnet.WithReadPacketLimit(0))

proxyCapability := handshakeHandler.GetCapability()
if frontendTLSConfig == nil {
Expand Down
24 changes: 24 additions & 0 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,27 @@ func TestMalformedHandshakePacket(t *testing.T) {

clean()
}

func TestHandshakePacketSizeLimit(t *testing.T) {
tc := newTCPConnSuite(t)
ts, clean := newTestSuite(t, tc)

customClientRunner := func(packetIO pnet.PacketIO) error {
if _, err := packetIO.ReadPacket(); err != nil {
return err
}

oversizedPacket := make([]byte, maxHandshakePacketSize+1)
binary.LittleEndian.PutUint32(oversizedPacket[0:4], ts.mc.capability.Uint32())

return packetIO.WritePacket(oversizedPacket, true)
}

ts.runAndCheck(t, func(t *testing.T, ts *testSuite) {
require.Error(t, ts.mp.err)
require.ErrorIs(t, ts.mp.err, pnet.ErrPacketTooLarge)
require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err))
}, customClientRunner, ts.mb.authenticate, ts.mp.authenticateFirstTime)

clean()
}
4 changes: 3 additions & 1 deletion pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func ErrToClient(err error) error {
return ErrBackendNoTLS
case errors.Is(err, ErrBackendPPV2):
return ErrBackendPPV2
case errors.Is(err, pnet.ErrPacketTooLarge):
return mysql.NewDefaultError(mysql.ER_NET_PACKET_TOO_LARGE)
case errors.Is(err, ErrProxyErr):
// The error is returned by HandshakeHandler/BackendFetcher and wrapped with ErrProxyErr.
return errors.Unwrap(err)
Expand Down Expand Up @@ -112,7 +114,7 @@ func Error2Source(err error) ErrorSource {
case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, mysql.ErrMalformPacket):
// We assume the clients and TiDB are right and treat it as TiProxy bugs.
return SrcProxyMalformed
case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap):
case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap), errors.Is(err, pnet.ErrPacketTooLarge):
return SrcClientHandshake
case errors.Is(err, ErrClientAuthFail):
return SrcClientAuthFail
Expand Down
11 changes: 6 additions & 5 deletions pkg/proxy/net/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ import (
)

var (
ErrReadConn = errors.New("failed to read the connection")
ErrWriteConn = errors.New("failed to write the connection")
ErrFlushConn = errors.New("failed to flush the connection")
ErrCloseConn = errors.New("failed to close the connection")
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
ErrReadConn = errors.New("failed to read the connection")
ErrWriteConn = errors.New("failed to write the connection")
ErrFlushConn = errors.New("failed to flush the connection")
ErrCloseConn = errors.New("failed to close the connection")
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
ErrPacketTooLarge = errors.New("packet size exceeds limit")
)

// IsDisconnectError returns whether the error is caused by peer disconnection.
Expand Down
10 changes: 7 additions & 3 deletions pkg/proxy/net/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,17 @@ func MakeEOFPacket(status uint16) []byte {
return data
}

// WriteUserError writes an unknown error to the client.
// MakeUserError builds an error packet for the client.
func MakeUserError(err error) []byte {
if err == nil {
return nil
}
myErr := gomysql.NewError(gomysql.ER_UNKNOWN_ERROR, err.Error())
return MakeErrPacket(myErr)
var myErr *gomysql.MyError
if errors.As(err, &myErr) {
return MakeErrPacket(myErr)
}
baseErr := gomysql.NewError(gomysql.ER_UNKNOWN_ERROR, err.Error())
return MakeErrPacket(baseErr)
}

type InitialHandshake struct {
Expand Down
44 changes: 32 additions & 12 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,17 @@ type PacketIO interface {

// PacketIO is a helper to read and write sql and proxy protocol.
type packetIO struct {
lastKeepAlive config.KeepAlive
rawConn net.Conn
readWriter packetReadWriter
limitReader io.LimitedReader // reuse memory to reduce allocation
logger *zap.Logger
remoteAddr net.Addr
wrap error
header [4]byte // reuse memory to reduce allocation
inPackets uint64
outPackets uint64
lastKeepAlive config.KeepAlive
rawConn net.Conn
readWriter packetReadWriter
limitReader io.LimitedReader // reuse memory to reduce allocation
logger *zap.Logger
remoteAddr net.Addr
wrap error
header [4]byte // reuse memory to reduce allocation
readPacketLimit int
inPackets uint64
outPackets uint64
}

func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO {
Expand Down Expand Up @@ -286,7 +287,10 @@ func (p *packetIO) GetSequence() uint8 {
return p.readWriter.Sequence()
}

func (p *packetIO) readOnePacket() ([]byte, bool, error) {
// readOnePacket reads one packet and returns the data without header, whether there are more packets and error if any.
// If limit >= 0, it returns an error if the packet size exceeds the limit.
// The caller may read a trailing zero-length packet when the previous packet length equals MaxPayloadLen.
func (p *packetIO) readOnePacket(limit int) ([]byte, bool, error) {
if err := ReadFull(p.readWriter, p.header[:]); err != nil {
return nil, false, errors.Wrap(err, ErrReadConn)
}
Expand All @@ -297,6 +301,9 @@ func (p *packetIO) readOnePacket() ([]byte, bool, error) {
p.readWriter.SetSequence(sequence + 1)

length := int(p.header[0]) | int(p.header[1])<<8 | int(p.header[2])<<16
if limit >= 0 && length > limit {
return nil, false, errors.Wrapf(ErrPacketTooLarge, "packet size %d exceeds limit %d", length, limit)
}
data := make([]byte, length)
if err := ReadFull(p.readWriter, data); err != nil {
return nil, false, errors.Wrap(err, ErrReadConn)
Expand All @@ -308,13 +315,26 @@ func (p *packetIO) readOnePacket() ([]byte, bool, error) {
// ReadPacket reads data and removes the header
func (p *packetIO) ReadPacket() (data []byte, err error) {
p.readWriter.BeginRW(rwRead)
checkPacketLimit := p.readPacketLimit > 0
remaining := p.readPacketLimit
for more := true; more; {
var buf []byte
buf, more, err = p.readOnePacket()
limit := -1
if checkPacketLimit {
limit = remaining
}
buf, more, err = p.readOnePacket(limit)
if err != nil {
err = p.wrapErr(err)
return
}
if checkPacketLimit {
remaining -= len(buf)
if remaining < 0 {
err = p.wrapErr(errors.Wrapf(ErrPacketTooLarge, "packet size exceeds limit %d", p.readPacketLimit))
return
}
}
if data == nil {
data = buf
} else {
Expand Down
8 changes: 8 additions & 0 deletions pkg/proxy/net/packetio_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ func WithWrapError(err error) func(pi *packetIO) {
}
}

// WithReadPacketLimit limits the total size of one ReadPacket call.
// A zero or negative limit means no limit.
func WithReadPacketLimit(limit int) func(pi *packetIO) {
return func(pi *packetIO) {
pi.readPacketLimit = limit
}
}

// WithRemoteAddr
var _ proxyprotocol.AddressWrapper = &originAddr{}

Expand Down
44 changes: 44 additions & 0 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,50 @@ func TestPacketIO(t *testing.T) {
)
}

func TestReadPacketLimitOption(t *testing.T) {
cases := []struct {
name string
limit int
size int
ok bool
}{
{name: "single-under", limit: 8, size: 4, ok: true},
{name: "single-at", limit: 8, size: 8, ok: true},
{name: "single-over", limit: 8, size: 9, ok: false},
{name: "max-payload", limit: MaxPayloadLen, size: MaxPayloadLen, ok: true},
{name: "multi-ok", limit: MaxPayloadLen + 20, size: MaxPayloadLen + 20, ok: true},
{name: "multi-over", limit: MaxPayloadLen + 10, size: MaxPayloadLen + 20, ok: false},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
testPipeConn(t,
func(t *testing.T, cli *packetIO) {
if !tc.ok {
_ = cli.readWriter.SetWriteDeadline(time.Now().Add(time.Second))
}
data := make([]byte, tc.size)
err := cli.WritePacket(data, true)
if tc.ok {
require.NoError(t, err)
}
},
func(t *testing.T, srv *packetIO) {
srv.ApplyOpts(WithReadPacketLimit(tc.limit))
data, err := srv.ReadPacket()
if tc.ok {
require.NoError(t, err)
require.Len(t, data, tc.size)
} else {
require.ErrorIs(t, err, ErrPacketTooLarge)
}
},
1,
)
})
}
}

func TestTLS(t *testing.T) {
stls, ctls, err := security.CreateTLSConfigForTest()
require.NoError(t, err)
Expand Down