Skip to content

Commit 4a9acba

Browse files
stephentoubCopilot
andcommitted
Add tests for early session registration before RPC
Add unit tests across Node.js, Python, and Go that verify sessions are registered in the client's sessions map before the session.create and session.resume RPC calls are issued. These tests would have failed before the early-registration change and now pass after. Each SDK gets 4 new tests: - Session is in map when session.create RPC is called - Session is in map when session.resume RPC is called - Session is cleaned up from map when session.create RPC fails - Session is cleaned up from map when session.resume RPC fails Go tests use a fake JSON-RPC server via io.Pipe() to verify map state during the RPC without needing a real CLI process. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 427205f commit 4a9acba

File tree

3 files changed

+422
-0
lines changed

3 files changed

+422
-0
lines changed

go/client_test.go

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
package copilot
22

33
import (
4+
"bufio"
45
"encoding/json"
6+
"fmt"
7+
"io"
58
"os"
69
"path/filepath"
710
"reflect"
811
"regexp"
912
"sync"
1013
"testing"
14+
15+
"github.com/github/copilot-sdk/go/internal/jsonrpc2"
1116
)
1217

1318
// This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead
@@ -569,3 +574,220 @@ func TestClient_StartStopRace(t *testing.T) {
569574
t.Fatal(err)
570575
}
571576
}
577+
578+
// fakeJSONRPCServer reads one JSON-RPC request from r and sends a response to w.
579+
// onRequest is called with the parsed method and params before the response is sent,
580+
// allowing the caller to inspect state (e.g. the sessions map) during the RPC.
581+
func fakeJSONRPCServer(t *testing.T, r io.Reader, w io.Writer, onRequest func(method string, params json.RawMessage)) {
582+
t.Helper()
583+
reader := bufio.NewReader(r)
584+
585+
// Read Content-Length header
586+
var contentLength int
587+
for {
588+
line, err := reader.ReadString('\n')
589+
if err != nil {
590+
t.Errorf("failed to read header: %v", err)
591+
return
592+
}
593+
if line == "\r\n" || line == "\n" {
594+
break
595+
}
596+
fmt.Sscanf(line, "Content-Length: %d", &contentLength)
597+
}
598+
599+
// Read body
600+
body := make([]byte, contentLength)
601+
if _, err := io.ReadFull(reader, body); err != nil {
602+
t.Errorf("failed to read body: %v", err)
603+
return
604+
}
605+
606+
// Parse request
607+
var req struct {
608+
ID json.RawMessage `json:"id"`
609+
Method string `json:"method"`
610+
Params json.RawMessage `json:"params"`
611+
}
612+
if err := json.Unmarshal(body, &req); err != nil {
613+
t.Errorf("failed to unmarshal request: %v", err)
614+
return
615+
}
616+
617+
onRequest(req.Method, req.Params)
618+
619+
// Send response
620+
result, _ := json.Marshal(map[string]any{"sessionId": "test", "workspacePath": "/tmp"})
621+
resp, _ := json.Marshal(map[string]any{
622+
"jsonrpc": "2.0",
623+
"id": req.ID,
624+
"result": json.RawMessage(result),
625+
})
626+
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp))
627+
w.Write([]byte(header))
628+
w.Write(resp)
629+
}
630+
631+
// fakeJSONRPCErrorServer reads one JSON-RPC request and returns an error response.
632+
func fakeJSONRPCErrorServer(t *testing.T, r io.Reader, w io.Writer) {
633+
t.Helper()
634+
reader := bufio.NewReader(r)
635+
636+
var contentLength int
637+
for {
638+
line, err := reader.ReadString('\n')
639+
if err != nil {
640+
return
641+
}
642+
if line == "\r\n" || line == "\n" {
643+
break
644+
}
645+
fmt.Sscanf(line, "Content-Length: %d", &contentLength)
646+
}
647+
648+
body := make([]byte, contentLength)
649+
if _, err := io.ReadFull(reader, body); err != nil {
650+
return
651+
}
652+
653+
var req struct {
654+
ID json.RawMessage `json:"id"`
655+
}
656+
json.Unmarshal(body, &req)
657+
658+
resp, _ := json.Marshal(map[string]any{
659+
"jsonrpc": "2.0",
660+
"id": req.ID,
661+
"error": map[string]any{"code": -32000, "message": "test error"},
662+
})
663+
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp))
664+
w.Write([]byte(header))
665+
w.Write(resp)
666+
}
667+
668+
// newTestClientWithFakeServer creates a Client wired to a fake jsonrpc2.Client
669+
// backed by the provided io pipes. The caller must call jrpcClient.Stop() when done.
670+
func newTestClientWithFakeServer(clientWriter io.WriteCloser, clientReader io.ReadCloser) (*Client, *jsonrpc2.Client) {
671+
jrpcClient := jsonrpc2.NewClient(clientWriter, clientReader)
672+
jrpcClient.Start()
673+
674+
client := NewClient(nil)
675+
client.client = jrpcClient
676+
client.state = StateConnected
677+
client.sessions = make(map[string]*Session)
678+
return client, jrpcClient
679+
}
680+
681+
func TestClient_CreateSession_RegistersSessionBeforeRPC(t *testing.T) {
682+
// Create pipes: client writes to serverReader, server writes to clientReader
683+
serverReader, clientWriter := io.Pipe()
684+
clientReader, serverWriter := io.Pipe()
685+
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
686+
defer jrpcClient.Stop()
687+
688+
sessionInMap := false
689+
go fakeJSONRPCServer(t, serverReader, serverWriter, func(method string, params json.RawMessage) {
690+
if method != "session.create" {
691+
t.Errorf("expected session.create, got %s", method)
692+
}
693+
var p struct {
694+
SessionID string `json:"sessionId"`
695+
}
696+
json.Unmarshal(params, &p)
697+
client.sessionsMux.Lock()
698+
_, sessionInMap = client.sessions[p.SessionID]
699+
client.sessionsMux.Unlock()
700+
})
701+
702+
session, err := client.CreateSession(t.Context(), &SessionConfig{
703+
OnPermissionRequest: PermissionHandler.ApproveAll,
704+
})
705+
if err != nil {
706+
t.Fatalf("CreateSession failed: %v", err)
707+
}
708+
if session == nil {
709+
t.Fatal("expected non-nil session")
710+
}
711+
if !sessionInMap {
712+
t.Error("session was not in sessions map when session.create RPC was issued")
713+
}
714+
}
715+
716+
func TestClient_ResumeSession_RegistersSessionBeforeRPC(t *testing.T) {
717+
serverReader, clientWriter := io.Pipe()
718+
clientReader, serverWriter := io.Pipe()
719+
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
720+
defer jrpcClient.Stop()
721+
722+
sessionInMap := false
723+
go fakeJSONRPCServer(t, serverReader, serverWriter, func(method string, params json.RawMessage) {
724+
if method != "session.resume" {
725+
t.Errorf("expected session.resume, got %s", method)
726+
}
727+
var p struct {
728+
SessionID string `json:"sessionId"`
729+
}
730+
json.Unmarshal(params, &p)
731+
client.sessionsMux.Lock()
732+
_, sessionInMap = client.sessions[p.SessionID]
733+
client.sessionsMux.Unlock()
734+
})
735+
736+
session, err := client.ResumeSessionWithOptions(t.Context(), "test-session-id", &ResumeSessionConfig{
737+
OnPermissionRequest: PermissionHandler.ApproveAll,
738+
})
739+
if err != nil {
740+
t.Fatalf("ResumeSessionWithOptions failed: %v", err)
741+
}
742+
if session == nil {
743+
t.Fatal("expected non-nil session")
744+
}
745+
if !sessionInMap {
746+
t.Error("session was not in sessions map when session.resume RPC was issued")
747+
}
748+
}
749+
750+
func TestClient_CreateSession_CleansUpOnRPCFailure(t *testing.T) {
751+
serverReader, clientWriter := io.Pipe()
752+
clientReader, serverWriter := io.Pipe()
753+
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
754+
defer jrpcClient.Stop()
755+
756+
// Send a JSON-RPC error response to simulate failure
757+
go fakeJSONRPCErrorServer(t, serverReader, serverWriter)
758+
759+
_, err := client.CreateSession(t.Context(), &SessionConfig{
760+
OnPermissionRequest: PermissionHandler.ApproveAll,
761+
})
762+
if err == nil {
763+
t.Fatal("expected error from CreateSession")
764+
}
765+
client.sessionsMux.Lock()
766+
count := len(client.sessions)
767+
client.sessionsMux.Unlock()
768+
if count != 0 {
769+
t.Errorf("expected 0 sessions after failed create, got %d", count)
770+
}
771+
}
772+
773+
func TestClient_ResumeSession_CleansUpOnRPCFailure(t *testing.T) {
774+
serverReader, clientWriter := io.Pipe()
775+
clientReader, serverWriter := io.Pipe()
776+
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
777+
defer jrpcClient.Stop()
778+
779+
go fakeJSONRPCErrorServer(t, serverReader, serverWriter)
780+
781+
_, err := client.ResumeSessionWithOptions(t.Context(), "test-session-id", &ResumeSessionConfig{
782+
OnPermissionRequest: PermissionHandler.ApproveAll,
783+
})
784+
if err == nil {
785+
t.Fatal("expected error from ResumeSessionWithOptions")
786+
}
787+
client.sessionsMux.Lock()
788+
count := len(client.sessions)
789+
client.sessionsMux.Unlock()
790+
if count != 0 {
791+
t.Errorf("expected 0 sessions after failed resume, got %d", count)
792+
}
793+
}

