Skip to content

Commit

Permalink
Lock sessionController only on last call to BuildHandshakeState
Browse files Browse the repository at this point in the history
  • Loading branch information
adotkhan committed Jun 19, 2024
1 parent 4f71339 commit ebe5d66
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
29 changes: 21 additions & 8 deletions u_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,27 @@ func (uconn *UConn) BuildHandshakeState() error {
return err
}

err = uconn.uLoadSession()
if err != nil {
return err
}

err = uconn.MarshalClientHello()
if err != nil {
return err
}

uconn.uApplyPatch()
}
return nil
}

uconn.sessionController.finalCheck()
uconn.clientHelloBuildStatus = BuildByUtls
func (uconn *UConn) lockSessionState() error {

err := uconn.uLoadSession()
if err != nil {
return err
}

uconn.uApplyPatch()

uconn.sessionController.finalCheck()
uconn.clientHelloBuildStatus = BuildByUtls

return nil
}

Expand Down Expand Up @@ -358,6 +364,10 @@ func (c *UConn) handshakeContext(ctx context.Context) (ret error) {
if err != nil {
return err
}
err = c.lockSessionState()
if err != nil {
return err
}
}
// [uTLS section ends]
c.handshakeErr = c.handshakeFn(handshakeCtx)
Expand Down Expand Up @@ -983,6 +993,9 @@ func (c *UConn) handleRenegotiation() error {
if err = c.BuildHandshakeState(); err != nil {
return err
}
if err = c.lockSessionState(); err != nil {
return err
}
// [uTLS section ends]
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
Expand Down
5 changes: 5 additions & 0 deletions u_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,11 @@ func (test *clientTest) runUTLS(t *testing.T, write bool, hello helloStrategy, o
t.Errorf("Client.BuildHandshakeState() failed: %s", err)
return
}
err = client.lockSessionState()
if err != nil {
t.Errorf("Client.lockSessionState() failed: %s", err)
return
}
// TODO: fix this name hack if we ever decide to use non-standard testing object
err = client.SetClientRandom([]byte("Custom ClientRandom h^xbw8bf0sn3"))
if err != nil {
Expand Down

0 comments on commit ebe5d66

Please sign in to comment.