增加websocket支持

This commit is contained in:
2025-09-08 13:47:13 +08:00
parent 9caedd697d
commit 7e0fd53dd3
336 changed files with 9020 additions and 20356 deletions

View File

@@ -20,44 +20,41 @@ import (
// TODO: Benchmark to determine if the pools are necessary. The GC may have
// improved enough that we can instead allocate chunks like this:
// make([]byte, max(16<<10, expectedBytesRemaining))
var dataChunkPools = [...]sync.Pool{
{New: func() interface{} { return new([1 << 10]byte) }},
{New: func() interface{} { return new([2 << 10]byte) }},
{New: func() interface{} { return new([4 << 10]byte) }},
{New: func() interface{} { return new([8 << 10]byte) }},
{New: func() interface{} { return new([16 << 10]byte) }},
}
var (
dataChunkSizeClasses = []int{
1 << 10,
2 << 10,
4 << 10,
8 << 10,
16 << 10,
}
dataChunkPools = [...]sync.Pool{
{New: func() interface{} { return make([]byte, 1<<10) }},
{New: func() interface{} { return make([]byte, 2<<10) }},
{New: func() interface{} { return make([]byte, 4<<10) }},
{New: func() interface{} { return make([]byte, 8<<10) }},
{New: func() interface{} { return make([]byte, 16<<10) }},
}
)
func getDataBufferChunk(size int64) []byte {
switch {
case size <= 1<<10:
return dataChunkPools[0].Get().(*[1 << 10]byte)[:]
case size <= 2<<10:
return dataChunkPools[1].Get().(*[2 << 10]byte)[:]
case size <= 4<<10:
return dataChunkPools[2].Get().(*[4 << 10]byte)[:]
case size <= 8<<10:
return dataChunkPools[3].Get().(*[8 << 10]byte)[:]
default:
return dataChunkPools[4].Get().(*[16 << 10]byte)[:]
i := 0
for ; i < len(dataChunkSizeClasses)-1; i++ {
if size <= int64(dataChunkSizeClasses[i]) {
break
}
}
return dataChunkPools[i].Get().([]byte)
}
func putDataBufferChunk(p []byte) {
switch len(p) {
case 1 << 10:
dataChunkPools[0].Put((*[1 << 10]byte)(p))
case 2 << 10:
dataChunkPools[1].Put((*[2 << 10]byte)(p))
case 4 << 10:
dataChunkPools[2].Put((*[4 << 10]byte)(p))
case 8 << 10:
dataChunkPools[3].Put((*[8 << 10]byte)(p))
case 16 << 10:
dataChunkPools[4].Put((*[16 << 10]byte)(p))
default:
panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
for i, n := range dataChunkSizeClasses {
if len(p) == n {
dataChunkPools[i].Put(p)
return
}
}
panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
}
// dataBuffer is an io.ReadWriter backed by a list of data chunks.

View File

@@ -490,9 +490,6 @@ func terminalReadFrameError(err error) bool {
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
//
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
fr.errDetail = nil
if fr.lastFrame != nil {
@@ -1513,18 +1510,19 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
}
func (fr *Framer) maxHeaderStringLen() int {
v := int(fr.maxHeaderListSize())
if v < 0 {
// If maxHeaderListSize overflows an int, use no limit (0).
return 0
v := fr.maxHeaderListSize()
if uint32(int(v)) == v {
return int(v)
}
return v
// They had a crazy big number for MaxHeaderBytes anyway,
// so give them unlimited header lengths:
return 0
}
// readMetaFrame returns 0 or more CONTINUATION frames from fr and
// merge them into the provided hf and returns a MetaHeadersFrame
// with the decoded hpack values.
func (fr *Framer) readMetaFrame(hf *HeadersFrame) (Frame, error) {
func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if fr.AllowIllegalReads {
return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
}
@@ -1567,7 +1565,6 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (Frame, error) {
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
remainSize = 0
return
}
remainSize -= size
@@ -1580,38 +1577,8 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (Frame, error) {
var hc headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
// Avoid parsing large amounts of headers that we will then discard.
// If the sender exceeds the max header list size by too much,
// skip parsing the fragment and close the connection.
//
// "Too much" is either any CONTINUATION frame after we've already
// exceeded the max header list size (in which case remainSize is 0),
// or a frame whose encoded size is more than twice the remaining
// header list bytes we're willing to accept.
if int64(len(frag)) > int64(2*remainSize) {
if VerboseLogs {
log.Printf("http2: header list too large")
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return mh, ConnectionError(ErrCodeProtocol)
}
// Also close the connection after any CONTINUATION frame following an
// invalid header, since we stop tracking the size of the headers after
// an invalid one.
if invalid != nil {
if VerboseLogs {
log.Printf("http2: invalid header: %v", invalid)
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return mh, ConnectionError(ErrCodeProtocol)
}
if _, err := hdec.Write(frag); err != nil {
return mh, ConnectionError(ErrCodeCompression)
return nil, ConnectionError(ErrCodeCompression)
}
if hc.HeadersEnded() {
@@ -1628,7 +1595,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (Frame, error) {
mh.HeadersFrame.invalidate()
if err := hdec.Close(); err != nil {
return mh, ConnectionError(ErrCodeCompression)
return nil, ConnectionError(ErrCodeCompression)
}
if invalid != nil {
fr.errDetail = invalid

View File

@@ -44,7 +44,7 @@ func init() {
// HTTP/1, but unlikely to occur in practice and (2) Upgrading from HTTP/1 to
// h2c - this works by using the HTTP/1 Upgrade header to request an upgrade to
// h2c. When either of those situations occur we hijack the HTTP/1 connection,
// convert it to an HTTP/2 connection and pass the net.Conn to http2.ServeConn.
// convert it to a HTTP/2 connection and pass the net.Conn to http2.ServeConn.
type h2cHandler struct {
Handler http.Handler
s *http2.Server

View File

@@ -77,10 +77,7 @@ func (p *pipe) Read(d []byte) (n int, err error) {
}
}
var (
errClosedPipeWrite = errors.New("write on closed buffer")
errUninitializedPipeWrite = errors.New("write on uninitialized buffer")
)
var errClosedPipeWrite = errors.New("write on closed buffer")
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
@@ -94,12 +91,6 @@ func (p *pipe) Write(d []byte) (n int, err error) {
if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite
}
// pipe.setBuffer is never invoked, leaving the buffer uninitialized.
// We shouldn't try to write to an uninitialized pipe,
// but returning an error is better than panicking.
if p.b == nil {
return 0, errUninitializedPipeWrite
}
return p.b.Write(d)
}

View File

@@ -124,7 +124,6 @@ type Server struct {
// IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout.
// If zero or negative, there is no timeout.
IdleTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow
@@ -435,14 +434,14 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here.
if sc.hs.WriteTimeout > 0 {
if sc.hs.WriteTimeout != 0 {
sc.conn.SetWriteDeadline(time.Time{})
}
if s.NewWriteScheduler != nil {
sc.writeSched = s.NewWriteScheduler()
} else {
sc.writeSched = newRoundRobinWriteScheduler()
sc.writeSched = NewPriorityWriteScheduler(nil)
}
// These start at the RFC-specified defaults. If there is a higher
@@ -582,11 +581,9 @@ type serverConn struct {
advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
curClientStreams uint32 // number of open streams initiated by the client
curPushedStreams uint32 // number of open streams initiated by server push
curHandlers uint32 // number of running handler goroutines
maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests
maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes
streams map[uint32]*stream
unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32
maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default)
@@ -732,7 +729,11 @@ func isClosedConnError(err error) bool {
return false
}
if errors.Is(err, net.ErrClosed) {
// TODO: remove this string search and be more like the Windows
// case below. That might involve modifying the standard library
// to return better error types.
str := err.Error()
if strings.Contains(str, "use of closed network connection") {
return true
}
@@ -921,7 +922,7 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateActive)
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
if sc.srv.IdleTimeout != 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
@@ -980,8 +981,6 @@ func (sc *serverConn) serve() {
return
case gracefulShutdownMsg:
sc.startGracefulShutdownInternal()
case handlerDoneMsg:
sc.handlerDone()
default:
panic("unknown timer")
}
@@ -1013,6 +1012,14 @@ func (sc *serverConn) serve() {
}
}
func (sc *serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) {
select {
case <-sc.doneServing:
case <-sharedCh:
close(privateCh)
}
}
type serverMessage int
// Message values sent to serveMsgCh.
@@ -1021,7 +1028,6 @@ var (
idleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage)
)
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
@@ -1478,11 +1484,6 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
sc.goAway(ErrCodeFlowControl)
return true
case ConnectionError:
if res.f != nil {
if id := res.f.Header().StreamID; id > sc.maxClientStreamID {
sc.maxClientStreamID = id
}
}
sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
sc.goAway(ErrCode(ev))
return true // goAway will handle shutdown
@@ -1639,7 +1640,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
if sc.srv.IdleTimeout != 0 {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
@@ -1899,11 +1900,9 @@ func (st *stream) copyTrailersToHandlerRequest() {
// onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired.
func (st *stream) onReadTimeout() {
if st.body != nil {
// Wrap the ErrDeadlineExceeded to avoid callers depending on us
// returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
// Wrap the ErrDeadlineExceeded to avoid callers depending on us
// returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
// onWriteTimeout is run on its own goroutine (from time.AfterFunc)
@@ -2019,12 +2018,15 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout > 0 {
if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
if st.body != nil {
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
}
return sc.scheduleHandler(id, rw, req, handler)
go sc.runHandler(rw, req, handler)
return nil
}
func (sc *serverConn) upgradeRequest(req *http.Request) {
@@ -2040,14 +2042,10 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
// Disable any read deadline set by the net/http package
// prior to the upgrade.
if sc.hs.ReadTimeout > 0 {
if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{})
}
// This is the first request on the connection,
// so start the handler directly rather than going
// through scheduleHandler.
sc.curHandlers++
go sc.runHandler(rw, req, sc.handler.ServeHTTP)
}
@@ -2118,7 +2116,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout > 0 {
if sc.hs.WriteTimeout != 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
@@ -2288,62 +2286,8 @@ func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *response
return &responseWriter{rws: rws}
}
type unstartedHandler struct {
streamID uint32
rw *responseWriter
req *http.Request
handler func(http.ResponseWriter, *http.Request)
}
// scheduleHandler starts a handler goroutine,
// or schedules one to start as soon as an existing handler finishes.
func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error {
sc.serveG.check()
maxHandlers := sc.advMaxStreams
if sc.curHandlers < maxHandlers {
sc.curHandlers++
go sc.runHandler(rw, req, handler)
return nil
}
if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
}
sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
streamID: streamID,
rw: rw,
req: req,
handler: handler,
})
return nil
}
func (sc *serverConn) handlerDone() {
sc.serveG.check()
sc.curHandlers--
i := 0
maxHandlers := sc.advMaxStreams
for ; i < len(sc.unstartedHandlers); i++ {
u := sc.unstartedHandlers[i]
if sc.streams[u.streamID] == nil {
// This stream was reset before its goroutine had a chance to start.
continue
}
if sc.curHandlers >= maxHandlers {
break
}
sc.curHandlers++
go sc.runHandler(u.rw, u.req, u.handler)
sc.unstartedHandlers[i] = unstartedHandler{} // don't retain references
}
sc.unstartedHandlers = sc.unstartedHandlers[i:]
if len(sc.unstartedHandlers) == 0 {
sc.unstartedHandlers = nil
}
}
// Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true
defer func() {
rw.rws.stream.cancelCtx()
@@ -2485,7 +2429,7 @@ type requestBody struct {
conn *serverConn
closeOnce sync.Once // for use by Close only
sawEOF bool // for use by Read only
pipe *pipe // non-nil if we have an HTTP entity message body
pipe *pipe // non-nil if we have a HTTP entity message body
needsContinue bool // need to send a 100-continue
}
@@ -2551,6 +2495,7 @@ type responseWriterState struct {
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished
dirty bool // a Write failed; don't reuse this responseWriterState
sentContentLen int64 // non-zero if handler set a Content-Length header
wroteBytes int64
@@ -2624,8 +2569,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
clen = ""
}
}
_, hasContentLength := rws.snapHeader["Content-Length"]
if !hasContentLength && clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
if clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
clen = strconv.Itoa(len(p))
}
_, hasContentType := rws.snapHeader["Content-Type"]
@@ -2670,6 +2614,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
date: date,
})
if err != nil {
rws.dirty = true
return 0, err
}
if endStream {
@@ -2690,6 +2635,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
rws.dirty = true
return 0, err
}
}
@@ -2701,6 +2647,9 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
trailers: rws.trailers,
endStream: true,
})
if err != nil {
rws.dirty = true
}
return len(p), err
}
return len(p), nil
@@ -2825,7 +2774,7 @@ func (w *responseWriter) FlushError() error {
err = rws.bw.Flush()
} else {
// The bufio.Writer won't call chunkWriter.Write
// (writeChunk with zero bytes), so we have to do it
// (writeChunk with zero bytes, so we have to do it
// ourselves to force the HTTP response header and/or
// final DATA frame (with END_STREAM) to be sent.
_, err = chunkWriter{rws}.Write(nil)
@@ -2916,12 +2865,14 @@ func (rws *responseWriterState) writeHeader(code int) {
h.Del("Transfer-Encoding")
}
rws.conn.writeHeaders(rws.stream, &writeResHeaders{
if rws.conn.writeHeaders(rws.stream, &writeResHeaders{
streamID: rws.stream.id,
httpResCode: code,
h: h,
endStream: rws.handlerDone && !rws.hasTrailers(),
})
}) != nil {
rws.dirty = true
}
return
}
@@ -2986,10 +2937,19 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int,
func (w *responseWriter) handlerDone() {
rws := w.rws
dirty := rws.dirty
rws.handlerDone = true
w.Flush()
w.rws = nil
responseWriterStatePool.Put(rws)
if !dirty {
// Only recycle the pool if all prior Write calls to
// the serverConn goroutine completed successfully. If
// they returned earlier due to resets from the peer
// there might still be write goroutines outstanding
// from the serverConn referencing the rws memory. See
// issue 20704.
responseWriterStatePool.Put(rws)
}
}
// Push errors.
@@ -3172,7 +3132,6 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
}
sc.curHandlers++
go sc.runHandler(rw, req, sc.handler.ServeHTTP)
return promisedID, nil
}

View File

@@ -19,7 +19,6 @@ import (
"io/fs"
"log"
"math"
"math/bits"
mathrand "math/rand"
"net"
"net/http"
@@ -147,12 +146,6 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
// Zero means no limit.
IdleConnTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
@@ -184,8 +177,6 @@ type Transport struct {
connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool
syncHooks *testSyncHooks
}
func (t *Transport) maxHeaderListSize() uint32 {
@@ -299,7 +290,8 @@ func (t *Transport) initConnPool() {
// HTTP/2 server.
type ClientConn struct {
t *Transport
tconn net.Conn // usually *tls.Conn, except specialized impls
tconn net.Conn // usually *tls.Conn, except specialized impls
tconnClosed bool
tlsState *tls.ConnectionState // nil only for specialized impls
reused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
@@ -310,7 +302,7 @@ type ClientConn struct {
readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never
idleTimer timer
idleTimer *time.Timer
mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes
@@ -352,60 +344,6 @@ type ClientConn struct {
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
syncHooks *testSyncHooks // can be nil
}
// Hook points used for testing.
// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
// Inside tests, see the testSyncHooks function docs.
// goRun starts a new goroutine.
func (cc *ClientConn) goRun(f func()) {
if cc.syncHooks != nil {
cc.syncHooks.goRun(f)
return
}
go f()
}
// condBroadcast is cc.cond.Broadcast.
func (cc *ClientConn) condBroadcast() {
if cc.syncHooks != nil {
cc.syncHooks.condBroadcast(cc.cond)
}
cc.cond.Broadcast()
}
// condWait is cc.cond.Wait.
func (cc *ClientConn) condWait() {
if cc.syncHooks != nil {
cc.syncHooks.condWait(cc.cond)
}
cc.cond.Wait()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (cc *ClientConn) newTimer(d time.Duration) timer {
if cc.syncHooks != nil {
return cc.syncHooks.newTimer(d)
}
return newTimeTimer(d)
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
if cc.syncHooks != nil {
return cc.syncHooks.afterFunc(d, f)
}
return newTimeAfterFunc(d, f)
}
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if cc.syncHooks != nil {
return cc.syncHooks.contextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
@@ -487,7 +425,7 @@ func (cs *clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
cs.cc.condBroadcast()
cs.cc.cond.Broadcast()
}
}
@@ -497,7 +435,7 @@ func (cs *clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
cc.condBroadcast()
cc.cond.Broadcast()
}
}
@@ -507,10 +445,10 @@ func (cs *clientStream) closeReqBodyLocked() {
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
cs.cc.goRun(func() {
go func() {
cs.reqBody.Close()
close(reqBodyClosed)
})
}()
}
type stickyErrWriter struct {
@@ -580,14 +518,11 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
host = authority
port = ""
}
if port == "" { // authority's port was empty
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
@@ -599,6 +534,15 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port)
}
var retryBackoffHook func(time.Duration) *time.Timer
func backoffNewTimer(d time.Duration) *time.Timer {
if retryBackoffHook != nil {
return retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@@ -626,27 +570,13 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
var tm timer
if t.syncHooks != nil {
tm = t.syncHooks.newTimer(d)
t.syncHooks.blockUntil(func() bool {
select {
case <-tm.C():
case <-req.Context().Done():
default:
return false
}
return true
})
} else {
tm = newTimeTimer(d)
}
timer := backoffNewTimer(d)
select {
case <-tm.C():
case <-timer.C:
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue
case <-req.Context().Done():
tm.Stop()
timer.Stop()
err = req.Context().Err()
}
}
@@ -725,9 +655,6 @@ func canRetryError(err error) bool {
}
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
if t.syncHooks != nil {
return t.newClientConn(nil, singleUse, t.syncHooks)
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@@ -736,7 +663,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil {
return nil, err
}
return t.newClientConn(tconn, singleUse, nil)
return t.newClientConn(tconn, singleUse)
}
func (t *Transport) newTLSConfig(host string) *tls.Config {
@@ -802,10 +729,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 {
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives(), nil)
return t.newClientConn(c, t.disableKeepAlives())
}
func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) {
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
cc := &ClientConn{
t: t,
tconn: c,
@@ -820,15 +747,10 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
syncHooks: hooks,
}
if hooks != nil {
hooks.newclientconn(cc)
c = cc.tconn
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -893,7 +815,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHoo
return nil, cc.werr
}
cc.goRun(cc.readLoop)
go cc.readLoop()
return cc, nil
}
@@ -901,7 +823,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -936,20 +858,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
}
last := f.LastStreamID
for streamID, cs := range cc.streams {
if streamID <= last {
// The server's GOAWAY indicates that it received this stream.
// It will either finish processing it, or close the connection
// without doing so. Either way, leave the stream alone for now.
continue
}
if streamID == 1 && cc.goAway.ErrCode != ErrCodeNo {
// Don't retry the first stream on a connection if we get a non-NO error.
// If the server is sending an error on a new connection,
// retrying the request on a new one probably isn't going to work.
cs.abortStreamLocked(fmt.Errorf("http2: Transport received GOAWAY from server ErrCode:%v", cc.goAway.ErrCode))
} else {
// Aborting the stream with errClentConnGotGoAway indicates that
// the request should be retried on a new connection.
if streamID > last {
cs.abortStreamLocked(errClientConnGotGoAway)
}
}
@@ -1106,7 +1015,7 @@ func (cc *ClientConn) forceCloseConn() {
if !ok {
return
}
if nc := tc.NetConn(); nc != nil {
if nc := tlsUnderlyingConn(tc); nc != nil {
nc.Close()
}
}
@@ -1144,7 +1053,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
cc.goRun(func() {
go func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@@ -1156,9 +1065,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
if cancelled {
break
}
cc.condWait()
cc.cond.Wait()
}
})
}()
shutdownEnterWaitStateHook()
select {
case <-done:
@@ -1168,7 +1077,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock()
// Free the goroutine above
cancelled = true
cc.condBroadcast()
cc.cond.Broadcast()
cc.mu.Unlock()
return ctx.Err()
}
@@ -1206,7 +1115,7 @@ func (cc *ClientConn) closeForError(err error) {
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
cc.condBroadcast()
cc.cond.Broadcast()
cc.mu.Unlock()
cc.closeConn()
}
@@ -1303,10 +1212,6 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
}
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return cc.roundTrip(req, nil)
}
func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) {
ctx := req.Context()
cs := &clientStream{
cc: cc,
@@ -1321,23 +1226,9 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
cc.goRun(func() {
cs.doRequest(req)
})
go cs.doRequest(req)
waitDone := func() error {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.donec:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.donec:
return nil
@@ -1377,45 +1268,26 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
cancelRequest := func(cs *clientStream, err error) error {
cs.cc.mu.Lock()
bodyClosed := cs.reqBodyClosed
cs.cc.mu.Unlock()
// Wait for the request body to be closed.
//
// If nothing closed the body before now, abortStreamLocked
// will have started a goroutine to close it.
//
// Closing the body before returning avoids a race condition
// with net/http checking its readTrackingBody to see if the
// body was read from or closed. See golang/go#60041.
//
// The body is closed in a separate goroutine without the
// connection mutex held, but dropping the mutex before waiting
// will keep us from holding it indefinitely if the body
// close is slow for some reason.
if bodyClosed != nil {
<-bodyClosed
defer cs.cc.mu.Unlock()
cs.abortStreamLocked(err)
if cs.ID != 0 {
// This request may have failed because of a problem with the connection,
// or for some unrelated reason. (For example, the user might have canceled
// the request without waiting for a response.) Mark the connection as
// not reusable, since trying to reuse a dead connection is worse than
// unnecessarily creating a new one.
//
// If cs.ID is 0, then the request was never allocated a stream ID and
// whatever went wrong was unrelated to the connection. We might have
// timed out waiting for a stream slot when StrictMaxConcurrentStreams
// is set, for example, in which case retrying on a different connection
// will not help.
cs.cc.doNotReuse = true
}
return err
}
if streamf != nil {
streamf(cs)
}
for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
@@ -1429,14 +1301,11 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
return handleResponseHeaders()
default:
waitDone()
return nil, cs.abortErr
return nil, cancelRequest(cs, cs.abortErr)
}
case <-ctx.Done():
err := ctx.Err()
cs.abortStream(err)
return nil, cancelRequest(cs, err)
return nil, cancelRequest(cs, ctx.Err())
case <-cs.reqCancel:
cs.abortStream(errRequestCanceled)
return nil, cancelRequest(cs, errRequestCanceled)
}
}
@@ -1471,21 +1340,6 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
var newStreamHook func(*clientStream)
if cc.syncHooks != nil {
newStreamHook = cc.syncHooks.newstream
cc.syncHooks.blockUntil(func() bool {
select {
case cc.reqHeaderMu <- struct{}{}:
<-cc.reqHeaderMu
case <-cs.reqCancel:
case <-ctx.Done():
default:
return false
}
return true
})
}
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -1510,10 +1364,6 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
}
cc.mu.Unlock()
if newStreamHook != nil {
newStreamHook(cs)
}
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
@@ -1594,30 +1444,15 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
timer := cc.newTimer(d)
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C()
respHeaderTimer = timer.C
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.peerClosed:
case <-respHeaderTimer:
case <-respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select {
case <-cs.peerClosed:
return nil
@@ -1766,7 +1601,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
return nil
}
cc.pendingRequests++
cc.condWait()
cc.cond.Wait()
cc.pendingRequests--
select {
case <-cs.abort:
@@ -1837,27 +1672,7 @@ func (cs *clientStream) frameScratchBufferLen(maxFrameSize int) int {
return int(n) // doesn't truncate; max is 512K
}
// Seven bufPools manage different frame sizes. This helps to avoid scenarios where long-running
// streaming requests using small frame sizes occupy large buffers initially allocated for prior
// requests needing big buffers. The size ranges are as follows:
// {0 KB, 16 KB], {16 KB, 32 KB], {32 KB, 64 KB], {64 KB, 128 KB], {128 KB, 256 KB],
// {256 KB, 512 KB], {512 KB, infinity}
// In practice, the maximum scratch buffer size should not exceed 512 KB due to
// frameScratchBufferLen(maxFrameSize), thus the "infinity pool" should never be used.
// It exists mainly as a safety measure, for potential future increases in max buffer size.
var bufPools [7]sync.Pool // of *[]byte
func bufPoolIndex(size int) int {
if size <= 16384 {
return 0
}
size -= 1
bits := bits.Len(uint(size))
index := bits - 14
if index >= len(bufPools) {
return len(bufPools) - 1
}
return index
}
var bufPool sync.Pool // of *[]byte
func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
cc := cs.cc
@@ -1875,13 +1690,12 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
// Scratch buffer for reading into & writing from.
scratchLen := cs.frameScratchBufferLen(maxFrameSize)
var buf []byte
index := bufPoolIndex(scratchLen)
if bp, ok := bufPools[index].Get().(*[]byte); ok && len(*bp) >= scratchLen {
defer bufPools[index].Put(bp)
if bp, ok := bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen {
defer bufPool.Put(bp)
buf = *bp
} else {
buf = make([]byte, scratchLen)
defer bufPools[index].Put(&buf)
defer bufPool.Put(&buf)
}
var sawEOF bool
@@ -2028,26 +1842,10 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
cs.flow.take(take)
return take, nil
}
cc.condWait()
cc.cond.Wait()
}
}
func validateHeaders(hdrs http.Header) string {
for k, vv := range hdrs {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Sprintf("name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error,
// because it may be sensitive.
return fmt.Sprintf("value for header %q", k)
}
}
}
return ""
}
var errNilRequestURL = errors.New("http2: Request.URI is nil")
// requires cc.wmu be held.
@@ -2065,9 +1863,6 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
if err != nil {
return nil, err
}
if !httpguts.ValidHostHeader(host) {
return nil, errors.New("http2: invalid Host header")
}
var path string
if req.Method != "CONNECT" {
@@ -2085,21 +1880,26 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
}
}
// Check for any invalid headers+trailers and return an error before we
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
if err := validateHeaders(req.Header); err != "" {
return nil, fmt.Errorf("invalid HTTP header %s", err)
}
if err := validateHeaders(req.Trailer); err != "" {
return nil, fmt.Errorf("invalid HTTP trailer %s", err)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error, because it may be sensitive.
return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
}
}
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production, see Sections 3.3 and 3.4 of
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
m := req.Method
@@ -2311,7 +2111,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
cc.condBroadcast()
cc.cond.Broadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@@ -2399,7 +2199,7 @@ func (rl *clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err)
}
}
cc.condBroadcast()
cc.cond.Broadcast()
cc.mu.Unlock()
}
@@ -2434,9 +2234,10 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t timer
var t *time.Timer
if readIdleTimeout != 0 {
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
}
for {
f, err := cc.fr.ReadFrame()
@@ -2851,7 +2652,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
})
return nil
}
if !cs.pastHeaders {
if !cs.firstByte {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
@@ -3034,7 +2835,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
for _, cs := range cc.streams {
cs.flow.add(delta)
}
cc.condBroadcast()
cc.cond.Broadcast()
cc.initialWindowSize = s.Val
case SettingHeaderTableSize:
@@ -3078,18 +2879,9 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
fl = &cs.flow
}
if !fl.add(int32(f.Increment)) {
// For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR
if cs != nil {
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
Code: ErrCodeFlowControl,
})
return nil
}
return ConnectionError(ErrCodeFlowControl)
}
cc.condBroadcast()
cc.cond.Broadcast()
return nil
}
@@ -3131,38 +2923,24 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
}
cc.mu.Unlock()
}
var pingError error
errc := make(chan struct{})
cc.goRun(func() {
errc := make(chan error, 1)
go func() {
cc.wmu.Lock()
defer cc.wmu.Unlock()
if pingError = cc.fr.WritePing(false, p); pingError != nil {
close(errc)
if err := cc.fr.WritePing(false, p); err != nil {
errc <- err
return
}
if pingError = cc.bw.Flush(); pingError != nil {
close(errc)
if err := cc.bw.Flush(); err != nil {
errc <- err
return
}
})
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-c:
case <-errc:
case <-ctx.Done():
case <-cc.readerDone:
default:
return false
}
return true
})
}
}()
select {
case <-c:
return nil
case <-errc:
return pingError
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:
@@ -3331,17 +3109,9 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
}
func (t *Transport) idleConnTimeout() time.Duration {
// to keep things backwards compatible, we use non-zero values of
// IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
// http1 transport, followed by 0
if t.IdleConnTimeout != 0 {
return t.IdleConnTimeout
}
if t.t1 != nil {
return t.t1.IdleConnTimeout
}
return 0
}
@@ -3399,34 +3169,3 @@ func traceFirstResponseByte(trace *httptrace.ClientTrace) {
trace.GotFirstResponseByte()
}
}
func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
return trace != nil && trace.WroteHeaderField != nil
}
func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(k, []string{v})
}
}
func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
if trace != nil {
return trace.Got1xxResponse
}
return nil
}
// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
return tlsCn, nil
}

View File

@@ -184,8 +184,7 @@ func (wr *FrameWriteRequest) replyToWriter(err error) {
// writeQueue is used by implementations of WriteScheduler.
type writeQueue struct {
s []FrameWriteRequest
prev, next *writeQueue
s []FrameWriteRequest
}
func (q *writeQueue) empty() bool { return len(q.s) == 0 }