Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions frontend/app/store/wshclientapi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,16 @@ class RpcApiType {
return client.wshRpcStream("streamcpudata", data, opts);
}

// command "streamdata" [call]
StreamDataCommand(client: WshClient, data: CommandStreamData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("streamdata", data, opts);
}

// command "streamdataack" [call]
StreamDataAckCommand(client: WshClient, data: CommandStreamAckData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("streamdataack", data, opts);
}

// command "streamtest" [responsestream]
StreamTestCommand(client: WshClient, opts?: RpcOpts): AsyncGenerator<number, void, boolean> {
return client.wshRpcStream("streamtest", null, opts);
Expand Down
21 changes: 21 additions & 0 deletions frontend/types/gotypes.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,26 @@ declare global {
builderid: string;
};

// wshrpc.CommandStreamAckData
type CommandStreamAckData = {
id: number;
seq: number;
rwnd: number;
fin?: boolean;
delay?: number;
cancel?: boolean;
error?: string;
};

// wshrpc.CommandStreamData
type CommandStreamData = {
id: number;
seq: number;
data64?: string;
eof?: boolean;
error?: string;
};

// wshrpc.CommandTermGetScrollbackLinesData
type CommandTermGetScrollbackLinesData = {
linestart: number;
Expand Down Expand Up @@ -1240,6 +1260,7 @@ declare global {
"wsh:cmd"?: string;
"wsh:haderror"?: boolean;
"conn:conntype"?: string;
"conn:wsherrorcode"?: string;
"onboarding:feature"?: "waveai" | "magnify" | "wsh";
"onboarding:version"?: string;
"onboarding:githubstar"?: "already" | "star" | "later";
Expand Down
267 changes: 267 additions & 0 deletions pkg/streamclient/stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package streamclient

import (
"bytes"
"io"
"testing"
"time"

"github.com/wavetermdev/waveterm/pkg/wshrpc"
)

type fakeTransport struct {
dataChan chan wshrpc.CommandStreamData
ackChan chan wshrpc.CommandStreamAckData
}

func newFakeTransport() *fakeTransport {
return &fakeTransport{
dataChan: make(chan wshrpc.CommandStreamData, 10),
ackChan: make(chan wshrpc.CommandStreamAckData, 10),
}
}

func (ft *fakeTransport) SendData(dataPk wshrpc.CommandStreamData) {
ft.dataChan <- dataPk
}

func (ft *fakeTransport) SendAck(ackPk wshrpc.CommandStreamAckData) {
ft.ackChan <- ackPk
}

func TestBasicReadWrite(t *testing.T) {
transport := newFakeTransport()

reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)

go func() {
for dataPk := range transport.dataChan {
reader.RecvData(dataPk)
}
}()

go func() {
for ackPk := range transport.ackChan {
writer.RecvAck(ackPk)
}
}()

testData := []byte("Hello, World!")
n, err := writer.Write(testData)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(testData) {
t.Fatalf("Write returned %d, expected %d", n, len(testData))
}

buf := make([]byte, 1024)
n, err = reader.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != len(testData) {
t.Fatalf("Read returned %d, expected %d", n, len(testData))
}
if !bytes.Equal(buf[:n], testData) {
t.Fatalf("Read data %q doesn't match written data %q", buf[:n], testData)
}
}

func TestEOF(t *testing.T) {
transport := newFakeTransport()

reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)

go func() {
for dataPk := range transport.dataChan {
reader.RecvData(dataPk)
}
}()

go func() {
for ackPk := range transport.ackChan {
writer.RecvAck(ackPk)
}
}()

testData := []byte("Test data")
writer.Write(testData)
writer.Close()

buf := make([]byte, 1024)
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("First read failed: %v", err)
}
if !bytes.Equal(buf[:n], testData) {
t.Fatalf("Read data doesn't match")
}

_, err = reader.Read(buf)
if err != io.EOF {
t.Fatalf("Expected EOF, got %v", err)
}
}

func TestFlowControl(t *testing.T) {
smallWindow := int64(10)
transport := newFakeTransport()

reader := NewReader(1, smallWindow, transport)
writer := NewWriter(1, smallWindow, transport)

go func() {
for dataPk := range transport.dataChan {
reader.RecvData(dataPk)
}
}()

go func() {
for ackPk := range transport.ackChan {
writer.RecvAck(ackPk)
}
}()

largeData := make([]byte, 100)
for i := range largeData {
largeData[i] = byte(i)
}

writeDone := make(chan error)
go func() {
_, err := writer.Write(largeData)
writeDone <- err
}()

received := make([]byte, 0, 100)
buf := make([]byte, 20)
for len(received) < len(largeData) {
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
received = append(received, buf[:n]...)
}

select {
case err := <-writeDone:
if err != nil {
t.Fatalf("Write failed: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Write didn't complete in time")
}

if !bytes.Equal(received, largeData) {
t.Fatal("Received data doesn't match sent data")
}
}

func TestError(t *testing.T) {
transport := newFakeTransport()

reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)

go func() {
for dataPk := range transport.dataChan {
reader.RecvData(dataPk)
}
}()

go func() {
for ackPk := range transport.ackChan {
writer.RecvAck(ackPk)
}
}()

testErr := io.ErrUnexpectedEOF
writer.CloseWithError(testErr)

buf := make([]byte, 1024)
_, err := reader.Read(buf)
if err == nil {
t.Fatal("Expected error from read")
}
if err.Error() != "stream error: unexpected EOF" {
t.Fatalf("Expected stream error, got: %v", err)
}
}

func TestCancel(t *testing.T) {
transport := newFakeTransport()

reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)

go func() {
for dataPk := range transport.dataChan {
reader.RecvData(dataPk)
}
}()

go func() {
for ackPk := range transport.ackChan {
writer.RecvAck(ackPk)
}
}()

reader.Close()

select {
case <-writer.GetCanceledChan():
// Success
case <-time.After(1 * time.Second):
t.Fatal("Writer not notified of cancellation")
}

_, _, canceled := writer.GetAckState()
if !canceled {
t.Fatal("Writer should be in canceled state")
}
}

func TestMultipleWrites(t *testing.T) {
transport := newFakeTransport()

reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)

go func() {
for dataPk := range transport.dataChan {
reader.RecvData(dataPk)
}
}()

go func() {
for ackPk := range transport.ackChan {
writer.RecvAck(ackPk)
}
}()

messages := []string{"First", "Second", "Third"}
for _, msg := range messages {
_, err := writer.Write([]byte(msg))
if err != nil {
t.Fatalf("Write failed: %v", err)
}
}

expected := "FirstSecondThird"
buf := make([]byte, len(expected))
totalRead := 0
for totalRead < len(expected) {
n, err := reader.Read(buf[totalRead:])
if err != nil {
t.Fatalf("Read failed: %v", err)
}
totalRead += n
}

if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, string(buf))
}
}
Loading
Loading