u_conn.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. // Copyright 2017 Google Inc. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bufio"
  7. "bytes"
  8. "crypto/cipher"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "net"
  14. "strconv"
  15. "sync/atomic"
  16. )
  17. type UConn struct {
  18. *Conn
  19. Extensions []TLSExtension
  20. clientHelloID ClientHelloID
  21. ClientHelloBuilt bool
  22. HandshakeState ClientHandshakeState
  23. // sessionID may or may not depend on ticket; nil => random
  24. GetSessionID func(ticket []byte) [32]byte
  25. greaseSeed [ssl_grease_last_index]uint16
  26. }
  27. // UClient returns a new uTLS client, with behavior depending on clientHelloID.
  28. // Config CAN be nil, but make sure to eventually specify ServerName.
  29. func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
  30. if config == nil {
  31. config = &Config{}
  32. }
  33. tlsConn := Conn{conn: conn, config: config, isClient: true}
  34. handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
  35. uconn := UConn{Conn: &tlsConn, clientHelloID: clientHelloID, HandshakeState: handshakeState}
  36. return &uconn
  37. }
  38. // BuildHandshakeState behavior varies based on ClientHelloID and
  39. // whether it was already called before.
  40. // If HelloGolang:
  41. // [only once] make default ClientHello and overwrite existing state
  42. // If any other mimicking ClientHelloID is used:
  43. // [only once] make ClientHello based on ID and overwrite existing state
  44. // [each call] apply uconn.Extensions config to internal crypto/tls structures
  45. // [each call] marshal ClientHello.
  46. //
  47. // BuildHandshakeState is automatically called before uTLS performs handshake,
  48. // amd should only be called explicitly to inspect/change fields of
  49. // default/mimicked ClientHello.
  50. func (uconn *UConn) BuildHandshakeState() error {
  51. if uconn.clientHelloID == HelloGolang {
  52. if uconn.ClientHelloBuilt {
  53. return nil
  54. }
  55. // use default Golang ClientHello.
  56. hello, ecdheParams, err := uconn.makeClientHello()
  57. if err != nil {
  58. return err
  59. }
  60. uconn.HandshakeState.Hello = hello.getPublicPtr()
  61. uconn.HandshakeState.State13.EcdheParams = ecdheParams
  62. uconn.HandshakeState.C = uconn.Conn
  63. } else {
  64. if !uconn.ClientHelloBuilt {
  65. err := uconn.applyPresetByID(uconn.clientHelloID)
  66. if err != nil {
  67. return err
  68. }
  69. }
  70. err := uconn.ApplyConfig()
  71. if err != nil {
  72. return err
  73. }
  74. err = uconn.MarshalClientHello()
  75. if err != nil {
  76. return err
  77. }
  78. }
  79. uconn.ClientHelloBuilt = true
  80. return nil
  81. }
  82. // SetSessionState sets the session ticket, which may be preshared or fake.
  83. // If session is nil, the body of session ticket extension will be unset,
  84. // but the extension itself still MAY be present for mimicking purposes.
  85. // Session tickets to be reused - use same cache on following connections.
  86. func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
  87. uconn.HandshakeState.Session = session
  88. var sessionTicket []uint8
  89. if session != nil {
  90. sessionTicket = session.sessionTicket
  91. }
  92. uconn.HandshakeState.Hello.TicketSupported = true
  93. uconn.HandshakeState.Hello.SessionTicket = sessionTicket
  94. for _, ext := range uconn.Extensions {
  95. st, ok := ext.(*SessionTicketExtension)
  96. if !ok {
  97. continue
  98. }
  99. st.Session = session
  100. if session != nil {
  101. if len(session.SessionTicket()) > 0 {
  102. if uconn.GetSessionID != nil {
  103. sid := uconn.GetSessionID(session.SessionTicket())
  104. uconn.HandshakeState.Hello.SessionId = sid[:]
  105. return nil
  106. }
  107. }
  108. var sessionID [32]byte
  109. _, err := io.ReadFull(uconn.config.rand(), uconn.HandshakeState.Hello.SessionId)
  110. if err != nil {
  111. return err
  112. }
  113. uconn.HandshakeState.Hello.SessionId = sessionID[:]
  114. }
  115. return nil
  116. }
  117. return nil
  118. }
  119. // If you want session tickets to be reused - use same cache on following connections
  120. func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
  121. uconn.config.ClientSessionCache = cache
  122. uconn.HandshakeState.Hello.TicketSupported = true
  123. }
  124. // SetClientRandom sets client random explicitly.
  125. // BuildHandshakeFirst() must be called before SetClientRandom.
  126. // r must to be 32 bytes long.
  127. func (uconn *UConn) SetClientRandom(r []byte) error {
  128. if len(r) != 32 {
  129. return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
  130. } else {
  131. uconn.HandshakeState.Hello.Random = make([]byte, 32)
  132. copy(uconn.HandshakeState.Hello.Random, r)
  133. return nil
  134. }
  135. }
  136. func (uconn *UConn) SetSNI(sni string) {
  137. hname := hostnameInSNI(sni)
  138. uconn.config.ServerName = hname
  139. for _, ext := range uconn.Extensions {
  140. sniExt, ok := ext.(*SNIExtension)
  141. if ok {
  142. sniExt.ServerName = hname
  143. }
  144. }
  145. }
  146. // Handshake runs the client handshake using given clientHandshakeState
  147. // Requires hs.hello, and, optionally, hs.session to be set.
  148. func (c *UConn) Handshake() error {
  149. c.handshakeMutex.Lock()
  150. defer c.handshakeMutex.Unlock()
  151. if err := c.handshakeErr; err != nil {
  152. return err
  153. }
  154. if c.handshakeComplete() {
  155. return nil
  156. }
  157. c.in.Lock()
  158. defer c.in.Unlock()
  159. if c.isClient {
  160. // [uTLS section begins]
  161. err := c.BuildHandshakeState()
  162. if err != nil {
  163. return err
  164. }
  165. // [uTLS section ends]
  166. c.handshakeErr = c.clientHandshake()
  167. } else {
  168. c.handshakeErr = c.serverHandshake()
  169. }
  170. if c.handshakeErr == nil {
  171. c.handshakes++
  172. } else {
  173. // If an error occurred during the hadshake try to flush the
  174. // alert that might be left in the buffer.
  175. c.flush()
  176. }
  177. if c.handshakeErr == nil && !c.handshakeComplete() {
  178. c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
  179. }
  180. return c.handshakeErr
  181. }
  182. // Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
  183. // Write writes data to the connection.
  184. func (c *UConn) Write(b []byte) (int, error) {
  185. // interlock with Close below
  186. for {
  187. x := atomic.LoadInt32(&c.activeCall)
  188. if x&1 != 0 {
  189. return 0, errClosed
  190. }
  191. if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
  192. defer atomic.AddInt32(&c.activeCall, -2)
  193. break
  194. }
  195. }
  196. if err := c.Handshake(); err != nil {
  197. return 0, err
  198. }
  199. c.out.Lock()
  200. defer c.out.Unlock()
  201. if err := c.out.err; err != nil {
  202. return 0, err
  203. }
  204. if !c.handshakeComplete() {
  205. return 0, alertInternalError
  206. }
  207. if c.closeNotifySent {
  208. return 0, errShutdown
  209. }
  210. // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
  211. // attack when using block mode ciphers due to predictable IVs.
  212. // This can be prevented by splitting each Application Data
  213. // record into two records, effectively randomizing the IV.
  214. //
  215. // https://www.openssl.org/~bodo/tls-cbc.txt
  216. // https://bugzilla.mozilla.org/show_bug.cgi?id=665814
  217. // https://www.imperialviolet.org/2012/01/15/beastfollowup.html
  218. var m int
  219. if len(b) > 1 && c.vers <= VersionTLS10 {
  220. if _, ok := c.out.cipher.(cipher.BlockMode); ok {
  221. n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
  222. if err != nil {
  223. return n, c.out.setErrorLocked(err)
  224. }
  225. m, b = 1, b[1:]
  226. }
  227. }
  228. n, err := c.writeRecordLocked(recordTypeApplicationData, b)
  229. return n + m, c.out.setErrorLocked(err)
  230. }
  231. // clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3)
  232. // and performs client TLS handshake with that state
  233. func (c *UConn) clientHandshake() (err error) {
  234. // [uTLS section begins]
  235. hello := c.HandshakeState.Hello.getPrivatePtr()
  236. defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
  237. sessionIsAlreadySet := c.HandshakeState.Session != nil
  238. // after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
  239. // useTLS13 variable tells which pointer
  240. // [uTLS section ends]
  241. if c.config == nil {
  242. c.config = defaultConfig()
  243. }
  244. // This may be a renegotiation handshake, in which case some fields
  245. // need to be reset.
  246. c.didResume = false
  247. // [uTLS section begins]
  248. // don't make new ClientHello, use hs.hello
  249. // preserve the checks from beginning and end of makeClientHello()
  250. if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
  251. return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
  252. }
  253. nextProtosLength := 0
  254. for _, proto := range c.config.NextProtos {
  255. if l := len(proto); l == 0 || l > 255 {
  256. return errors.New("tls: invalid NextProtos value")
  257. } else {
  258. nextProtosLength += 1 + l
  259. }
  260. }
  261. if nextProtosLength > 0xffff {
  262. return errors.New("tls: NextProtos values too large")
  263. }
  264. if c.handshakes > 0 {
  265. hello.secureRenegotiation = c.clientFinished[:]
  266. }
  267. // [uTLS section ends]
  268. cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
  269. if cacheKey != "" && session != nil {
  270. defer func() {
  271. // If we got a handshake failure when resuming a session, throw away
  272. // the session ticket. See RFC 5077, Section 3.2.
  273. //
  274. // RFC 8446 makes no mention of dropping tickets on failure, but it
  275. // does require servers to abort on invalid binders, so we need to
  276. // delete tickets to recover from a corrupted PSK.
  277. if err != nil {
  278. c.config.ClientSessionCache.Put(cacheKey, nil)
  279. }
  280. }()
  281. }
  282. if !sessionIsAlreadySet { // uTLS: do not overwrite already set session
  283. err = c.SetSessionState(session)
  284. if err != nil {
  285. return
  286. }
  287. }
  288. if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
  289. return err
  290. }
  291. msg, err := c.readHandshake()
  292. if err != nil {
  293. return err
  294. }
  295. serverHello, ok := msg.(*serverHelloMsg)
  296. if !ok {
  297. c.sendAlert(alertUnexpectedMessage)
  298. return unexpectedMessageError(serverHello, msg)
  299. }
  300. if err := c.pickTLSVersion(serverHello); err != nil {
  301. return err
  302. }
  303. // uTLS: do not create new handshakeState, use existing one
  304. if c.vers == VersionTLS13 {
  305. hs13 := c.HandshakeState.toPrivate13()
  306. hs13.serverHello = serverHello
  307. hs13.hello = hello
  308. if !sessionIsAlreadySet {
  309. hs13.earlySecret = earlySecret
  310. hs13.binderKey = binderKey
  311. }
  312. // In TLS 1.3, session tickets are delivered after the handshake.
  313. err = hs13.handshake()
  314. c.HandshakeState = *hs13.toPublic13()
  315. return err
  316. }
  317. hs12 := c.HandshakeState.toPrivate12()
  318. hs12.serverHello = serverHello
  319. hs12.hello = hello
  320. err = hs12.handshake()
  321. c.HandshakeState = *hs12.toPublic13()
  322. if err != nil {
  323. return err
  324. }
  325. // If we had a successful handshake and hs.session is different from
  326. // the one already cached - cache a new one.
  327. if cacheKey != "" && hs12.session != nil && session != hs12.session {
  328. c.config.ClientSessionCache.Put(cacheKey, hs12.session)
  329. }
  330. return nil
  331. }
  332. func (uconn *UConn) ApplyConfig() error {
  333. for _, ext := range uconn.Extensions {
  334. err := ext.writeToUConn(uconn)
  335. if err != nil {
  336. return err
  337. }
  338. }
  339. return nil
  340. }
  341. func (uconn *UConn) MarshalClientHello() error {
  342. hello := uconn.HandshakeState.Hello
  343. headerLength := 2 + 32 + 1 + len(hello.SessionId) +
  344. 2 + len(hello.CipherSuites)*2 +
  345. 1 + len(hello.CompressionMethods)
  346. extensionsLen := 0
  347. var paddingExt *UtlsPaddingExtension
  348. for _, ext := range uconn.Extensions {
  349. if pe, ok := ext.(*UtlsPaddingExtension); !ok {
  350. // If not padding - just add length of extension to total length
  351. extensionsLen += ext.Len()
  352. } else {
  353. // If padding - process it later
  354. if paddingExt == nil {
  355. paddingExt = pe
  356. } else {
  357. return errors.New("Multiple padding extensions!")
  358. }
  359. }
  360. }
  361. if paddingExt != nil {
  362. // determine padding extension presence and length
  363. paddingExt.Update(headerLength + 4 + extensionsLen + 2)
  364. extensionsLen += paddingExt.Len()
  365. }
  366. helloLen := headerLength
  367. if len(uconn.Extensions) > 0 {
  368. helloLen += 2 + extensionsLen // 2 bytes for extensions' length
  369. }
  370. helloBuffer := bytes.Buffer{}
  371. bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
  372. // We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
  373. // Write() will become noop, and error will be accessible via Flush(), which is called once in the end
  374. binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
  375. helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
  376. binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
  377. binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
  378. binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
  379. binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
  380. binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
  381. binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
  382. for _, suite := range hello.CipherSuites {
  383. binary.Write(bufferedWriter, binary.BigEndian, suite)
  384. }
  385. binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
  386. binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
  387. if len(uconn.Extensions) > 0 {
  388. binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
  389. for _, ext := range uconn.Extensions {
  390. bufferedWriter.ReadFrom(ext)
  391. }
  392. }
  393. err := bufferedWriter.Flush()
  394. if err != nil {
  395. return err
  396. }
  397. if helloBuffer.Len() != 4+helloLen {
  398. return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
  399. ". Got: " + strconv.Itoa(helloBuffer.Len()))
  400. }
  401. hello.Raw = helloBuffer.Bytes()
  402. return nil
  403. }
  404. // get current state of cipher and encrypt zeros to get keystream
  405. func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
  406. zeros := make([]byte, length)
  407. if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
  408. // AEAD.Seal() does not mutate internal state, other ciphers might
  409. return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
  410. }
  411. return nil, errors.New("Could not convert OutCipher to cipher.AEAD")
  412. }
  413. // SetVersCreateState set min and max TLS version in all appropriate places.
  414. func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16) error {
  415. if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 {
  416. return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers)
  417. }
  418. if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 {
  419. return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers)
  420. }
  421. uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers)
  422. uconn.config.MinVersion = minTLSVers
  423. uconn.config.MaxVersion = maxTLSVers
  424. return nil
  425. }
  426. func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
  427. uconn.Conn.conn = c
  428. }
  429. func (uconn *UConn) GetUnderlyingConn() net.Conn {
  430. return uconn.Conn.conn
  431. }
  432. // MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
  433. // Major Hack Alert.
  434. func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
  435. tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
  436. cs := cipherSuiteByID(cipherSuite)
  437. // This is mostly borrowed from establishKeys()
  438. clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
  439. keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
  440. cs.macLen, cs.keyLen, cs.ivLen)
  441. var clientCipher, serverCipher interface{}
  442. var clientHash, serverHash macFunction
  443. if cs.cipher != nil {
  444. clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
  445. clientHash = cs.mac(version, clientMAC)
  446. serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
  447. serverHash = cs.mac(version, serverMAC)
  448. } else {
  449. clientCipher = cs.aead(clientKey, clientIV)
  450. serverCipher = cs.aead(serverKey, serverIV)
  451. }
  452. if isClient {
  453. tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
  454. tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
  455. } else {
  456. tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
  457. tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
  458. }
  459. // skip the handshake states
  460. tlsConn.handshakeStatus = 1
  461. tlsConn.cipherSuite = cipherSuite
  462. tlsConn.haveVers = true
  463. tlsConn.vers = version
  464. // Update to the new cipher specs
  465. // and consume the finished messages
  466. tlsConn.in.changeCipherSpec()
  467. tlsConn.out.changeCipherSpec()
  468. tlsConn.in.incSeq()
  469. tlsConn.out.incSeq()
  470. return tlsConn
  471. }
  472. func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
  473. a := make([]uint16, maxVers-minVers+1)
  474. for i := range a {
  475. a[i] = maxVers - uint16(i)
  476. }
  477. return a
  478. }