diff --git a/github/github.go b/github/github.go index c4da5391195..2577d1def69 100644 --- a/github/github.go +++ b/github/github.go @@ -40,10 +40,8 @@ const ( HeaderRequestID = "X-Github-Request-Id" // https://docs.github.com/en/rest/about-the-rest-api/api-versions#about-api-versioning - defaultAPIVersion = api20221128 - latestAPIVersion = api20260310 - api20221128 = "2022-11-28" - api20260310 = "2026-03-10" + api20221128 = "2022-11-28" + api20260310 = "2026-03-10" defaultBaseURL = "https://api.github.com/" defaultUserAgent = "go-github" + "/" + Version @@ -178,6 +176,13 @@ type Client struct { // Base URL for uploading files. uploadURL *url.URL + // Default API version to set in the X-Github-Api-Version header. + apiVersionDefault string + // Minimum API version that the client can use. + apiVersionMin string + // Maximum API version that the client can use. + apiVersionMax string + // User agent used when communicating with the GitHub API. userAgent string @@ -347,6 +352,8 @@ type clientOptions struct { httpClient *http.Client transport http.RoundTripper timeout *time.Duration + apiVersionMin *string + apiVersionMax *string userAgent *string envProxy bool token *string @@ -558,7 +565,11 @@ func NewClient(opts ...ClientOptionsFunc) (*Client, error) { // newClient creates a new Client with the provided options. This is an internal // helper function that is called by [NewClient] and [Client.Clone]. func newClient(opts clientOptions) (*Client, error) { - c := &Client{} + c := &Client{ + apiVersionDefault: api20221128, + apiVersionMin: api20221128, + apiVersionMax: api20260310, + } if opts.httpClient != nil { c.client = opts.httpClient @@ -609,6 +620,14 @@ func newClient(opts clientOptions) (*Client, error) { CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }, } + if opts.apiVersionMin != nil { + c.apiVersionMin = *opts.apiVersionMin + } + + if opts.apiVersionMax != nil { + c.apiVersionMax = *opts.apiVersionMax + } + if opts.userAgent != nil { c.userAgent = *opts.userAgent } else { @@ -718,6 +737,8 @@ func (c *Client) Clone(opts ...ClientOptionsFunc) (*Client, error) { } o := clientOptions{ + apiVersionMin: &c.apiVersionMin, + apiVersionMax: &c.apiVersionMax, userAgent: &c.userAgent, baseURL: Ptr(*c.baseURL), uploadURL: Ptr(*c.uploadURL), @@ -814,7 +835,7 @@ func (c *Client) NewRequest(ctx context.Context, method, urlStr string, body any if c.userAgent != "" { req.Header.Set("User-Agent", c.userAgent) } - req.Header.Set(headerAPIVersion, defaultAPIVersion) + req.Header.Set(headerAPIVersion, c.apiVersionDefault) for _, opt := range opts { opt(req) @@ -851,7 +872,7 @@ func (c *Client) NewFormRequest(ctx context.Context, urlStr string, body io.Read if c.userAgent != "" { req.Header.Set("User-Agent", c.userAgent) } - req.Header.Set(headerAPIVersion, defaultAPIVersion) + req.Header.Set(headerAPIVersion, c.apiVersionDefault) for _, opt := range opts { opt(req) @@ -922,7 +943,7 @@ func (c *Client) NewUploadRequest(ctx context.Context, urlStr string, reader io. req.Header.Set("Content-Type", mediaType) req.Header.Set("Accept", mediaTypeV3) req.Header.Set("User-Agent", c.userAgent) - req.Header.Set(headerAPIVersion, defaultAPIVersion) + req.Header.Set(headerAPIVersion, c.apiVersionDefault) for _, opt := range opts { opt(req) @@ -1143,6 +1164,28 @@ const ( // unexpectedly large error body. const maxErrorBodySize = 1 * 1024 * 1024 // 1 MiB +// ErrUnsupportedAPIVersion is returned when the API version specified in the +// request is not supported by the client. +var ErrUnsupportedAPIVersion = errors.New("unsupported api version") + +// checkRequestAPIVersionBeforeDo checks if the API version specified in the +// request is supported by the client before making the API call. If the +// version is not supported, it returns [ErrUnsupportedAPIVersion]. If the +// version is empty it returns nil. +func (c *Client) checkRequestAPIVersionBeforeDo(req *http.Request) error { + reqAPIVersion := req.Header.Get(headerAPIVersion) + + if reqAPIVersion == "" { + return nil + } + + if reqAPIVersion < c.apiVersionMin || reqAPIVersion > c.apiVersionMax { + return ErrUnsupportedAPIVersion + } + + return nil +} + // bareDo sends an API request using `caller` http.Client passed in the parameters // and lets you handle the api response. If an error or API Error occurs, the error // will contain more information. Otherwise, you are supposed to read and close the @@ -1151,6 +1194,10 @@ const maxErrorBodySize = 1 * 1024 * 1024 // 1 MiB func (c *Client) bareDo(caller *http.Client, req *http.Request) (*Response, error) { ctx := req.Context() + if err := c.checkRequestAPIVersionBeforeDo(req); err != nil { + return nil, err + } + rateLimitCategory := CoreCategory if !c.disableRateLimitCheck { diff --git a/github/github_test.go b/github/github_test.go index 2beedbddef3..79f8d1b795d 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -1010,6 +1010,8 @@ func Test_newClient(t *testing.T) { httpClient: &http.Client{Transport: &http.Transport{IdleConnTimeout: 5 * time.Second}}, transport: &http.Transport{IdleConnTimeout: 10 * time.Second}, timeout: Ptr(15 * time.Second), + apiVersionMin: Ptr(api20221128), + apiVersionMax: Ptr(api20221128), userAgent: Ptr("CustomUserAgent/1.0"), baseURL: mustParseURL(t, "https://custom-url/api/v3/"), uploadURL: mustParseURL(t, "https://custom-upload-url/api/uploads/"), @@ -1092,6 +1094,20 @@ func Test_newClient(t *testing.T) { t.Error("newClient http.Client used for redirects should have a CheckRedirect function") } + if tt.opts.apiVersionMin != nil && c.apiVersionMin != *tt.opts.apiVersionMin { + t.Errorf("newClient apiVersionMin is %v, want %v", c.apiVersionMin, *tt.opts.apiVersionMin) + } + if tt.opts.apiVersionMin == nil && c.apiVersionMin != api20221128 { + t.Errorf("newClient apiVersionMin is %v, want %v", c.apiVersionMin, api20221128) + } + + if tt.opts.apiVersionMax != nil && c.apiVersionMax != *tt.opts.apiVersionMax { + t.Errorf("newClient apiVersionMax is %v, want %v", c.apiVersionMax, *tt.opts.apiVersionMax) + } + if tt.opts.apiVersionMax == nil && c.apiVersionMax != api20260310 { + t.Errorf("newClient apiVersionMax is %v, want %v", c.apiVersionMax, api20260310) + } + if tt.opts.userAgent != nil && c.userAgent != *tt.opts.userAgent { t.Errorf("newClient userAgent is %v, want %v", c.userAgent, *tt.opts.userAgent) } @@ -1274,6 +1290,8 @@ func TestClient_Clone(t *testing.T) { t.Parallel() c := mustNewClient(t) + c.apiVersionMin = api20221128 + c.apiVersionMax = api20221128 c.userAgent = "CustomUserAgent/1.0" c.baseURL.Path = "/custom/" c.uploadURL.Path = "/custom-upload/" @@ -1581,7 +1599,7 @@ func TestNewRequest(t *testing.T) { } apiVersion := req.Header.Get(headerAPIVersion) - if got, want := apiVersion, defaultAPIVersion; got != want { + if got, want := apiVersion, api20221128; got != want { t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) } @@ -1667,13 +1685,13 @@ func TestNewRequest_errorForNoTrailingSlash(t *testing.T) { for _, test := range tests { u, err := url.Parse(test.rawurl) if err != nil { - t.Fatalf("url.Parse returned unexpected error: %v.", err) + t.Fatalf("url.Parse returned unexpected error: %v", err) } c.baseURL = u if _, err := c.NewRequest(t.Context(), "GET", "test", nil); test.wantError && err == nil { t.Fatal("Expected error to be returned.") } else if !test.wantError && err != nil { - t.Fatalf("NewRequest returned unexpected error: %v.", err) + t.Fatalf("NewRequest returned unexpected error: %v", err) } } } @@ -1788,7 +1806,7 @@ func TestNewFormRequest(t *testing.T) { } apiVersion := req.Header.Get(headerAPIVersion) - if got, want := apiVersion, defaultAPIVersion; got != want { + if got, want := apiVersion, api20221128; got != want { t.Errorf("NewFormRequest() %v header is %v, want %v", headerAPIVersion, got, want) } @@ -1847,13 +1865,13 @@ func TestNewFormRequest_errorForNoTrailingSlash(t *testing.T) { for _, test := range tests { u, err := url.Parse(test.rawURL) if err != nil { - t.Fatalf("url.Parse returned unexpected error: %v.", err) + t.Fatalf("url.Parse returned unexpected error: %v", err) } c.baseURL = u if _, err := c.NewFormRequest(t.Context(), "test", nil); test.wantError && err == nil { t.Fatal("Expected error to be returned.") } else if !test.wantError && err != nil { - t.Fatalf("NewFormRequest returned unexpected error: %v.", err) + t.Fatalf("NewFormRequest returned unexpected error: %v", err) } } } @@ -1864,7 +1882,7 @@ func TestNewUploadRequest_WithVersion(t *testing.T) { req, _ := c.NewUploadRequest(t.Context(), "https://example.com/", nil, 0, "") apiVersion := req.Header.Get(headerAPIVersion) - if got, want := apiVersion, defaultAPIVersion; got != want { + if got, want := apiVersion, api20221128; got != want { t.Errorf("NewRequest() %v header is %v, want %v", headerAPIVersion, got, want) } @@ -1901,13 +1919,13 @@ func TestNewUploadRequest_errorForNoTrailingSlash(t *testing.T) { for _, test := range tests { u, err := url.Parse(test.rawurl) if err != nil { - t.Fatalf("url.Parse returned unexpected error: %v.", err) + t.Fatalf("url.Parse returned unexpected error: %v", err) } c.uploadURL = u if _, err = c.NewUploadRequest(t.Context(), "test", nil, 0, ""); test.wantError && err == nil { t.Fatal("Expected error to be returned.") } else if !test.wantError && err != nil { - t.Fatalf("NewUploadRequest returned unexpected error: %v.", err) + t.Fatalf("NewUploadRequest returned unexpected error: %v", err) } } } @@ -2225,7 +2243,7 @@ func TestDo_redirectLoop(t *testing.T) { t.Error("Expected error to be returned.") } if !errors.As(err, new(*url.Error)) { - t.Errorf("Expected a URL error; got %#v.", err) + t.Errorf("Expected a URL error; got %#v", err) } } @@ -2464,7 +2482,7 @@ func TestDo_rateLimit_errorResponse(t *testing.T) { t.Error("Expected error to be returned.") } if errors.As(err, new(*RateLimitError)) { - t.Errorf("Did not expect a *RateLimitError error; got %#v.", err) + t.Errorf("Did not expect a *RateLimitError error; got %#v", err) } if got, want := resp.Rate.Limit, 60; got != want { t.Errorf("Client rate limit = %v, want %v", got, want) @@ -2511,7 +2529,7 @@ func TestDo_rateLimit_rateLimitError(t *testing.T) { } var rateLimitErr *RateLimitError if !errors.As(err, &rateLimitErr) { - t.Fatalf("Expected a *RateLimitError error; got %#v.", err) + t.Fatalf("Expected a *RateLimitError error; got %#v", err) } if got, want := rateLimitErr.Rate.Limit, 60; got != want { t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) @@ -2578,7 +2596,7 @@ func TestDo_rateLimit_noNetworkCall(t *testing.T) { } var rateLimitErr *RateLimitError if !errors.As(err, &rateLimitErr) { - t.Fatalf("Expected a *RateLimitError error; got %#v.", err) + t.Fatalf("Expected a *RateLimitError error; got %#v", err) } if got, want := rateLimitErr.Rate.Limit, 60; got != want { t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) @@ -2813,7 +2831,7 @@ func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) { _, err := client.Do(req, nil) var rateLimitError *RateLimitError if !errors.As(err, &rateLimitError) { - t.Fatalf("Expected a *rateLimitError error; got %#v.", err) + t.Fatalf("Expected a *rateLimitError error; got %#v", err) } if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) { t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got) @@ -2848,7 +2866,7 @@ func TestDo_rateLimit_abuseRateLimitError(t *testing.T) { } var abuseRateLimitErr *AbuseRateLimitError if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) @@ -2882,7 +2900,7 @@ func TestDo_rateLimit_abuseRateLimitErrorEnterprise(t *testing.T) { } var abuseRateLimitErr *AbuseRateLimitError if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if got, want := abuseRateLimitErr.RetryAfter, (*time.Duration)(nil); got != want { t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) @@ -2913,7 +2931,7 @@ func TestDo_rateLimit_abuseRateLimitError_retryAfter(t *testing.T) { } var abuseRateLimitErr *AbuseRateLimitError if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if abuseRateLimitErr.RetryAfter == nil { t.Fatal("abuseRateLimitErr RetryAfter is nil, expected not-nil") @@ -2927,7 +2945,7 @@ func TestDo_rateLimit_abuseRateLimitError_retryAfter(t *testing.T) { t.Error("Expected error to be returned.") } if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if abuseRateLimitErr.RetryAfter == nil { t.Fatal("abuseRateLimitErr RetryAfter is nil, expected not-nil") @@ -2969,7 +2987,7 @@ func TestDo_rateLimit_abuseRateLimitError_xRateLimitReset(t *testing.T) { } var abuseRateLimitErr *AbuseRateLimitError if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if abuseRateLimitErr.RetryAfter == nil { t.Fatal("abuseRateLimitErr RetryAfter is nil, expected not-nil") @@ -2984,7 +3002,7 @@ func TestDo_rateLimit_abuseRateLimitError_xRateLimitReset(t *testing.T) { t.Error("Expected error to be returned.") } if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if abuseRateLimitErr.RetryAfter == nil { t.Fatal("abuseRateLimitErr RetryAfter is nil, expected not-nil") @@ -3027,7 +3045,7 @@ func TestDo_rateLimit_abuseRateLimitError_maxDuration(t *testing.T) { } var abuseRateLimitErr *AbuseRateLimitError if !errors.As(err, &abuseRateLimitErr) { - t.Fatalf("Expected a *AbuseRateLimitError error; got %#v.", err) + t.Fatalf("Expected a *AbuseRateLimitError error; got %#v", err) } if abuseRateLimitErr.RetryAfter == nil { t.Fatal("abuseRateLimitErr RetryAfter is nil, expected not-nil") @@ -3126,6 +3144,106 @@ func TestDo_noContent(t *testing.T) { } } +func TestClient_checkRequestAPIVersionBeforeDo(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + version string + versionMin string + versionMax string + wantErr bool + }{ + { + name: "version_not_set", + version: "", + versionMin: api20221128, + versionMax: api20260310, + wantErr: false, + }, + { + name: "version_less_than_min", + version: "2022-01-01", + versionMin: api20221128, + versionMax: api20260310, + wantErr: true, + }, + { + name: "version_equal_to_min", + version: api20221128, + versionMin: api20221128, + versionMax: api20260310, + wantErr: false, + }, + { + name: "version_between_min_and_max", + version: "2023-01-01", + versionMin: api20221128, + versionMax: api20260310, + wantErr: false, + }, + { + name: "version_equal_to_max", + version: api20260310, + versionMin: api20221128, + versionMax: api20260310, + wantErr: false, + }, + { + name: "version_greater_than_max", + version: api20260310, + versionMin: api20221128, + versionMax: api20221128, + wantErr: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client := mustNewClient(t) + client.apiVersionMin = tt.versionMin + client.apiVersionMax = tt.versionMax + + req, _ := http.NewRequestWithContext(t.Context(), "GET", ".", nil) + req.Header.Set(headerAPIVersion, tt.version) + + err := client.checkRequestAPIVersionBeforeDo(req) + if tt.wantErr { + if err == nil { + t.Fatal("Expected error to be returned, got nil") + } + if !errors.Is(err, ErrUnsupportedAPIVersion) { + t.Errorf("Expected ErrUnsupportedAPIVersion; got %#v", err) + } + return + } + + if err != nil { + t.Fatalf("Expected no error to be returned, got: %v", err) + } + }) + } +} + +func TestClient_bareDo_errors_with_unsupported_api_version(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + c.apiVersionMin = api20221128 + c.apiVersionMax = api20221128 + + req, _ := http.NewRequestWithContext(t.Context(), "GET", ".", nil) + req.Header.Set(headerAPIVersion, api20260310) + + _, err := c.bareDo(c.client, req) + if err == nil { + t.Fatal("Expected error to be returned, got nil") + } + if !errors.Is(err, ErrUnsupportedAPIVersion) { + t.Errorf("Expected ErrUnsupportedAPIVersion; got %#v", err) + } +} + func TestBareDoUntilFound_redirectLoop(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -3141,7 +3259,7 @@ func TestBareDoUntilFound_redirectLoop(t *testing.T) { t.Error("Expected error to be returned.") } if !errors.As(err, new(*RedirectionError)) { - t.Errorf("Expected a Redirection error; got %#v.", err) + t.Errorf("Expected a Redirection error; got %#v", err) } } @@ -3160,7 +3278,7 @@ func TestBareDoUntilFound_UnexpectedRedirection(t *testing.T) { t.Error("Expected error to be returned.") } if !errors.As(err, new(*RedirectionError)) { - t.Errorf("Expected a Redirection error; got %#v.", err) + t.Errorf("Expected a Redirection error; got %#v", err) } }