|
1 | 1 | package copilot |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bufio" |
4 | 5 | "encoding/json" |
| 6 | + "fmt" |
| 7 | + "io" |
5 | 8 | "os" |
6 | 9 | "path/filepath" |
7 | 10 | "reflect" |
8 | 11 | "regexp" |
9 | 12 | "sync" |
10 | 13 | "testing" |
| 14 | + |
| 15 | + "github.com/github/copilot-sdk/go/internal/jsonrpc2" |
11 | 16 | ) |
12 | 17 |
|
13 | 18 | // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead |
@@ -569,3 +574,220 @@ func TestClient_StartStopRace(t *testing.T) { |
569 | 574 | t.Fatal(err) |
570 | 575 | } |
571 | 576 | } |
| 577 | + |
| 578 | +// fakeJSONRPCServer reads one JSON-RPC request from r and sends a response to w. |
| 579 | +// onRequest is called with the parsed method and params before the response is sent, |
| 580 | +// allowing the caller to inspect state (e.g. the sessions map) during the RPC. |
| 581 | +func fakeJSONRPCServer(t *testing.T, r io.Reader, w io.Writer, onRequest func(method string, params json.RawMessage)) { |
| 582 | + t.Helper() |
| 583 | + reader := bufio.NewReader(r) |
| 584 | + |
| 585 | + // Read Content-Length header |
| 586 | + var contentLength int |
| 587 | + for { |
| 588 | + line, err := reader.ReadString('\n') |
| 589 | + if err != nil { |
| 590 | + t.Errorf("failed to read header: %v", err) |
| 591 | + return |
| 592 | + } |
| 593 | + if line == "\r\n" || line == "\n" { |
| 594 | + break |
| 595 | + } |
| 596 | + fmt.Sscanf(line, "Content-Length: %d", &contentLength) |
| 597 | + } |
| 598 | + |
| 599 | + // Read body |
| 600 | + body := make([]byte, contentLength) |
| 601 | + if _, err := io.ReadFull(reader, body); err != nil { |
| 602 | + t.Errorf("failed to read body: %v", err) |
| 603 | + return |
| 604 | + } |
| 605 | + |
| 606 | + // Parse request |
| 607 | + var req struct { |
| 608 | + ID json.RawMessage `json:"id"` |
| 609 | + Method string `json:"method"` |
| 610 | + Params json.RawMessage `json:"params"` |
| 611 | + } |
| 612 | + if err := json.Unmarshal(body, &req); err != nil { |
| 613 | + t.Errorf("failed to unmarshal request: %v", err) |
| 614 | + return |
| 615 | + } |
| 616 | + |
| 617 | + onRequest(req.Method, req.Params) |
| 618 | + |
| 619 | + // Send response |
| 620 | + result, _ := json.Marshal(map[string]any{"sessionId": "test", "workspacePath": "/tmp"}) |
| 621 | + resp, _ := json.Marshal(map[string]any{ |
| 622 | + "jsonrpc": "2.0", |
| 623 | + "id": req.ID, |
| 624 | + "result": json.RawMessage(result), |
| 625 | + }) |
| 626 | + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp)) |
| 627 | + w.Write([]byte(header)) |
| 628 | + w.Write(resp) |
| 629 | +} |
| 630 | + |
| 631 | +// fakeJSONRPCErrorServer reads one JSON-RPC request and returns an error response. |
| 632 | +func fakeJSONRPCErrorServer(t *testing.T, r io.Reader, w io.Writer) { |
| 633 | + t.Helper() |
| 634 | + reader := bufio.NewReader(r) |
| 635 | + |
| 636 | + var contentLength int |
| 637 | + for { |
| 638 | + line, err := reader.ReadString('\n') |
| 639 | + if err != nil { |
| 640 | + return |
| 641 | + } |
| 642 | + if line == "\r\n" || line == "\n" { |
| 643 | + break |
| 644 | + } |
| 645 | + fmt.Sscanf(line, "Content-Length: %d", &contentLength) |
| 646 | + } |
| 647 | + |
| 648 | + body := make([]byte, contentLength) |
| 649 | + if _, err := io.ReadFull(reader, body); err != nil { |
| 650 | + return |
| 651 | + } |
| 652 | + |
| 653 | + var req struct { |
| 654 | + ID json.RawMessage `json:"id"` |
| 655 | + } |
| 656 | + json.Unmarshal(body, &req) |
| 657 | + |
| 658 | + resp, _ := json.Marshal(map[string]any{ |
| 659 | + "jsonrpc": "2.0", |
| 660 | + "id": req.ID, |
| 661 | + "error": map[string]any{"code": -32000, "message": "test error"}, |
| 662 | + }) |
| 663 | + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp)) |
| 664 | + w.Write([]byte(header)) |
| 665 | + w.Write(resp) |
| 666 | +} |
| 667 | + |
| 668 | +// newTestClientWithFakeServer creates a Client wired to a fake jsonrpc2.Client |
| 669 | +// backed by the provided io pipes. The caller must call jrpcClient.Stop() when done. |
| 670 | +func newTestClientWithFakeServer(clientWriter io.WriteCloser, clientReader io.ReadCloser) (*Client, *jsonrpc2.Client) { |
| 671 | + jrpcClient := jsonrpc2.NewClient(clientWriter, clientReader) |
| 672 | + jrpcClient.Start() |
| 673 | + |
| 674 | + client := NewClient(nil) |
| 675 | + client.client = jrpcClient |
| 676 | + client.state = StateConnected |
| 677 | + client.sessions = make(map[string]*Session) |
| 678 | + return client, jrpcClient |
| 679 | +} |
| 680 | + |
| 681 | +func TestClient_CreateSession_RegistersSessionBeforeRPC(t *testing.T) { |
| 682 | + // Create pipes: client writes to serverReader, server writes to clientReader |
| 683 | + serverReader, clientWriter := io.Pipe() |
| 684 | + clientReader, serverWriter := io.Pipe() |
| 685 | + client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader) |
| 686 | + defer jrpcClient.Stop() |
| 687 | + |
| 688 | + sessionInMap := false |
| 689 | + go fakeJSONRPCServer(t, serverReader, serverWriter, func(method string, params json.RawMessage) { |
| 690 | + if method != "session.create" { |
| 691 | + t.Errorf("expected session.create, got %s", method) |
| 692 | + } |
| 693 | + var p struct { |
| 694 | + SessionID string `json:"sessionId"` |
| 695 | + } |
| 696 | + json.Unmarshal(params, &p) |
| 697 | + client.sessionsMux.Lock() |
| 698 | + _, sessionInMap = client.sessions[p.SessionID] |
| 699 | + client.sessionsMux.Unlock() |
| 700 | + }) |
| 701 | + |
| 702 | + session, err := client.CreateSession(t.Context(), &SessionConfig{ |
| 703 | + OnPermissionRequest: PermissionHandler.ApproveAll, |
| 704 | + }) |
| 705 | + if err != nil { |
| 706 | + t.Fatalf("CreateSession failed: %v", err) |
| 707 | + } |
| 708 | + if session == nil { |
| 709 | + t.Fatal("expected non-nil session") |
| 710 | + } |
| 711 | + if !sessionInMap { |
| 712 | + t.Error("session was not in sessions map when session.create RPC was issued") |
| 713 | + } |
| 714 | +} |
| 715 | + |
| 716 | +func TestClient_ResumeSession_RegistersSessionBeforeRPC(t *testing.T) { |
| 717 | + serverReader, clientWriter := io.Pipe() |
| 718 | + clientReader, serverWriter := io.Pipe() |
| 719 | + client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader) |
| 720 | + defer jrpcClient.Stop() |
| 721 | + |
| 722 | + sessionInMap := false |
| 723 | + go fakeJSONRPCServer(t, serverReader, serverWriter, func(method string, params json.RawMessage) { |
| 724 | + if method != "session.resume" { |
| 725 | + t.Errorf("expected session.resume, got %s", method) |
| 726 | + } |
| 727 | + var p struct { |
| 728 | + SessionID string `json:"sessionId"` |
| 729 | + } |
| 730 | + json.Unmarshal(params, &p) |
| 731 | + client.sessionsMux.Lock() |
| 732 | + _, sessionInMap = client.sessions[p.SessionID] |
| 733 | + client.sessionsMux.Unlock() |
| 734 | + }) |
| 735 | + |
| 736 | + session, err := client.ResumeSessionWithOptions(t.Context(), "test-session-id", &ResumeSessionConfig{ |
| 737 | + OnPermissionRequest: PermissionHandler.ApproveAll, |
| 738 | + }) |
| 739 | + if err != nil { |
| 740 | + t.Fatalf("ResumeSessionWithOptions failed: %v", err) |
| 741 | + } |
| 742 | + if session == nil { |
| 743 | + t.Fatal("expected non-nil session") |
| 744 | + } |
| 745 | + if !sessionInMap { |
| 746 | + t.Error("session was not in sessions map when session.resume RPC was issued") |
| 747 | + } |
| 748 | +} |
| 749 | + |
| 750 | +func TestClient_CreateSession_CleansUpOnRPCFailure(t *testing.T) { |
| 751 | + serverReader, clientWriter := io.Pipe() |
| 752 | + clientReader, serverWriter := io.Pipe() |
| 753 | + client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader) |
| 754 | + defer jrpcClient.Stop() |
| 755 | + |
| 756 | + // Send a JSON-RPC error response to simulate failure |
| 757 | + go fakeJSONRPCErrorServer(t, serverReader, serverWriter) |
| 758 | + |
| 759 | + _, err := client.CreateSession(t.Context(), &SessionConfig{ |
| 760 | + OnPermissionRequest: PermissionHandler.ApproveAll, |
| 761 | + }) |
| 762 | + if err == nil { |
| 763 | + t.Fatal("expected error from CreateSession") |
| 764 | + } |
| 765 | + client.sessionsMux.Lock() |
| 766 | + count := len(client.sessions) |
| 767 | + client.sessionsMux.Unlock() |
| 768 | + if count != 0 { |
| 769 | + t.Errorf("expected 0 sessions after failed create, got %d", count) |
| 770 | + } |
| 771 | +} |
| 772 | + |
| 773 | +func TestClient_ResumeSession_CleansUpOnRPCFailure(t *testing.T) { |
| 774 | + serverReader, clientWriter := io.Pipe() |
| 775 | + clientReader, serverWriter := io.Pipe() |
| 776 | + client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader) |
| 777 | + defer jrpcClient.Stop() |
| 778 | + |
| 779 | + go fakeJSONRPCErrorServer(t, serverReader, serverWriter) |
| 780 | + |
| 781 | + _, err := client.ResumeSessionWithOptions(t.Context(), "test-session-id", &ResumeSessionConfig{ |
| 782 | + OnPermissionRequest: PermissionHandler.ApproveAll, |
| 783 | + }) |
| 784 | + if err == nil { |
| 785 | + t.Fatal("expected error from ResumeSessionWithOptions") |
| 786 | + } |
| 787 | + client.sessionsMux.Lock() |
| 788 | + count := len(client.sessions) |
| 789 | + client.sessionsMux.Unlock() |
| 790 | + if count != 0 { |
| 791 | + t.Errorf("expected 0 sessions after failed resume, got %d", count) |
| 792 | + } |
| 793 | +} |
0 commit comments