From ced5598c05d8325e41de8f881f71349951aebaa9 Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Mon, 2 Mar 2026 19:16:18 +0800 Subject: [PATCH] limit the packet size of packets before handshake Signed-off-by: Yang Keao --- pkg/proxy/backend/authenticator.go | 15 +++++++++ pkg/proxy/backend/authenticator_test.go | 24 ++++++++++++++ pkg/proxy/backend/error.go | 4 ++- pkg/proxy/net/error.go | 11 ++++--- pkg/proxy/net/mysql.go | 10 ++++-- pkg/proxy/net/packetio.go | 44 ++++++++++++++++++------- pkg/proxy/net/packetio_options.go | 8 +++++ pkg/proxy/net/packetio_test.go | 44 +++++++++++++++++++++++++ 8 files changed, 139 insertions(+), 21 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 88037065..19ffb1c8 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -24,6 +24,16 @@ const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF const ER_INVALID_SEQUENCE = 8052 +// maxHandshakePacketSize limits client handshake packets to avoid OOM. +// Most known clients send ~1-2KB handshakes: 4B capability + 4B max packet size + +// 1B charset + 23B reserved + NUL-terminated user + auth data (lenenc/len) + +// optional NUL-terminated db name + optional auth plugin name + +// optional connection attributes (lenenc length + key/value pairs). +// The SSLRequest packet is fixed 32 bytes. These stay far below 1MB in practice. +// +// Ref https://dev.mysql.com/doc/dev/mysql-server/9.5.0/page_protocol_connection_phase_packets_protocol_handshake_response.html +const maxHandshakePacketSize = 1 << 20 + // SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported. // TiDB supports ClientDeprecateEOF since v6.3.0. // TiDB supports ClientCompress and ClientZstdCompressionAlgorithm since v7.2.0. @@ -89,6 +99,11 @@ type backendIOGetter func(ctx context.Context, cctx ConnContext, resp *pnet.Hand func (auth *Authenticator) handshakeFirstTime(ctx context.Context, logger *zap.Logger, cctx ConnContext, clientIO pnet.PacketIO, handshakeHandler HandshakeHandler, getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() + clientIO.ApplyOpts(pnet.WithReadPacketLimit(maxHandshakePacketSize)) + // TODO: now we only limit the size of the handshake packet, we assume that all clients with proper + // user / password will send reasonable packets. However, it's not true when the TiProxy is shared by + // many TiDBs keyspaces and customers. + defer clientIO.ApplyOpts(pnet.WithReadPacketLimit(0)) proxyCapability := handshakeHandler.GetCapability() if frontendTLSConfig == nil { diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index d05574be..002f77a4 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -625,3 +625,27 @@ func TestMalformedHandshakePacket(t *testing.T) { clean() } + +func TestHandshakePacketSizeLimit(t *testing.T) { + tc := newTCPConnSuite(t) + ts, clean := newTestSuite(t, tc) + + customClientRunner := func(packetIO pnet.PacketIO) error { + if _, err := packetIO.ReadPacket(); err != nil { + return err + } + + oversizedPacket := make([]byte, maxHandshakePacketSize+1) + binary.LittleEndian.PutUint32(oversizedPacket[0:4], ts.mc.capability.Uint32()) + + return packetIO.WritePacket(oversizedPacket, true) + } + + ts.runAndCheck(t, func(t *testing.T, ts *testSuite) { + require.Error(t, ts.mp.err) + require.ErrorIs(t, ts.mp.err, pnet.ErrPacketTooLarge) + require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err)) + }, customClientRunner, ts.mb.authenticate, ts.mp.authenticateFirstTime) + + clean() +} diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 41411f5d..05c64a1e 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -49,6 +49,8 @@ func ErrToClient(err error) error { return ErrBackendNoTLS case errors.Is(err, ErrBackendPPV2): return ErrBackendPPV2 + case errors.Is(err, pnet.ErrPacketTooLarge): + return mysql.NewDefaultError(mysql.ER_NET_PACKET_TOO_LARGE) case errors.Is(err, ErrProxyErr): // The error is returned by HandshakeHandler/BackendFetcher and wrapped with ErrProxyErr. return errors.Unwrap(err) @@ -112,7 +114,7 @@ func Error2Source(err error) ErrorSource { case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, mysql.ErrMalformPacket): // We assume the clients and TiDB are right and treat it as TiProxy bugs. return SrcProxyMalformed - case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap): + case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap), errors.Is(err, pnet.ErrPacketTooLarge): return SrcClientHandshake case errors.Is(err, ErrClientAuthFail): return SrcClientAuthFail diff --git a/pkg/proxy/net/error.go b/pkg/proxy/net/error.go index 6d3939c3..1d731754 100644 --- a/pkg/proxy/net/error.go +++ b/pkg/proxy/net/error.go @@ -14,11 +14,12 @@ import ( ) var ( - ErrReadConn = errors.New("failed to read the connection") - ErrWriteConn = errors.New("failed to write the connection") - ErrFlushConn = errors.New("failed to flush the connection") - ErrCloseConn = errors.New("failed to close the connection") - ErrHandshakeTLS = errors.New("failed to complete tls handshake") + ErrReadConn = errors.New("failed to read the connection") + ErrWriteConn = errors.New("failed to write the connection") + ErrFlushConn = errors.New("failed to flush the connection") + ErrCloseConn = errors.New("failed to close the connection") + ErrHandshakeTLS = errors.New("failed to complete tls handshake") + ErrPacketTooLarge = errors.New("packet size exceeds limit") ) // IsDisconnectError returns whether the error is caused by peer disconnection. diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index 8dbf786a..6005a8dd 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -84,13 +84,17 @@ func MakeEOFPacket(status uint16) []byte { return data } -// WriteUserError writes an unknown error to the client. +// MakeUserError builds an error packet for the client. func MakeUserError(err error) []byte { if err == nil { return nil } - myErr := gomysql.NewError(gomysql.ER_UNKNOWN_ERROR, err.Error()) - return MakeErrPacket(myErr) + var myErr *gomysql.MyError + if errors.As(err, &myErr) { + return MakeErrPacket(myErr) + } + baseErr := gomysql.NewError(gomysql.ER_UNKNOWN_ERROR, err.Error()) + return MakeErrPacket(baseErr) } type InitialHandshake struct { diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 72342ee7..3b97c6cd 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -230,16 +230,17 @@ type PacketIO interface { // PacketIO is a helper to read and write sql and proxy protocol. type packetIO struct { - lastKeepAlive config.KeepAlive - rawConn net.Conn - readWriter packetReadWriter - limitReader io.LimitedReader // reuse memory to reduce allocation - logger *zap.Logger - remoteAddr net.Addr - wrap error - header [4]byte // reuse memory to reduce allocation - inPackets uint64 - outPackets uint64 + lastKeepAlive config.KeepAlive + rawConn net.Conn + readWriter packetReadWriter + limitReader io.LimitedReader // reuse memory to reduce allocation + logger *zap.Logger + remoteAddr net.Addr + wrap error + header [4]byte // reuse memory to reduce allocation + readPacketLimit int + inPackets uint64 + outPackets uint64 } func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO { @@ -286,7 +287,10 @@ func (p *packetIO) GetSequence() uint8 { return p.readWriter.Sequence() } -func (p *packetIO) readOnePacket() ([]byte, bool, error) { +// readOnePacket reads one packet and returns the data without header, whether there are more packets and error if any. +// If limit >= 0, it returns an error if the packet size exceeds the limit. +// The caller may read a trailing zero-length packet when the previous packet length equals MaxPayloadLen. +func (p *packetIO) readOnePacket(limit int) ([]byte, bool, error) { if err := ReadFull(p.readWriter, p.header[:]); err != nil { return nil, false, errors.Wrap(err, ErrReadConn) } @@ -297,6 +301,9 @@ func (p *packetIO) readOnePacket() ([]byte, bool, error) { p.readWriter.SetSequence(sequence + 1) length := int(p.header[0]) | int(p.header[1])<<8 | int(p.header[2])<<16 + if limit >= 0 && length > limit { + return nil, false, errors.Wrapf(ErrPacketTooLarge, "packet size %d exceeds limit %d", length, limit) + } data := make([]byte, length) if err := ReadFull(p.readWriter, data); err != nil { return nil, false, errors.Wrap(err, ErrReadConn) @@ -308,13 +315,26 @@ func (p *packetIO) readOnePacket() ([]byte, bool, error) { // ReadPacket reads data and removes the header func (p *packetIO) ReadPacket() (data []byte, err error) { p.readWriter.BeginRW(rwRead) + checkPacketLimit := p.readPacketLimit > 0 + remaining := p.readPacketLimit for more := true; more; { var buf []byte - buf, more, err = p.readOnePacket() + limit := -1 + if checkPacketLimit { + limit = remaining + } + buf, more, err = p.readOnePacket(limit) if err != nil { err = p.wrapErr(err) return } + if checkPacketLimit { + remaining -= len(buf) + if remaining < 0 { + err = p.wrapErr(errors.Wrapf(ErrPacketTooLarge, "packet size exceeds limit %d", p.readPacketLimit)) + return + } + } if data == nil { data = buf } else { diff --git a/pkg/proxy/net/packetio_options.go b/pkg/proxy/net/packetio_options.go index bb12a050..a7014e89 100644 --- a/pkg/proxy/net/packetio_options.go +++ b/pkg/proxy/net/packetio_options.go @@ -21,6 +21,14 @@ func WithWrapError(err error) func(pi *packetIO) { } } +// WithReadPacketLimit limits the total size of one ReadPacket call. +// A zero or negative limit means no limit. +func WithReadPacketLimit(limit int) func(pi *packetIO) { + return func(pi *packetIO) { + pi.readPacketLimit = limit + } +} + // WithRemoteAddr var _ proxyprotocol.AddressWrapper = &originAddr{} diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 15110884..ff20c155 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -111,6 +111,50 @@ func TestPacketIO(t *testing.T) { ) } +func TestReadPacketLimitOption(t *testing.T) { + cases := []struct { + name string + limit int + size int + ok bool + }{ + {name: "single-under", limit: 8, size: 4, ok: true}, + {name: "single-at", limit: 8, size: 8, ok: true}, + {name: "single-over", limit: 8, size: 9, ok: false}, + {name: "max-payload", limit: MaxPayloadLen, size: MaxPayloadLen, ok: true}, + {name: "multi-ok", limit: MaxPayloadLen + 20, size: MaxPayloadLen + 20, ok: true}, + {name: "multi-over", limit: MaxPayloadLen + 10, size: MaxPayloadLen + 20, ok: false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + testPipeConn(t, + func(t *testing.T, cli *packetIO) { + if !tc.ok { + _ = cli.readWriter.SetWriteDeadline(time.Now().Add(time.Second)) + } + data := make([]byte, tc.size) + err := cli.WritePacket(data, true) + if tc.ok { + require.NoError(t, err) + } + }, + func(t *testing.T, srv *packetIO) { + srv.ApplyOpts(WithReadPacketLimit(tc.limit)) + data, err := srv.ReadPacket() + if tc.ok { + require.NoError(t, err) + require.Len(t, data, tc.size) + } else { + require.ErrorIs(t, err, ErrPacketTooLarge) + } + }, + 1, + ) + }) + } +} + func TestTLS(t *testing.T) { stls, ctls, err := security.CreateTLSConfigForTest() require.NoError(t, err)