Skip to content

Commit e1cfb91

Browse files
feat: poll for linked PR after assigning Copilot to issue
Enhances the assign_copilot_to_issue tool to automatically poll for the PR created by the Copilot coding agent after assignment. Changes: - Add findLinkedCopilotPR() to query issue timeline for CrossReferencedEvent items from PRs authored by copilot-swe-agent - Add polling loop (9 attempts, 1s delay) matching remote server latency - Return structured JSON with PR details when found, or helpful note otherwise - Add PollConfig for configurable polling (used in tests to disable) - Add GraphQLFeaturesTransport for feature flag header support The returned response now includes: - issue_number, issue_url, owner, repo - pull_request object (if found during polling) - Note with instructions to use get_copilot_job_status if PR not yet created
1 parent 31b541e commit e1cfb91

File tree

4 files changed

+349
-3
lines changed

4 files changed

+349
-3
lines changed

pkg/github/issues.go

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,97 @@ func (d *mvpDescription) String() string {
16091609
return sb.String()
16101610
}
16111611

1612+
// linkedPullRequest represents a PR linked to an issue by Copilot.
1613+
type linkedPullRequest struct {
1614+
Number int
1615+
URL string
1616+
Title string
1617+
State string
1618+
}
1619+
1620+
// pollConfigKey is a context key for polling configuration.
1621+
type pollConfigKey struct{}
1622+
1623+
// PollConfig configures the PR polling behavior.
1624+
type PollConfig struct {
1625+
MaxAttempts int
1626+
Delay time.Duration
1627+
}
1628+
1629+
// ContextWithPollConfig returns a context with polling configuration.
1630+
// Use this in tests to reduce or disable polling.
1631+
func ContextWithPollConfig(ctx context.Context, config PollConfig) context.Context {
1632+
return context.WithValue(ctx, pollConfigKey{}, config)
1633+
}
1634+
1635+
// getPollConfig returns the polling configuration from context, or defaults.
1636+
func getPollConfig(ctx context.Context) PollConfig {
1637+
if config, ok := ctx.Value(pollConfigKey{}).(PollConfig); ok {
1638+
return config
1639+
}
1640+
// Default: 9 attempts with 1s delay = 8s max wait
1641+
// Based on observed latency in remote server: p50 ~5s, p90 ~7s
1642+
return PollConfig{MaxAttempts: 9, Delay: 1 * time.Second}
1643+
}
1644+
1645+
// findLinkedCopilotPR searches for a PR created by the copilot-swe-agent bot that references the given issue.
1646+
// It queries the issue's timeline for CrossReferencedEvent items from PRs authored by copilot-swe-agent.
1647+
func findLinkedCopilotPR(ctx context.Context, client *githubv4.Client, owner, repo string, issueNumber int) (*linkedPullRequest, error) {
1648+
// Query timeline items looking for CrossReferencedEvent from PRs by copilot-swe-agent
1649+
var query struct {
1650+
Repository struct {
1651+
Issue struct {
1652+
TimelineItems struct {
1653+
Nodes []struct {
1654+
TypeName string `graphql:"__typename"`
1655+
CrossReferencedEvent struct {
1656+
Source struct {
1657+
PullRequest struct {
1658+
Number int
1659+
URL string
1660+
Title string
1661+
State string
1662+
Author struct {
1663+
Login string
1664+
}
1665+
} `graphql:"... on PullRequest"`
1666+
}
1667+
} `graphql:"... on CrossReferencedEvent"`
1668+
}
1669+
} `graphql:"timelineItems(first: 20, itemTypes: [CROSS_REFERENCED_EVENT])"`
1670+
} `graphql:"issue(number: $number)"`
1671+
} `graphql:"repository(owner: $owner, name: $name)"`
1672+
}
1673+
1674+
variables := map[string]any{
1675+
"owner": githubv4.String(owner),
1676+
"name": githubv4.String(repo),
1677+
"number": githubv4.Int(issueNumber), //nolint:gosec // Issue numbers are always small positive integers
1678+
}
1679+
1680+
if err := client.Query(ctx, &query, variables); err != nil {
1681+
return nil, err
1682+
}
1683+
1684+
// Look for a PR from copilot-swe-agent
1685+
for _, node := range query.Repository.Issue.TimelineItems.Nodes {
1686+
if node.TypeName != "CrossReferencedEvent" {
1687+
continue
1688+
}
1689+
pr := node.CrossReferencedEvent.Source.PullRequest
1690+
if pr.Number > 0 && pr.Author.Login == "copilot-swe-agent" {
1691+
return &linkedPullRequest{
1692+
Number: pr.Number,
1693+
URL: pr.URL,
1694+
Title: pr.Title,
1695+
State: pr.State,
1696+
}, nil
1697+
}
1698+
}
1699+
1700+
return nil, nil
1701+
}
1702+
16121703
func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.ServerTool {
16131704
description := mvpDescription{
16141705
summary: "Assign Copilot to a specific issue in a GitHub repository.",
@@ -1804,7 +1895,55 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server
18041895
return nil, nil, fmt.Errorf("failed to update issue with agent assignment: %w", err)
18051896
}
18061897

1807-
return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil
1898+
// Poll for a linked PR created by Copilot
1899+
pollConfig := getPollConfig(ctx)
1900+
1901+
var linkedPR *linkedPullRequest
1902+
for attempt := range pollConfig.MaxAttempts {
1903+
if attempt > 0 {
1904+
time.Sleep(pollConfig.Delay)
1905+
}
1906+
1907+
pr, err := findLinkedCopilotPR(ctx, client, params.Owner, params.Repo, int(params.IssueNumber))
1908+
if err != nil {
1909+
// Log but don't fail - polling errors are non-fatal
1910+
continue
1911+
}
1912+
if pr != nil {
1913+
linkedPR = pr
1914+
break
1915+
}
1916+
}
1917+
1918+
// Build the result
1919+
result := map[string]any{
1920+
"message": "successfully assigned copilot to issue",
1921+
"issue_number": updateIssueMutation.UpdateIssue.Issue.Number,
1922+
"issue_url": updateIssueMutation.UpdateIssue.Issue.URL,
1923+
"owner": params.Owner,
1924+
"repo": params.Repo,
1925+
}
1926+
1927+
// Add PR info if found during polling
1928+
if linkedPR != nil {
1929+
result["pull_request"] = map[string]any{
1930+
"number": linkedPR.Number,
1931+
"url": linkedPR.URL,
1932+
"title": linkedPR.Title,
1933+
"state": linkedPR.State,
1934+
}
1935+
result["message"] = "successfully assigned copilot to issue - pull request created"
1936+
} else {
1937+
result["message"] = "successfully assigned copilot to issue - pull request pending"
1938+
result["note"] = "The pull request may still be in progress. Use get_copilot_job_status with the pull request number once created, or check the issue timeline for updates."
1939+
}
1940+
1941+
r, err := json.Marshal(result)
1942+
if err != nil {
1943+
return utils.NewToolResultError(fmt.Sprintf("failed to marshal response: %s", err)), nil, nil
1944+
}
1945+
1946+
return utils.NewToolResultText(string(r)), result, nil
18081947
})
18091948
}
18101949

