Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 97 additions & 73 deletions src/SwaggerProvider.DesignTime/OperationCompiler.fs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace SwaggerProvider.Internal.Compilers
open System
open System.Collections.Generic
open System.Net.Http
open System.Reflection
open System.Text.Json

open Microsoft.FSharp.Quotations
Expand Down Expand Up @@ -60,6 +61,44 @@ type PayloadType =

/// Object for compiling operations.
type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, ignoreControllerPrefix, ignoreOperationId, asAsync: bool) =
let toParamMethod =
match <@@ RuntimeHelpers.toParam(null) @@> with
| Call(None, m, _) -> m
| _ -> failwith "Cannot extract toParam MethodInfo"

let toQueryParamsMethod =
match <@@ RuntimeHelpers.toQueryParams "" null Unchecked.defaultof<ProvidedApiClientBase> @@> with
| Call(None, m, _) -> m
| _ -> failwith "Cannot extract toQueryParams MethodInfo"

let resolveCastMethod(ownerType: Type) =
ownerType.GetMethods(BindingFlags.Public ||| BindingFlags.Static)
|> Array.tryFind(fun m ->
m.Name = "cast"
&& m.IsGenericMethodDefinition
&& m.GetGenericArguments().Length = 1
&& m.GetParameters().Length = 1)
|> Option.defaultWith(fun () -> failwithf "Cannot extract %s.cast<'T> MethodInfo" ownerType.FullName)

let taskCastMethod = resolveCastMethod typeof<TaskExtensions>
let asyncCastMethod = resolveCastMethod typeof<AsyncExtensions>

let stringPairListExpr(items: (string * string) list) : Expr<(string * string) list> =
let empty = <@ [] @>

(empty, List.rev items)
||> List.fold(fun acc (name, value) ->
let nameExpr = Expr.Value(name, typeof<string>) |> Expr.Cast<string>
let valueExpr = Expr.Value(value, typeof<string>) |> Expr.Cast<string>

<@ (%nameExpr, %valueExpr) :: %acc @>)

