Merge pull request #51157 from corhere/split-stdcopy

api/pkg/stdcopy: move stdWriter to daemon/internal
This commit is contained in:
Sebastiaan van Stijn
2025-10-10 17:33:43 +02:00
committed by GitHub
8 changed files with 123 additions and 167 deletions

View File

@@ -1,12 +1,10 @@
package stdcopy
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
)
// StdType is the type of standard stream
@@ -28,71 +26,6 @@ const (
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
)
var bufPool = &sync.Pool{New: func() any { return bytes.NewBuffer(nil) }}
// stdWriter is wrapper of io.Writer with extra customized info.
type stdWriter struct {
io.Writer
prefix byte
}
// Write sends the buffer to the underlying writer.
// It inserts the prefix header before the buffer,
// so [StdCopy] knows where to multiplex the output.
//
// It implements [io.Writer].
func (w *stdWriter) Write(p []byte) (int, error) {
if w == nil || w.Writer == nil {
return 0, errors.New("writer not instantiated")
}
if p == nil {
return 0, nil
}
header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(p)))
buf := bufPool.Get().(*bytes.Buffer)
buf.Write(header[:])
buf.Write(p)
n, err := w.Writer.Write(buf.Bytes())
n -= stdWriterPrefixLen
if n < 0 {
n = 0
}
buf.Reset()
bufPool.Put(buf)
return n, err
}
// NewStdWriter instantiates a new writer using a custom format to multiplex
// multiple streams to a single writer. All messages written using this writer
// are encapsulated using a custom format, and written to the underlying
// stream "w".
//
// Writers created through NewStdWriter allow for multiple write streams
// (e.g., stdout ([Stdout]) and stderr ([Stderr]) to be multiplexed into a
// single connection. "streamType" indicates the type of stream to encapsulate,
// commonly, [Stdout] or [Stderr]. The [Systemerr] stream can be used to
// include server-side errors in the stream. Information on this stream
// is returned as an error by [StdCopy] and terminates processing the
// stream.
//
// The [Stdin] stream is present for completeness and should generally
// NOT be used. It is output on [Stdout] when reading the stream with
// [StdCopy].
//
// All streams must share the same underlying [io.Writer] to ensure proper
// multiplexing. Each call to NewStdWriter wraps that shared writer with
// a header indicating the target stream.
func NewStdWriter(w io.Writer, streamType StdType) io.Writer {
return &stdWriter{
Writer: w,
prefix: byte(streamType),
}
}
// StdCopy is a modified version of [io.Copy] to de-multiplex messages
// from "multiplexedSource" and copy them to destination streams
// "destOut" and "destErr".

View File

@@ -1,66 +0,0 @@
package stdcopy_test
import (
"errors"
"fmt"
"io"
"os"
"time"
"github.com/moby/moby/api/pkg/stdcopy"
)
func ExampleNewStdWriter() {
muxReader, muxStream := io.Pipe()
defer func() { _ = muxStream.Close() }()
// Start demuxing before the daemon starts writing.
done := make(chan error, 1)
go func() {
// using os.Stdout for both, otherwise output doesn't show up in the example.
osStdout := os.Stdout
osStderr := os.Stdout
_, err := stdcopy.StdCopy(osStdout, osStderr, muxReader)
done <- err
}()
// daemon writing to stdout, stderr, and systemErr.
stdout := stdcopy.NewStdWriter(muxStream, stdcopy.Stdout)
stderr := stdcopy.NewStdWriter(muxStream, stdcopy.Stderr)
systemErr := stdcopy.NewStdWriter(muxStream, stdcopy.Systemerr)
for range 10 {
_, _ = fmt.Fprintln(stdout, "hello from stdout")
_, _ = fmt.Fprintln(stderr, "hello from stderr")
time.Sleep(50 * time.Millisecond)
}
_, _ = fmt.Fprintln(systemErr, errors.New("something went wrong"))
// Wait for the demuxer to finish.
if err := <-done; err != nil {
fmt.Println(err)
}
// Output:
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// hello from stdout
// hello from stderr
// error from daemon in stream: something went wrong
}

View File