pkg/github/issues_test.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,8 +2654,12 @@ func TestAssignCopilotToIssue(t *testing.T) {
26542654
// Create call request
26552655
request := createMCPRequest(tc.requestArgs)
26562656

2657+
// Disable polling in tests to avoid timeouts
2658+
ctx := ContextWithPollConfig(context.Background(), PollConfig{MaxAttempts: 0})
2659+
ctx = ContextWithDeps(ctx, deps)
2660+
26572661
// Call handler
2658-
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
2662+
result, err := handler(ctx, &request)
26592663
require.NoError(t, err)
26602664

26612665
textContent := getTextResult(t, result)
@@ -2667,7 +2671,16 @@ func TestAssignCopilotToIssue(t *testing.T) {
26672671
}
26682672

26692673
require.False(t, result.IsError, fmt.Sprintf("expected there to be no tool error, text was %s", textContent.Text))
2670-
require.Equal(t, textContent.Text, "successfully assigned copilot to issue")
2674+
2675+
// Verify the JSON response contains expected fields
2676+
var response map[string]any
2677+
err = json.Unmarshal([]byte(textContent.Text), &response)
2678+
require.NoError(t, err, "response should be valid JSON")
2679+
assert.Equal(t, float64(123), response["issue_number"])
2680+
assert.Equal(t, "https://github.com/owner/repo/issues/123", response["issue_url"])
2681+
assert.Equal(t, "owner", response["owner"])
2682+
assert.Equal(t, "repo", response["repo"])
2683+
assert.Contains(t, response["message"], "successfully assigned copilot to issue")
26712684
})
26722685
}
26732686
}

pkg/github/transport.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package github
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
)
7+
8+
// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features
9+
// header to requests based on context values. This is required for using
10+
// non-GA GraphQL API features like the agent assignment API.
11+
//
12+
// Usage:
13+
//
14+
// httpClient := &http.Client{
15+
// Transport: &github.GraphQLFeaturesTransport{
16+
// Transport: http.DefaultTransport,
17+
// },
18+
// }
19+
// gqlClient := githubv4.NewClient(httpClient)
20+
//
21+
// Then use withGraphQLFeatures(ctx, "feature_name") when calling GraphQL operations.
22+
type GraphQLFeaturesTransport struct {
23+
// Transport is the underlying HTTP transport. If nil, http.DefaultTransport is used.
24+
Transport http.RoundTripper
25+
}
26+
27+
// RoundTrip implements http.RoundTripper.
28+
func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) {
29+
transport := t.Transport
30+
if transport == nil {
31+
transport = http.DefaultTransport
32+
}
33+
34+
// Clone the request to avoid mutating the original
35+
req = req.Clone(req.Context())
36+
37+
// Check for GraphQL-Features in context and add header if present
38+
if features := GetGraphQLFeatures(req.Context()); len(features) > 0 {
39+
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
40+
}
41+
42+
return transport.RoundTrip(req)
43+
}