let typedListExpr(items: Expr<'T> list) : Expr<'T list> =
let empty = <@ [] @>

(empty, List.rev items)
||> List.fold(fun acc item -> <@ %item :: %acc @>)

let compileOperation (providedMethodName: string) (apiCall: ApiCall) =
let path, pathItem, opTy = apiCall
let operation = pathItem.Operations[opTy]
Expand Down Expand Up @@ -96,7 +135,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
let (|NoMediaType|_|)(content: IDictionary<string, OpenApiMediaType>) =
if isNull content || content.Count = 0 then Some() else None

let payloadTy, payloadMime, parameters, ctArgIndex =
let payloadTy, payloadMime, parameters, ctArgIndex, apiParamByProvidedName =
/// handles de-duplicating Swagger parameter names if the same parameter name
/// appears in multiple locations in a given operation definition.
let uniqueParamName usedNames (param: IOpenApiParameter) =
Expand Down Expand Up @@ -156,8 +195,8 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
|> List.partition(_.Required)

let buildProvidedParameters usedNames (paramList: IOpenApiParameter list) =
((usedNames, []), paramList)
||> List.fold(fun (names, parameters) current ->
((usedNames, [], []), paramList)
||> List.fold(fun (names, parameters, lookup) current ->
let names, paramName = uniqueParamName names current

let paramType =
Expand All @@ -170,15 +209,20 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
let paramDefaultValue = defCompiler.GetDefaultValue paramType
ProvidedParameter(paramName, paramType, false, paramDefaultValue)

(names, providedParam :: parameters))
|> fun (finalNames, ps) -> finalNames, List.rev ps
(names, providedParam :: parameters, (paramName, current) :: lookup))
|> fun (finalNames, ps, lookup) -> finalNames, List.rev ps, List.rev lookup

let namesAfterRequired, requiredProvidedParams =
let namesAfterRequired, requiredProvidedParams, requiredLookup =
buildProvidedParameters Set.empty requiredOpenApiParams

let _, optionalProvidedParams =
let _, optionalProvidedParams, optionalLookup =
buildProvidedParameters namesAfterRequired optionalOpenApiParams

let apiParamByProvidedName =
requiredLookup @ optionalLookup
|> List.choose(fun (paramName, param) -> if param.In.HasValue then Some(paramName, param) else None)
|> Map.ofList

let ctArgIndex, parameters =
let scope = UniqueNameGenerator()

Expand All @@ -196,7 +240,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,

ctArgIndex, requiredProvidedParams @ optionalProvidedParams @ [ ctParam ]

payloadTy, payloadTy.ToMediaType(), parameters, ctArgIndex
payloadTy, payloadTy.ToMediaType(), parameters, ctArgIndex, apiParamByProvidedName

// find the inner type value
let okResponse =
Expand Down Expand Up @@ -258,6 +302,12 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
|> Seq.toArray
|> Array.unzip

let fixedHeaders =
[ if not(isNull payloadMime) then
"Content-Type", payloadMime
if not(isNull retMime) then
"Accept", retMime ]

let m =
ProvidedMethod(
providedMethodName,
Expand All @@ -271,13 +321,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,

let httpMethod = opTy.ToString()

let headers =
<@
[ if not(isNull payloadMime) then
"Content-Type", payloadMime
if not(isNull retMime) then
"Accept", retMime ]
@>
let headers = stringPairListExpr fixedHeaders

// Locates parameters matching the arguments
let mutable payloadExp = None
Expand All @@ -298,14 +342,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
apiArgs
|> List.choose (function
| ShapeVar sVar as expr ->
let param =
openApiParameters
|> Seq.tryFind(fun x ->
// pain point: we have to make sure that the set of names we search for here are the same as the set of names generated when we make `parameters` above
let baseName = niceCamelName x.Name
baseName = sVar.Name || (unambiguousName x) = sVar.Name)

match param with
match apiParamByProvidedName |> Map.tryFind sVar.Name with
| Some(par) -> Some(par, expr)
| _ ->
let payloadType = PayloadType.Parse sVar.Name
Expand All @@ -324,31 +361,21 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
// object across all calls, causing "duplicate key" exceptions in ProvidedTypes
// when the same helper is called for multiple parameters in one operation.
// Instead, build the call expression directly without an intermediate binding.
let toParamMethod =
match <@@ RuntimeHelpers.toParam(null) @@> with
| Call(None, m, _) -> m
| _ -> failwith "Cannot extract toParam MethodInfo"

let coerceString exp =
let obj = Expr.Coerce(exp, typeof<obj>)
Expr.Call(toParamMethod, [ obj ]) |> Expr.Cast<string>

let toQueryParamsMethod =
match <@@ RuntimeHelpers.toQueryParams "" null (%this) @@> with
| Call(None, m, _) -> m
| _ -> failwith "Cannot extract toQueryParams MethodInfo"

let rec coerceQueryString name expr =
let obj = Expr.Coerce(expr, typeof<obj>)

Expr.Call(toQueryParamsMethod, [ Expr.Value name; obj; this ])
|> Expr.Cast<(string * string) list>

// Partitions arguments based on their locations
let path, queryParams, headers =
let path, queryParams, headers, cookies =
((<@ path @>, <@ [] @>, headers, <@ [] @>), parameters)
||> List.fold(fun (path, query, headers, cookies) (param: IOpenApiParameter, valueExpr) ->
let path, queryParamLists, headers, cookies =
let path, queryParamLists, headers, cookies =
((<@ path @>, [], headers, <@ [] @>), parameters)
||> List.fold(fun (path, queryParamLists, headers, cookies) (param: IOpenApiParameter, valueExpr) ->
if param.In.HasValue then
let name = param.Name

Expand All @@ -357,44 +384,32 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
let value = coerceString valueExpr
let pattern = $"{{%s{name}}}"
let path' = <@ (%path).Replace(pattern, %value) @>
(path', query, headers, cookies)
(path', queryParamLists, headers, cookies)
| ParameterLocation.Query ->
let listValues = coerceQueryString name valueExpr
let query' = <@ List.append %query %listValues @>
(path, query', headers, cookies)
(path, listValues :: queryParamLists, headers, cookies)
| ParameterLocation.Header ->
let value = coerceString valueExpr
let headers' = <@ (name, %value) :: (%headers) @>
(path, query, headers', cookies)
(path, queryParamLists, headers', cookies)
| ParameterLocation.Cookie ->
let value = coerceString valueExpr
let cookies' = <@ (name, %value) :: (%cookies) @>
(path, query, headers, cookies')
(path, queryParamLists, headers, cookies')
| x -> failwithf $"Unsupported parameter location '%A{x}'"
else
failwithf "This should not happen, payload expression is already parsed")

let headers' =
<@
let cookieHeader =
%cookies
|> Seq.filter(snd >> isNull >> not)
|> Seq.map(fun (name, value) -> $"{name}={value}")
|> String.concat ";"

if String.IsNullOrEmpty cookieHeader then
%headers
else
("Cookie", cookieHeader) :: (%headers)
@>

(path, queryParams, headers')
path, List.rev queryParamLists, headers, cookies

let queryParamLists = typedListExpr queryParamLists

let httpRequestMessage =
<@
let msg = RuntimeHelpers.createHttpRequest httpMethod %path %queryParams
RuntimeHelpers.fillHeaders msg %headers
let msg =
RuntimeHelpers.createHttpRequestFromQueryLists httpMethod %path %queryParamLists

RuntimeHelpers.fillHeadersAndCookies msg %headers %cookies
msg
@>

Expand Down Expand Up @@ -441,7 +456,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
let action =
<@ (%this).CallAsync(%httpRequestMessageWithPayload, errorCodes, errorDescriptions, %ct) @>

let responseObj =
let responseObj() =
let innerReturnType = defaultArg retTy null

<@
Expand All @@ -455,7 +470,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
}
@>

let responseStream =
let responseStream() =
<@
let x = %action
let ct = %ct
Expand All @@ -467,7 +482,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
}
@>

let responseString =
let responseString() =
<@
let x = %action
let ct = %ct
Expand All @@ -479,7 +494,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
}
@>

let responseUnit =
let responseUnit() =
<@
let x = %action

Expand All @@ -489,27 +504,36 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler,
}
@>

// if we're an async method, then we can just return the above, coerced to the overallReturnType.
// if we're not async, then run that^ through Async.RunSynchronously before doing the coercion.
// Build only the response quotation needed for this operation's return shape.
// For typed JSON responses, emit direct generic cast calls so generated clients
// do not pay MethodInfo.Invoke costs on every API call.
if not asAsync then
match retTy with
| None -> responseUnit.Raw
| Some t when t = typeof<IO.Stream> -> <@ %responseStream @>.Raw
| None -> (responseUnit()).Raw
| Some t when t = typeof<IO.Stream> -> <@ %(responseStream()) @>.Raw
| Some t ->
match retMime with
| TextReturn _ -> <@ %responseString @>.Raw
| _ -> Expr.Coerce(<@ RuntimeHelpers.taskCast t %responseObj @>, overallReturnType)
| TextReturn _ -> <@ %(responseString()) @>.Raw
| _ ->
let castMethod = ProvidedTypeBuilder.MakeGenericMethod(taskCastMethod, [ t ])

Expr.Call(castMethod, [ responseObj() ])
|> fun e -> Expr.Coerce(e, overallReturnType)
else
let awaitTask t =
<@ Async.AwaitTask(%t) @>

match retTy with
| None -> (awaitTask responseUnit).Raw
| Some t when t = typeof<IO.Stream> -> <@ %(awaitTask responseStream) @>.Raw
| None -> (awaitTask(responseUnit())).Raw
| Some t when t = typeof<IO.Stream> -> <@ %(awaitTask(responseStream())) @>.Raw
| Some t ->
match retMime with
| TextReturn _ -> <@ %(awaitTask responseString) @>.Raw
| _ -> Expr.Coerce(<@ RuntimeHelpers.asyncCast t %(awaitTask responseObj) @>, overallReturnType)
| TextReturn _ -> <@ %(awaitTask(responseString())) @>.Raw
| _ ->
let castMethod = ProvidedTypeBuilder.MakeGenericMethod(asyncCastMethod, [ t ])

Expr.Call(castMethod, [ awaitTask(responseObj()) ])
|> fun e -> Expr.Coerce(e, overallReturnType)
)

let xmlDoc =
Expand Down
11 changes: 6 additions & 5 deletions src/SwaggerProvider.DesignTime/Provider.OpenApiClient.fs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ type public OpenApiClientTypeProvider(cfg: TypeProviderConfig) as this =
// check we contain a copy of runtime files, and are not referencing the runtime DLL
do assert (typeof<ProvidedApiClientBase>.Assembly.GetName().Name = asm.GetName().Name)

let buildStringListExpr(items: string list) : Expr =
let stringListNilCase, stringListConsCase =
let cases = FSharpType.GetUnionCases typeof<string list>
let nilCase = cases |> Array.find(fun c -> c.Name = "Empty")
let consCase = cases |> Array.find(fun c -> c.Name = "Cons")
let nil = Expr.NewUnionCase(nilCase, [])
cases |> Array.find(fun c -> c.Name = "Empty"), cases |> Array.find(fun c -> c.Name = "Cons")

let buildStringListExpr(items: string list) : Expr =
let nil = Expr.NewUnionCase(stringListNilCase, [])

List.foldBack (fun (s: string) acc -> Expr.NewUnionCase(consCase, [ Expr.Value(s, typeof<string>); acc ])) items nil
List.foldBack (fun (s: string) acc -> Expr.NewUnionCase(stringListConsCase, [ Expr.Value(s, typeof<string>); acc ])) items nil

let myParamType =
let t =
Expand Down
Loading
Loading