nodejs/test/client.test.ts

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,94 @@ describe("CopilotClient", () => {
294294
});
295295
});
296296

297+
describe("session registered before RPC", () => {
298+
it("registers session in sessions map before session.create RPC completes", async () => {
299+
const client = new CopilotClient();
300+
await client.start();
301+
onTestFinished(() => client.forceStop());
302+
303+
let sessionInMapDuringRpc = false;
304+
vi.spyOn((client as any).connection!, "sendRequest").mockImplementation(
305+
async (method: string, params: any) => {
306+
if (method === "session.create") {
307+
sessionInMapDuringRpc = (client as any).sessions.has(params.sessionId);
308+
return { sessionId: params.sessionId };
309+
}
310+
throw new Error(`Unexpected method: ${method}`);
311+
}
312+
);
313+
314+
await client.createSession({ onPermissionRequest: approveAll });
315+
expect(sessionInMapDuringRpc).toBe(true);
316+
});
317+
318+
it("registers session in sessions map before session.resume RPC completes", async () => {
319+
const client = new CopilotClient();
320+
await client.start();
321+
onTestFinished(() => client.forceStop());
322+
323+
const session = await client.createSession({ onPermissionRequest: approveAll });
324+
325+
let sessionInMapDuringRpc = false;
326+
vi.spyOn((client as any).connection!, "sendRequest").mockImplementation(
327+
async (method: string, params: any) => {
328+
if (method === "session.resume") {
329+
sessionInMapDuringRpc = (client as any).sessions.has(params.sessionId);
330+
return { sessionId: params.sessionId };
331+
}
332+
throw new Error(`Unexpected method: ${method}`);
333+
}
334+
);
335+
336+
await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll });
337+
expect(sessionInMapDuringRpc).toBe(true);
338+
});
339+
340+
it("removes session from sessions map when session.create RPC fails", async () => {
341+
const client = new CopilotClient();
342+
await client.start();
343+
onTestFinished(() => client.forceStop());
344+
345+
vi.spyOn((client as any).connection!, "sendRequest").mockImplementation(
346+
async (method: string) => {
347+
if (method === "session.create") {
348+
throw new Error("RPC failed");
349+
}
350+
throw new Error(`Unexpected method: ${method}`);
351+
}
352+
);
353+
354+
await expect(client.createSession({ onPermissionRequest: approveAll })).rejects.toThrow(
355+
"RPC failed"
356+
);
357+
expect((client as any).sessions.size).toBe(0);
358+
});
359+
360+
it("removes session from sessions map when session.resume RPC fails", async () => {
361+
const client = new CopilotClient();
362+
await client.start();
363+
onTestFinished(() => client.forceStop());
364+
365+
const session = await client.createSession({ onPermissionRequest: approveAll });
366+
const sessionCountBefore = (client as any).sessions.size;
367+
368+
vi.spyOn((client as any).connection!, "sendRequest").mockImplementation(
369+
async (method: string) => {
370+
if (method === "session.resume") {
371+
throw new Error("RPC failed");
372+
}
373+
throw new Error(`Unexpected method: ${method}`);
374+
}
375+
);
376+
377+
await expect(
378+
client.resumeSession("other-session-id", { onPermissionRequest: approveAll })
379+
).rejects.toThrow("RPC failed");
380+
expect((client as any).sessions.size).toBe(sessionCountBefore);
381+
expect((client as any).sessions.has(session.sessionId)).toBe(true);
382+
});
383+
});
384+
297385
describe("overridesBuiltInTool in tool definitions", () => {
298386
it("sends overridesBuiltInTool in tool definition on session.create", async () => {
299387
const client = new CopilotClient();

0 commit comments

Comments
 (0)