pkg/github/transport_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestGraphQLFeaturesTransport(t *testing.T) {
14+
t.Parallel()
15+
16+
tests := []struct {
17+
name string
18+
features []string
19+
expectedHeader string
20+
hasHeader bool
21+
}{
22+
{
23+
name: "no features in context",
24+
features: nil,
25+
expectedHeader: "",
26+
hasHeader: false,
27+
},
28+
{
29+
name: "single feature in context",
30+
features: []string{"issues_copilot_assignment_api_support"},
31+
expectedHeader: "issues_copilot_assignment_api_support",
32+
hasHeader: true,
33+
},
34+
{
35+
name: "multiple features in context",
36+
features: []string{"feature1", "feature2", "feature3"},
37+
expectedHeader: "feature1, feature2, feature3",
38+
hasHeader: true,
39+
},
40+
{
41+
name: "empty features slice",
42+
features: []string{},
43+
expectedHeader: "",
44+
hasHeader: false,
45+
},
46+
}
47+
48+
for _, tc := range tests {
49+
t.Run(tc.name, func(t *testing.T) {
50+
t.Parallel()
51+
52+
var capturedHeader string
53+
var headerExists bool
54+
55+
// Create a test server that captures the request header
56+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57+
capturedHeader = r.Header.Get("GraphQL-Features")
58+
headerExists = r.Header.Get("GraphQL-Features") != ""
59+
w.WriteHeader(http.StatusOK)
60+
}))
61+
defer server.Close()
62+
63+
// Create the transport
64+
transport := &GraphQLFeaturesTransport{
65+
Transport: http.DefaultTransport,
66+
}
67+
68+
// Create a request
69+
ctx := context.Background()
70+
if tc.features != nil {
71+
ctx = withGraphQLFeatures(ctx, tc.features...)
72+
}
73+
74+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil)
75+
require.NoError(t, err)
76+
77+
// Execute the request
78+
resp, err := transport.RoundTrip(req)
79+
require.NoError(t, err)
80+
defer resp.Body.Close()
81+
82+
// Verify the header
83+
assert.Equal(t, tc.hasHeader, headerExists)
84+
if tc.hasHeader {
85+
assert.Equal(t, tc.expectedHeader, capturedHeader)
86+
}
87+
})
88+
}
89+
}
90+
91+
func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) {
92+
t.Parallel()
93+
94+
var capturedHeader string
95+
96+
// Create a test server
97+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98+
capturedHeader = r.Header.Get("GraphQL-Features")
99+
w.WriteHeader(http.StatusOK)
100+
}))
101+
defer server.Close()
102+
103+
// Create the transport with nil Transport (should use DefaultTransport)
104+
transport := &GraphQLFeaturesTransport{
105+
Transport: nil,
106+
}
107+
108+
// Create a request with features
109+
ctx := withGraphQLFeatures(context.Background(), "test_feature")
110+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil)
111+
require.NoError(t, err)
112+
113+
// Execute the request
114+
resp, err := transport.RoundTrip(req)
115+
require.NoError(t, err)
116+
defer resp.Body.Close()
117+
118+
// Verify the header was added
119+
assert.Equal(t, "test_feature", capturedHeader)
120+
}
121+
122+
func TestGraphQLFeaturesTransport_DoesNotMutateOriginalRequest(t *testing.T) {
123+
t.Parallel()
124+
125+
// Create a test server
126+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
127+
w.WriteHeader(http.StatusOK)
128+
}))
129+
defer server.Close()
130+
131+
// Create the transport
132+
transport := &GraphQLFeaturesTransport{
133+
Transport: http.DefaultTransport,
134+
}
135+
136+
// Create a request with features
137+
ctx := withGraphQLFeatures(context.Background(), "test_feature")
138+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil)
139+
require.NoError(t, err)
140+
141+
// Store the original header value
142+
originalHeader := req.Header.Get("GraphQL-Features")
143+
144+
// Execute the request
145+
resp, err := transport.RoundTrip(req)
146+
require.NoError(t, err)
147+
defer resp.Body.Close()
148+
149+
// Verify the original request was not mutated
150+
assert.Equal(t, originalHeader, req.Header.Get("GraphQL-Features"))
151+
}

0 commit comments

Comments
 (0)