@@ -1,292 +0,0 @@
package stdcopy
import (
"bytes"
"errors"
"io"
"strings"
"testing"
)
func TestNewStdWriter(t *testing.T) {
writer := NewStdWriter(io.Discard, Stdout)
if writer == nil {
t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
}
}
func TestWriteWithUninitializedStdWriter(t *testing.T) {
writer := stdWriter{
Writer: nil,
prefix: byte(Stdout),
}
n, err := writer.Write([]byte("Something here"))
if n != 0 || err == nil {
t.Fatalf("Should fail when given an incomplete or uninitialized StdWriter")
}
}
func TestWriteWithNilBytes(t *testing.T) {
writer := NewStdWriter(io.Discard, Stdout)
n, err := writer.Write(nil)
if err != nil {
t.Fatalf("Shouldn't have fail when given no data")
}
if n > 0 {
t.Fatalf("Write should have written 0 byte, but has written %d", n)
}
}
func TestWrite(t *testing.T) {
writer := NewStdWriter(io.Discard, Stdout)
data := []byte("Test StdWrite.Write")
n, err := writer.Write(data)
if err != nil {
t.Fatalf("Error while writing with StdWrite")
}
if n != len(data) {
t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
}
}
type errWriter struct {
n int
err error
}
func (f *errWriter) Write(buf []byte) (int, error) {
return f.n, f.err
}
func TestWriteWithWriterError(t *testing.T) {
expectedError := errors.New("expected")
expectedReturnedBytes := 10
writer := NewStdWriter(&errWriter{
n: stdWriterPrefixLen + expectedReturnedBytes,
err: expectedError,
}, Stdout)
data := []byte("This won't get written, sigh")
n, err := writer.Write(data)
if !errors.Is(err, expectedError) {
t.Fatalf("Didn't get expected error.")
}
if n != expectedReturnedBytes {
t.Fatalf("Didn't get expected written bytes %d, got %d.",
expectedReturnedBytes, n)
}
}
func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
writer := NewStdWriter(&errWriter{n: -1}, Stdout)
data := []byte("This won't get written, sigh")
actual, _ := writer.Write(data)
if actual != 0 {
t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
}
}
func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (*bytes.Buffer, error) {
buffer := new(bytes.Buffer)
dstOut := NewStdWriter(buffer, Stdout)
_, err := dstOut.Write(stdOutBytes)
if err != nil {
return buffer, err
}
dstErr := NewStdWriter(buffer, Stderr)
_, err = dstErr.Write(stdErrBytes)
return buffer, err
}
func TestStdCopyWriteAndRead(t *testing.T) {
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
if err != nil {
t.Fatal(err)
}
written, err := StdCopy(io.Discard, io.Discard, buffer)
if err != nil {
t.Fatal(err)
}
expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
if written != int64(expectedTotalWritten) {
t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
}
}
type customReader struct {
n int
err error
totalCalls int
correctCalls int
src *bytes.Buffer
}
func (f *customReader) Read(buf []byte) (int, error) {
f.totalCalls++
if f.totalCalls <= f.correctCalls {
return f.src.Read(buf)
}
return f.n, f.err
}
func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
expectedError := errors.New("error")
reader := &customReader{
err: expectedError,
}
written, err := StdCopy(io.Discard, io.Discard, reader)
if written != 0 {
t.Fatalf("Expected 0 bytes read, got %d", written)
}
if !errors.Is(err, expectedError) {
t.Fatalf("Didn't get expected error")
}
}
func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
expectedError := errors.New("error")
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
if err != nil {
t.Fatal(err)
}
reader := &customReader{
correctCalls: 1,
n: stdWriterPrefixLen + 1,
err: expectedError,
src: buffer,
}
written, err := StdCopy(io.Discard, io.Discard, reader)
if written != 0 {
t.Fatalf("Expected 0 bytes read, got %d", written)
}
if !errors.Is(err, expectedError) {
t.Fatalf("Didn't get expected error")
}
}
func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
if err != nil {
t.Fatal(err)
}
reader := &customReader{
correctCalls: 1,
n: stdWriterPrefixLen + 1,
err: io.EOF,
src: buffer,
}
written, err := StdCopy(io.Discard, io.Discard, reader)
if written != startingBufLen {
t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
}
if err != nil {
t.Fatal("Didn't get nil error")
}
}
func TestStdCopyWithInvalidInputHeader(t *testing.T) {
dstOut := NewStdWriter(io.Discard, Stdout)
dstErr := NewStdWriter(io.Discard, Stderr)
src := strings.NewReader("Invalid input")
_, err := StdCopy(dstOut, dstErr, src)
if err == nil {
t.Fatal("StdCopy with invalid input header should fail.")
}
}
func TestStdCopyWithCorruptedPrefix(t *testing.T) {
data := []byte{0x01, 0x02, 0x03}
src := bytes.NewReader(data)
written, err := StdCopy(nil, nil, src)
if err != nil {
t.Fatalf("StdCopy should not return an error with corrupted prefix.")
}
if written != 0 {
t.Fatalf("StdCopy should have written 0, but has written %d", written)
}
}
func TestStdCopyReturnsWriteErrors(t *testing.T) {
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
if err != nil {
t.Fatal(err)
}
expectedError := errors.New("expected")
dstOut := &errWriter{err: expectedError}
written, err := StdCopy(dstOut, io.Discard, buffer)
if written != 0 {
t.Fatalf("StdCopy should have written 0, but has written %d", written)
}
if !errors.Is(err, expectedError) {
t.Fatalf("Didn't get expected error, got %v", err)
}
}
func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
if err != nil {
t.Fatal(err)
}
dstOut := &errWriter{n: startingBufLen - 10}
written, err := StdCopy(dstOut, io.Discard, buffer)
if written != 0 {
t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
}
if !errors.Is(err, io.ErrShortWrite) {
t.Fatalf("Didn't get expected io.ErrShortWrite error")
}
}
// TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
// error, when that error is muxed into the Systemerr stream.
func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
// write in the basic messages, just so there's some fluff in there
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
if err != nil {
t.Fatal(err)
}
// add in an error message on the Systemerr stream
systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
systemWriter := NewStdWriter(buffer, Systemerr)
_, err = systemWriter.Write(systemErrBytes)
if err != nil {
t.Fatal(err)
}
// now copy and demux. we should expect an error containing the string we
// wrote out
_, err = StdCopy(io.Discard, io.Discard, buffer)
if err == nil {
t.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), string(systemErrBytes)) {
t.Fatal("expected error to contain message")
}
}
func BenchmarkWrite(b *testing.B) {
w := NewStdWriter(io.Discard, Stdout)
data := []byte("Test line for testing stdwriter performance\n")
data = bytes.Repeat(data, 100)
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := w.Write(data); err != nil {
b.Fatal(err)
}
}
}