| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- // Copyright 2017 Google Inc. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package tls
- import (
- "bufio"
- "bytes"
- "crypto/cipher"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "net"
- "strconv"
- "sync/atomic"
- )
- type UConn struct {
- *Conn
- Extensions []TLSExtension
- clientHelloID ClientHelloID
- ClientHelloBuilt bool
- HandshakeState ClientHandshakeState
- // sessionID may or may not depend on ticket; nil => random
- GetSessionID func(ticket []byte) [32]byte
- greaseSeed [ssl_grease_last_index]uint16
- }
- // UClient returns a new uTLS client, with behavior depending on clientHelloID.
- // Config CAN be nil, but make sure to eventually specify ServerName.
- func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
- if config == nil {
- config = &Config{}
- }
- tlsConn := Conn{conn: conn, config: config, isClient: true}
- handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
- uconn := UConn{Conn: &tlsConn, clientHelloID: clientHelloID, HandshakeState: handshakeState}
- return &uconn
- }
- // BuildHandshakeState behavior varies based on ClientHelloID and
- // whether it was already called before.
- // If HelloGolang:
- // [only once] make default ClientHello and overwrite existing state
- // If any other mimicking ClientHelloID is used:
- // [only once] make ClientHello based on ID and overwrite existing state
- // [each call] apply uconn.Extensions config to internal crypto/tls structures
- // [each call] marshal ClientHello.
- //
- // BuildHandshakeState is automatically called before uTLS performs handshake,
- // amd should only be called explicitly to inspect/change fields of
- // default/mimicked ClientHello.
- func (uconn *UConn) BuildHandshakeState() error {
- if uconn.clientHelloID == HelloGolang {
- if uconn.ClientHelloBuilt {
- return nil
- }
- // use default Golang ClientHello.
- hello, ecdheParams, err := uconn.makeClientHello()
- if err != nil {
- return err
- }
- uconn.HandshakeState.Hello = hello.getPublicPtr()
- uconn.HandshakeState.State13.EcdheParams = ecdheParams
- uconn.HandshakeState.C = uconn.Conn
- } else {
- if !uconn.ClientHelloBuilt {
- err := uconn.applyPresetByID(uconn.clientHelloID)
- if err != nil {
- return err
- }
- }
- err := uconn.ApplyConfig()
- if err != nil {
- return err
- }
- err = uconn.MarshalClientHello()
- if err != nil {
- return err
- }
- }
- uconn.ClientHelloBuilt = true
- return nil
- }
- // SetSessionState sets the session ticket, which may be preshared or fake.
- // If session is nil, the body of session ticket extension will be unset,
- // but the extension itself still MAY be present for mimicking purposes.
- // Session tickets to be reused - use same cache on following connections.
- func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
- uconn.HandshakeState.Session = session
- var sessionTicket []uint8
- if session != nil {
- sessionTicket = session.sessionTicket
- }
- uconn.HandshakeState.Hello.TicketSupported = true
- uconn.HandshakeState.Hello.SessionTicket = sessionTicket
- for _, ext := range uconn.Extensions {
- st, ok := ext.(*SessionTicketExtension)
- if !ok {
- continue
- }
- st.Session = session
- if session != nil {
- if len(session.SessionTicket()) > 0 {
- if uconn.GetSessionID != nil {
- sid := uconn.GetSessionID(session.SessionTicket())
- uconn.HandshakeState.Hello.SessionId = sid[:]
- return nil
- }
- }
- var sessionID [32]byte
- _, err := io.ReadFull(uconn.config.rand(), uconn.HandshakeState.Hello.SessionId)
- if err != nil {
- return err
- }
- uconn.HandshakeState.Hello.SessionId = sessionID[:]
- }
- return nil
- }
- return nil
- }
- // If you want session tickets to be reused - use same cache on following connections
- func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
- uconn.config.ClientSessionCache = cache
- uconn.HandshakeState.Hello.TicketSupported = true
- }
- // SetClientRandom sets client random explicitly.
- // BuildHandshakeFirst() must be called before SetClientRandom.
- // r must to be 32 bytes long.
- func (uconn *UConn) SetClientRandom(r []byte) error {
- if len(r) != 32 {
- return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
- } else {
- uconn.HandshakeState.Hello.Random = make([]byte, 32)
- copy(uconn.HandshakeState.Hello.Random, r)
- return nil
- }
- }
- func (uconn *UConn) SetSNI(sni string) {
- hname := hostnameInSNI(sni)
- uconn.config.ServerName = hname
- for _, ext := range uconn.Extensions {
- sniExt, ok := ext.(*SNIExtension)
- if ok {
- sniExt.ServerName = hname
- }
- }
- }
- // Handshake runs the client handshake using given clientHandshakeState
- // Requires hs.hello, and, optionally, hs.session to be set.
- func (c *UConn) Handshake() error {
- c.handshakeMutex.Lock()
- defer c.handshakeMutex.Unlock()
- if err := c.handshakeErr; err != nil {
- return err
- }
- if c.handshakeComplete() {
- return nil
- }
- c.in.Lock()
- defer c.in.Unlock()
- if c.isClient {
- // [uTLS section begins]
- err := c.BuildHandshakeState()
- if err != nil {
- return err
- }
- // [uTLS section ends]
- c.handshakeErr = c.clientHandshake()
- } else {
- c.handshakeErr = c.serverHandshake()
- }
- if c.handshakeErr == nil {
- c.handshakes++
- } else {
- // If an error occurred during the hadshake try to flush the
- // alert that might be left in the buffer.
- c.flush()
- }
- if c.handshakeErr == nil && !c.handshakeComplete() {
- c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
- }
- return c.handshakeErr
- }
- // Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
- // Write writes data to the connection.
- func (c *UConn) Write(b []byte) (int, error) {
- // interlock with Close below
- for {
- x := atomic.LoadInt32(&c.activeCall)
- if x&1 != 0 {
- return 0, errClosed
- }
- if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
- defer atomic.AddInt32(&c.activeCall, -2)
- break
- }
- }
- if err := c.Handshake(); err != nil {
- return 0, err
- }
- c.out.Lock()
- defer c.out.Unlock()
- if err := c.out.err; err != nil {
- return 0, err
- }
- if !c.handshakeComplete() {
- return 0, alertInternalError
- }
- if c.closeNotifySent {
- return 0, errShutdown
- }
- // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
- // attack when using block mode ciphers due to predictable IVs.
- // This can be prevented by splitting each Application Data
- // record into two records, effectively randomizing the IV.
- //
- // https://www.openssl.org/~bodo/tls-cbc.txt
- // https://bugzilla.mozilla.org/show_bug.cgi?id=665814
- // https://www.imperialviolet.org/2012/01/15/beastfollowup.html
- var m int
- if len(b) > 1 && c.vers <= VersionTLS10 {
- if _, ok := c.out.cipher.(cipher.BlockMode); ok {
- n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
- if err != nil {
- return n, c.out.setErrorLocked(err)
- }
- m, b = 1, b[1:]
- }
- }
- n, err := c.writeRecordLocked(recordTypeApplicationData, b)
- return n + m, c.out.setErrorLocked(err)
- }
- // clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3)
- // and performs client TLS handshake with that state
- func (c *UConn) clientHandshake() (err error) {
- // [uTLS section begins]
- hello := c.HandshakeState.Hello.getPrivatePtr()
- defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
- sessionIsAlreadySet := c.HandshakeState.Session != nil
- // after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
- // useTLS13 variable tells which pointer
- // [uTLS section ends]
- if c.config == nil {
- c.config = defaultConfig()
- }
- // This may be a renegotiation handshake, in which case some fields
- // need to be reset.
- c.didResume = false
- // [uTLS section begins]
- // don't make new ClientHello, use hs.hello
- // preserve the checks from beginning and end of makeClientHello()
- if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
- return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
- }
- nextProtosLength := 0
- for _, proto := range c.config.NextProtos {
- if l := len(proto); l == 0 || l > 255 {
- return errors.New("tls: invalid NextProtos value")
- } else {
- nextProtosLength += 1 + l
- }
- }
- if nextProtosLength > 0xffff {
- return errors.New("tls: NextProtos values too large")
- }
- if c.handshakes > 0 {
- hello.secureRenegotiation = c.clientFinished[:]
- }
- // [uTLS section ends]
- cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
- if cacheKey != "" && session != nil {
- defer func() {
- // If we got a handshake failure when resuming a session, throw away
- // the session ticket. See RFC 5077, Section 3.2.
- //
- // RFC 8446 makes no mention of dropping tickets on failure, but it
- // does require servers to abort on invalid binders, so we need to
- // delete tickets to recover from a corrupted PSK.
- if err != nil {
- c.config.ClientSessionCache.Put(cacheKey, nil)
- }
- }()
- }
- if !sessionIsAlreadySet { // uTLS: do not overwrite already set session
- err = c.SetSessionState(session)
- if err != nil {
- return
- }
- }
- if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
- return err
- }
- msg, err := c.readHandshake()
- if err != nil {
- return err
- }
- serverHello, ok := msg.(*serverHelloMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(serverHello, msg)
- }
- if err := c.pickTLSVersion(serverHello); err != nil {
- return err
- }
- // uTLS: do not create new handshakeState, use existing one
- if c.vers == VersionTLS13 {
- hs13 := c.HandshakeState.toPrivate13()
- hs13.serverHello = serverHello
- hs13.hello = hello
- if !sessionIsAlreadySet {
- hs13.earlySecret = earlySecret
- hs13.binderKey = binderKey
- }
- // In TLS 1.3, session tickets are delivered after the handshake.
- err = hs13.handshake()
- c.HandshakeState = *hs13.toPublic13()
- return err
- }
- hs12 := c.HandshakeState.toPrivate12()
- hs12.serverHello = serverHello
- hs12.hello = hello
- err = hs12.handshake()
- c.HandshakeState = *hs12.toPublic13()
- if err != nil {
- return err
- }
- // If we had a successful handshake and hs.session is different from
- // the one already cached - cache a new one.
- if cacheKey != "" && hs12.session != nil && session != hs12.session {
- c.config.ClientSessionCache.Put(cacheKey, hs12.session)
- }
- return nil
- }
- func (uconn *UConn) ApplyConfig() error {
- for _, ext := range uconn.Extensions {
- err := ext.writeToUConn(uconn)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (uconn *UConn) MarshalClientHello() error {
- hello := uconn.HandshakeState.Hello
- headerLength := 2 + 32 + 1 + len(hello.SessionId) +
- 2 + len(hello.CipherSuites)*2 +
- 1 + len(hello.CompressionMethods)
- extensionsLen := 0
- var paddingExt *UtlsPaddingExtension
- for _, ext := range uconn.Extensions {
- if pe, ok := ext.(*UtlsPaddingExtension); !ok {
- // If not padding - just add length of extension to total length
- extensionsLen += ext.Len()
- } else {
- // If padding - process it later
- if paddingExt == nil {
- paddingExt = pe
- } else {
- return errors.New("Multiple padding extensions!")
- }
- }
- }
- if paddingExt != nil {
- // determine padding extension presence and length
- paddingExt.Update(headerLength + 4 + extensionsLen + 2)
- extensionsLen += paddingExt.Len()
- }
- helloLen := headerLength
- if len(uconn.Extensions) > 0 {
- helloLen += 2 + extensionsLen // 2 bytes for extensions' length
- }
- helloBuffer := bytes.Buffer{}
- bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
- // We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
- // Write() will become noop, and error will be accessible via Flush(), which is called once in the end
- binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
- helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
- binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
- binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
- binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
- binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
- binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
- binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
- for _, suite := range hello.CipherSuites {
- binary.Write(bufferedWriter, binary.BigEndian, suite)
- }
- binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
- binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
- if len(uconn.Extensions) > 0 {
- binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
- for _, ext := range uconn.Extensions {
- bufferedWriter.ReadFrom(ext)
- }
- }
- err := bufferedWriter.Flush()
- if err != nil {
- return err
- }
- if helloBuffer.Len() != 4+helloLen {
- return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
- ". Got: " + strconv.Itoa(helloBuffer.Len()))
- }
- hello.Raw = helloBuffer.Bytes()
- return nil
- }
- // get current state of cipher and encrypt zeros to get keystream
- func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
- zeros := make([]byte, length)
- if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
- // AEAD.Seal() does not mutate internal state, other ciphers might
- return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
- }
- return nil, errors.New("Could not convert OutCipher to cipher.AEAD")
- }
- // SetVersCreateState set min and max TLS version in all appropriate places.
- func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16) error {
- if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 {
- return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers)
- }
- if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 {
- return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers)
- }
- uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers)
- uconn.config.MinVersion = minTLSVers
- uconn.config.MaxVersion = maxTLSVers
- return nil
- }
- func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
- uconn.Conn.conn = c
- }
- func (uconn *UConn) GetUnderlyingConn() net.Conn {
- return uconn.Conn.conn
- }
- // MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
- // Major Hack Alert.
- func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
- tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
- cs := cipherSuiteByID(cipherSuite)
- // This is mostly borrowed from establishKeys()
- clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
- keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
- cs.macLen, cs.keyLen, cs.ivLen)
- var clientCipher, serverCipher interface{}
- var clientHash, serverHash macFunction
- if cs.cipher != nil {
- clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
- clientHash = cs.mac(version, clientMAC)
- serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
- serverHash = cs.mac(version, serverMAC)
- } else {
- clientCipher = cs.aead(clientKey, clientIV)
- serverCipher = cs.aead(serverKey, serverIV)
- }
- if isClient {
- tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
- tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
- } else {
- tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
- tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
- }
- // skip the handshake states
- tlsConn.handshakeStatus = 1
- tlsConn.cipherSuite = cipherSuite
- tlsConn.haveVers = true
- tlsConn.vers = version
- // Update to the new cipher specs
- // and consume the finished messages
- tlsConn.in.changeCipherSpec()
- tlsConn.out.changeCipherSpec()
- tlsConn.in.incSeq()
- tlsConn.out.incSeq()
- return tlsConn
- }
- func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
- a := make([]uint16, maxVers-minVers+1)
- for i := range a {
- a[i] = maxVers - uint16(i)
- }
- return a
- }
|