diff --git a/.gitignore b/.gitignore index a893d80..e10166f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ dist/ client/go/modcdp client/go/modcdp-go __pycache__/ +testdata/codegen/stagehand_alias_manifest.json +js/test/stagehand_client_generated/stagehand_client_gen.ts +python/tests/stagehand_client_generated/stagehand_client_gen.py +testdata/codegen/stagehand_client_generated_go/stagehand_client_gen.go diff --git a/go/modcdp/client/CDPTypes.go b/go/modcdp/client/CDPTypes.go index bcf0b59..afb200b 100644 --- a/go/modcdp/client/CDPTypes.go +++ b/go/modcdp/client/CDPTypes.go @@ -29,6 +29,7 @@ type CDPTypes struct { CustomCommands map[string]CustomCommand CustomEvents map[string]CustomEvent CustomMiddlewares []CustomMiddleware + CustomAliasObjects map[string]CustomAliasObject commandSchemas map[string]CDPCommandSchema commandParamsSchemas map[string]map[string]any commandResultSchemas map[string]map[string]any @@ -105,6 +106,47 @@ var modMiddlewareRegistrationSchema = map[string]any{ "additionalProperties": false, } +var modAliasReturnSchema = map[string]any{ + "type": "object", + "properties": map[string]any{ + "object": map[string]any{"type": "string"}, + "unwrap": map[string]any{"type": "string"}, + "array": map[string]any{"type": "boolean"}, + "nullable": map[string]any{"type": "boolean"}, + }, + "additionalProperties": false, +} + +var modAliasMethodRegistrationSchema = map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "command": map[string]any{"type": "string"}, + "params_schema": map[string]any{"type": []any{"object", "null"}}, + "result_schema": map[string]any{"type": []any{"object", "null"}}, + "sticky_param": map[string]any{"type": "string"}, + "sticky_params": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}, + "sticky_fields": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}, + "return": modAliasReturnSchema, + }, + "required": []any{"name", "command"}, + "additionalProperties": false, +} + +var modAliasObjectRegistrationSchema = map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "type_name": map[string]any{"type": "string"}, + "root": map[string]any{"type": "boolean"}, + "sticky_schema": map[string]any{"type": []any{"object", "null"}}, + "sticky_fields": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}, + "methods": map[string]any{"type": "array", "items": modAliasMethodRegistrationSchema}, + }, + "required": []any{"name"}, + "additionalProperties": false, +} + var modConfigureParamsSchema = map[string]any{ "type": "object", "properties": map[string]any{ @@ -149,6 +191,7 @@ var modConfigureParamsSchema = map[string]any{ "custom_commands": map[string]any{"type": "array", "items": modCommandRegistrationSchema}, "custom_events": map[string]any{"type": "array", "items": modEventRegistrationSchema}, "custom_middlewares": map[string]any{"type": "array", "items": modMiddlewareRegistrationSchema}, + "custom_alias_objects": map[string]any{"type": "array", "items": modAliasObjectRegistrationSchema}, }, "additionalProperties": false, } @@ -327,11 +370,12 @@ var defaultBuiltinEvents = []CustomEvent{ }, } -func NewCDPTypes(customCommands []CustomCommand, customEvents []CustomEvent, customMiddlewares []CustomMiddleware) *CDPTypes { +func NewCDPTypes(customCommands []CustomCommand, customEvents []CustomEvent, customMiddlewares []CustomMiddleware, customAliasObjectGroups ...[]CustomAliasObject) *CDPTypes { types := &CDPTypes{ CustomCommands: map[string]CustomCommand{}, CustomEvents: map[string]CustomEvent{}, CustomMiddlewares: []CustomMiddleware{}, + CustomAliasObjects: map[string]CustomAliasObject{}, commandSchemas: map[string]CDPCommandSchema{}, commandParamsSchemas: map[string]map[string]any{}, commandResultSchemas: map[string]map[string]any{}, @@ -375,6 +419,13 @@ func NewCDPTypes(customCommands []CustomCommand, customEvents []CustomEvent, cus panic(err) } } + for _, customAliasObjects := range customAliasObjectGroups { + for _, object := range customAliasObjects { + if _, err := types.AddCustomAliasObject(object); err != nil { + panic(err) + } + } + } return types } @@ -389,11 +440,16 @@ func (types *CDPTypes) Update(config CDPTypesConfig) *CDPTypes { customEvents = append(customEvents, event) } customMiddlewares := append([]CustomMiddleware{}, types.CustomMiddlewares...) + customAliasObjects := make([]CustomAliasObject, 0, len(types.CustomAliasObjects)+len(config.CustomAliasObjects)) + for _, object := range types.CustomAliasObjects { + customAliasObjects = append(customAliasObjects, object) + } types.mu.RUnlock() customCommands = append(customCommands, config.CustomCommands...) customEvents = append(customEvents, config.CustomEvents...) customMiddlewares = append(customMiddlewares, config.CustomMiddlewares...) - return NewCDPTypes(customCommands, customEvents, customMiddlewares) + customAliasObjects = append(customAliasObjects, config.CustomAliasObjects...) + return NewCDPTypes(customCommands, customEvents, customMiddlewares, customAliasObjects) } func (types *CDPTypes) ToJSON() map[string]any { @@ -410,11 +466,13 @@ func (types *CDPTypes) ToJSON() map[string]any { } customMiddlewares = append(customMiddlewares, registration) } + customAliasObjects := types.CustomAliasObjectWireRegistrations() types.mu.RLock() state := map[string]any{ "custom_commands": len(types.CustomCommands), "custom_events": len(types.CustomEvents), "custom_middlewares": len(types.CustomMiddlewares), + "custom_alias_objects": len(types.CustomAliasObjects), "command_params_schemas": len(types.commandParamsSchemas), "command_result_schemas": len(types.commandResultSchemas), "event_schemas": len(types.eventSchemas), @@ -422,9 +480,10 @@ func (types *CDPTypes) ToJSON() map[string]any { types.mu.RUnlock() return modtypes.ModCDPToJSON(types, modtypes.ModCDPJSONConfig{ Config: map[string]any{ - "custom_commands": customCommands, - "custom_events": types.CustomEventWireRegistrations(), - "custom_middlewares": customMiddlewares, + "custom_commands": customCommands, + "custom_events": types.CustomEventWireRegistrations(), + "custom_middlewares": customMiddlewares, + "custom_alias_objects": customAliasObjects, }, State: state, }) @@ -695,15 +754,21 @@ func (types *CDPTypes) AddCustomMiddleware(middleware CustomMiddleware) (string, return "", fmt.Errorf("phase must be request, response, or event") } middleware.Name = name + types.mu.Lock() + defer types.mu.Unlock() types.CustomMiddlewares = append(types.CustomMiddlewares, middleware) return name, nil } func (types *CDPTypes) CustomMiddlewareWireRegistrations() []CustomMiddleware { + types.mu.RLock() + defer types.mu.RUnlock() return append([]CustomMiddleware{}, types.CustomMiddlewares...) } func (types *CDPTypes) CustomMiddlewareRegistrations(phase string, name string) []CustomMiddleware { + types.mu.RLock() + defer types.mu.RUnlock() middlewares := []CustomMiddleware{} for _, middleware := range types.CustomMiddlewares { middlewareName := middleware.Name @@ -717,6 +782,129 @@ func (types *CDPTypes) CustomMiddlewareRegistrations(phase string, name string) return middlewares } +func (types *CDPTypes) AddCustomAliasObject(object CustomAliasObject) (string, error) { + name := strings.TrimSpace(object.Name) + if name == "" { + return "", fmt.Errorf("custom alias object name is required") + } + normalized := CustomAliasObject{ + Name: name, + TypeName: strings.TrimSpace(object.TypeName), + StickyFields: append([]string{}, object.StickyFields...), + Methods: make([]AliasMethod, 0, len(object.Methods)), + } + if object.StickySchema != nil { + normalized.StickySchema = cloneSchema(object.StickySchema) + } + for _, method := range object.Methods { + methodName := strings.TrimSpace(method.Name) + if methodName == "" { + return "", fmt.Errorf("%s alias method name is required", name) + } + commandName := "" + if strings.TrimSpace(method.Command) != "" { + normalizedCommandName, err := normalizeModCDPName(method.Command) + if err != nil { + return "", fmt.Errorf("%s.%s command: %w", name, methodName, err) + } + commandName = normalizedCommandName + } + normalizedMethod := AliasMethod{ + Name: methodName, + Command: commandName, + SDKMethodName: strings.TrimSpace(method.SDKMethodName), + StickyParam: strings.TrimSpace(method.StickyParam), + StickyParams: append([]string{}, method.StickyParams...), + StickyFields: append([]string{}, method.StickyFields...), + } + if method.ParamsSchema != nil { + normalizedMethod.ParamsSchema = cloneSchema(method.ParamsSchema) + } + if method.ResultSchema != nil { + normalizedMethod.ResultSchema = cloneSchema(method.ResultSchema) + } + if method.Return != nil { + normalizedMethod.Return = &AliasReturn{ + Object: strings.TrimSpace(method.Return.Object), + Unwrap: strings.TrimSpace(method.Return.Unwrap), + Array: method.Return.Array, + Nullable: method.Return.Nullable, + } + } + normalized.Methods = append(normalized.Methods, normalizedMethod) + } + types.mu.Lock() + defer types.mu.Unlock() + for _, method := range normalized.Methods { + if method.Command == "" { + continue + } + if method.ParamsSchema != nil { + if schema := cloneSchema(method.ParamsSchema); schema != nil { + types.commandParamsSchemas[method.Command] = schema + } + } + if method.ResultSchema != nil { + if _, exists := types.commandResultSchemas[method.Command]; exists { + continue + } + if schema := cloneSchema(method.ResultSchema); schema != nil { + types.commandResultSchemas[method.Command] = schema + } + } + } + types.CustomAliasObjects[name] = normalized + return name, nil +} + +func (types *CDPTypes) CustomAliasObjectWireRegistrations() []CustomAliasObject { + types.mu.RLock() + objects := make([]CustomAliasObject, 0, len(types.CustomAliasObjects)) + for _, object := range types.CustomAliasObjects { + objects = append(objects, object) + } + types.mu.RUnlock() + registrations := make([]CustomAliasObject, 0, len(objects)) + for _, object := range objects { + registration := CustomAliasObject{ + Name: object.Name, + TypeName: object.TypeName, + StickyFields: append([]string{}, object.StickyFields...), + Methods: make([]AliasMethod, 0, len(object.Methods)), + } + if object.StickySchema != nil { + registration.StickySchema = cloneSchema(object.StickySchema) + } + for _, method := range object.Methods { + methodRegistration := AliasMethod{ + Name: method.Name, + Command: method.Command, + SDKMethodName: method.SDKMethodName, + StickyParam: method.StickyParam, + StickyParams: append([]string{}, method.StickyParams...), + StickyFields: append([]string{}, method.StickyFields...), + } + if method.ParamsSchema != nil { + methodRegistration.ParamsSchema = cloneSchema(method.ParamsSchema) + } + if method.ResultSchema != nil { + methodRegistration.ResultSchema = cloneSchema(method.ResultSchema) + } + if method.Return != nil { + methodRegistration.Return = &AliasReturn{ + Object: method.Return.Object, + Unwrap: method.Return.Unwrap, + Array: method.Return.Array, + Nullable: method.Return.Nullable, + } + } + registration.Methods = append(registration.Methods, methodRegistration) + } + registrations = append(registrations, registration) + } + return registrations +} + func (types *CDPTypes) ServiceWorkerCommandStep(method string, params map[string]any, cdpSessionID string, executionContextID int) (modtypes.TranslatedStep, error) { if params == nil { params = map[string]any{} diff --git a/go/modcdp/client/ModCDPClient.go b/go/modcdp/client/ModCDPClient.go index 19b6245..9a48028 100644 --- a/go/modcdp/client/ModCDPClient.go +++ b/go/modcdp/client/ModCDPClient.go @@ -88,6 +88,9 @@ type AutoSessionRouter = router.AutoSessionRouter type CustomCommand = types.ModCDPAddCustomCommandParams type CustomEvent = types.ModCDPAddCustomEventObjectParams type CustomMiddleware = types.ModCDPAddMiddlewareParams +type CustomAliasObject = types.ModCDPAliasObject +type AliasMethod = types.ModCDPAliasMethod +type AliasReturn = types.ModCDPAliasReturn var NewLocalBrowserLauncher = launcher.NewLocalBrowserLauncher var NewRemoteBrowserLauncher = launcher.NewRemoteBrowserLauncher @@ -163,9 +166,10 @@ func freePort() (int, error) { } type CDPTypesConfig struct { - CustomCommands []CustomCommand `json:"custom_commands,omitempty"` - CustomEvents []CustomEvent `json:"custom_events,omitempty"` - CustomMiddlewares []CustomMiddleware `json:"custom_middlewares,omitempty"` + CustomCommands []CustomCommand `json:"custom_commands,omitempty"` + CustomEvents []CustomEvent `json:"custom_events,omitempty"` + CustomMiddlewares []CustomMiddleware `json:"custom_middlewares,omitempty"` + CustomAliasObjects []CustomAliasObject `json:"custom_alias_objects,omitempty"` } type ServerConfig struct { @@ -177,6 +181,7 @@ type ServerConfig struct { CustomCommands []CustomCommand `json:"custom_commands,omitempty"` CustomEvents []CustomEvent `json:"custom_events,omitempty"` CustomMiddlewares []CustomMiddleware `json:"custom_middlewares,omitempty"` + CustomAliasObjects []CustomAliasObject `json:"custom_alias_objects,omitempty"` disabled bool } @@ -406,7 +411,7 @@ func New(config Config) *ModCDPClient { upstream := NewWSUpstreamTransport(config.Upstream) client := &ModCDPClient{ Config: config, - Types: NewCDPTypes(typesConfig.CustomCommands, typesConfig.CustomEvents, typesConfig.CustomMiddlewares), + Types: NewCDPTypes(typesConfig.CustomCommands, typesConfig.CustomEvents, typesConfig.CustomMiddlewares, typesConfig.CustomAliasObjects), Upstream: upstream, handlers: map[string][]handlerEntry{}, } diff --git a/go/modcdp/client/alias.go b/go/modcdp/client/alias.go new file mode 100644 index 0000000..2f5dd98 --- /dev/null +++ b/go/modcdp/client/alias.go @@ -0,0 +1,208 @@ +package client + +import ( + "encoding/json" + "fmt" + "strings" +) + +type AliasSticky = map[string]any +type AliasJSONObject = map[string]any + +type ModCDPAliasObject struct { + Client *ModCDPClient + Sticky AliasSticky +} + +func NewModCDPAliasObject(client *ModCDPClient, sticky AliasSticky) ModCDPAliasObject { + return ModCDPAliasObject{Client: client, Sticky: cloneAliasMap(sticky)} +} + +func SendAliasCommand[T any](client *ModCDPClient, method string, params any, sticky AliasSticky, stickyParam string, stickyFields []string, unwrap string) (T, error) { + stickyParams := []string{} + if stickyParam != "" { + stickyParams = []string{stickyParam} + } + return SendAliasCommandWithStickyParams[T](client, method, params, sticky, stickyParam, stickyParams, stickyFields, unwrap) +} + +func SendAliasCommandWithStickyParams[T any](client *ModCDPClient, method string, params any, sticky AliasSticky, stickyParam string, stickyParams []string, stickyFields []string, unwrap string) (T, error) { + var typed T + if client == nil { + return typed, fmt.Errorf("alias command %s requires a ModCDPClient", method) + } + rawParams, err := cdpParamsMap(params) + if err != nil { + return typed, err + } + mergeAliasStickyParams(rawParams, sticky, stickyParam, stickyParams, stickyFields) + result, err := client.Send(method, rawParams) + if err != nil { + return typed, err + } + unwrapped, err := UnwrapAliasResult(result, unwrap) + if err != nil { + return typed, err + } + body, err := json.Marshal(unwrapped) + if err != nil { + return typed, err + } + if err := json.Unmarshal(body, &typed); err != nil { + return typed, fmt.Errorf("%s result did not match typed alias result shape: %w", method, err) + } + return typed, nil +} + +func OptionalAliasParams[T any](method string, params []T) (T, error) { + var typed T + if len(params) > 1 { + return typed, fmt.Errorf("%s accepts at most one params object", method) + } + if len(params) == 1 { + return params[0], nil + } + return typed, nil +} + +func mergeAliasStickyParams(params map[string]any, sticky map[string]any, primaryStickyParam string, stickyParams []string, stickyFields []string) { + if len(stickyParams) == 0 { + mergeAliasSticky(params, sticky, primaryStickyParam, stickyFields) + return + } + seen := map[string]bool{} + for _, stickyParam := range stickyParams { + if stickyParam == "" || seen[stickyParam] { + continue + } + seen[stickyParam] = true + if stickyParam != primaryStickyParam { + if _, exists := params[stickyParam]; !exists { + continue + } + } + mergeAliasSticky(params, sticky, stickyParam, stickyFields) + } +} + +func AliasStickyFromResult(result any, unwrap string, stickyFields []string) (AliasSticky, error) { + unwrapped, err := UnwrapAliasResult(result, unwrap) + if err != nil { + return nil, err + } + sticky, ok := unwrapped.(map[string]any) + if !ok { + body, err := json.Marshal(unwrapped) + if err != nil { + return nil, err + } + if err := json.Unmarshal(body, &sticky); err != nil { + return nil, fmt.Errorf("alias sticky result did not decode to object: %w", err) + } + } + if len(stickyFields) == 0 { + return cloneAliasMap(sticky), nil + } + filtered := map[string]any{} + for _, field := range stickyFields { + if value, ok := sticky[field]; ok { + filtered[field] = value + } + } + return filtered, nil +} + +func AliasArrayFromResult(result any, unwrap string) ([]any, error) { + unwrapped, err := UnwrapAliasResult(result, unwrap) + if err != nil { + return nil, err + } + items, ok := unwrapped.([]any) + if ok { + return items, nil + } + body, err := json.Marshal(unwrapped) + if err != nil { + return nil, err + } + if err := json.Unmarshal(body, &items); err != nil { + return nil, fmt.Errorf("alias result did not decode to array: %w", err) + } + return items, nil +} + +func UnwrapAliasResult(result any, unwrap string) (any, error) { + if unwrap == "" { + return result, nil + } + current := result + for _, part := range strings.Split(unwrap, ".") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + object, ok := current.(map[string]any) + if !ok { + body, err := json.Marshal(current) + if err != nil { + return nil, err + } + object = map[string]any{} + if err := json.Unmarshal(body, &object); err != nil { + return nil, fmt.Errorf("alias unwrap %q expected object at %q: %w", unwrap, part, err) + } + } + value, ok := object[part] + if !ok { + return nil, fmt.Errorf("alias unwrap %q missing %q", unwrap, part) + } + current = value + } + return current, nil +} + +func mergeAliasSticky(params map[string]any, sticky map[string]any, stickyParam string, stickyFields []string) { + if params == nil || sticky == nil || len(sticky) == 0 { + return + } + target := params + if stickyParam != "" { + existing, ok := params[stickyParam].(map[string]any) + if !ok { + existing = map[string]any{} + if raw, exists := params[stickyParam]; exists && raw != nil { + body, err := json.Marshal(raw) + if err == nil { + _ = json.Unmarshal(body, &existing) + } + } + params[stickyParam] = existing + } + target = existing + } + fields := stickyFields + if len(fields) == 0 { + for key := range sticky { + fields = append(fields, key) + } + } + for _, field := range fields { + if _, exists := target[field]; exists { + continue + } + if value, ok := sticky[field]; ok { + target[field] = value + } + } +} + +func cloneAliasMap(source AliasSticky) AliasSticky { + if source == nil { + return AliasSticky{} + } + cloned := make(AliasSticky, len(source)) + for key, value := range source { + cloned[key] = value + } + return cloned +} diff --git a/go/modcdp/injector/extension.zip b/go/modcdp/injector/extension.zip index e68ec9f..9c35c8d 100644 Binary files a/go/modcdp/injector/extension.zip and b/go/modcdp/injector/extension.zip differ diff --git a/go/modcdp/types/types.go b/go/modcdp/types/types.go index d18acf0..928dbea 100644 --- a/go/modcdp/types/types.go +++ b/go/modcdp/types/types.go @@ -117,6 +117,33 @@ type ModCDPAddMiddlewareParams struct { Expression string `json:"expression"` } +type ModCDPAliasReturn struct { + Object string `json:"object,omitempty"` + Unwrap string `json:"unwrap,omitempty"` + Array bool `json:"array,omitempty"` + Nullable bool `json:"nullable,omitempty"` +} + +type ModCDPAliasMethod struct { + Name string `json:"name"` + Command string `json:"command,omitempty"` + SDKMethodName string `json:"sdk_method_name,omitempty"` + ParamsSchema map[string]any `json:"params_schema,omitempty"` + ResultSchema map[string]any `json:"result_schema,omitempty"` + StickyParam string `json:"sticky_param,omitempty"` + StickyParams []string `json:"sticky_params,omitempty"` + StickyFields []string `json:"sticky_fields,omitempty"` + Return *ModCDPAliasReturn `json:"return,omitempty"` +} + +type ModCDPAliasObject struct { + Name string `json:"name"` + TypeName string `json:"type_name,omitempty"` + StickySchema map[string]any `json:"sticky_schema,omitempty"` + StickyFields []string `json:"sticky_fields,omitempty"` + Methods []ModCDPAliasMethod `json:"methods,omitempty"` +} + type ModCDPPingParams struct { SentAt int `json:"sent_at,omitempty"` } @@ -163,6 +190,7 @@ type ModCDPServerConfig struct { CustomCommands []ModCDPAddCustomCommandParams `json:"custom_commands,omitempty"` CustomEvents []ModCDPAddCustomEventObjectParams `json:"custom_events,omitempty"` CustomMiddlewares []ModCDPAddMiddlewareParams `json:"custom_middlewares,omitempty"` + CustomAliasObjects []ModCDPAliasObject `json:"custom_alias_objects,omitempty"` } type ModCDPGetTopologyParams struct { diff --git a/js/src/client/ModCDPClient.ts b/js/src/client/ModCDPClient.ts index 09e31da..1e46052 100644 --- a/js/src/client/ModCDPClient.ts +++ b/js/src/client/ModCDPClient.ts @@ -67,6 +67,7 @@ import type { ProtocolParams, ProtocolResult, } from "../types/modcdp.js"; +import { installAliasMethods } from "./alias.js"; type ModCDPClientConfig = { launcher?: LauncherConfig; @@ -224,6 +225,7 @@ export class ModCDPClient< }); if (this.config.client_hydrate_aliases) this.types.installAliases(this, (method, params) => this.send(method, params)); + installAliasMethods(this); } toJSON() { diff --git a/js/src/client/alias.ts b/js/src/client/alias.ts new file mode 100644 index 0000000..3e0a67c --- /dev/null +++ b/js/src/client/alias.ts @@ -0,0 +1,264 @@ +import type { ModCDPClient } from "./ModCDPClient.js"; +import type { ModCDPAliasObject as ModCDPAliasObjectRegistration } from "../types/modcdp.js"; + +type AliasSticky = Record; +type AliasJSONObject = Record; + +class ModCDPAliasObject { + readonly client: TClient; + readonly sticky: AliasSticky; + + constructor(client: TClient, sticky: AliasSticky | null = null) { + this.client = client; + this.sticky = cloneAliasObject(sticky ?? {}); + } +} + +function optionalAliasParams(method: string, params: readonly T[]): T { + if (params.length > 1) throw new Error(`${method} accepts at most one params object`); + return params.length === 1 ? params[0]! : ({} as T); +} + +async function sendAliasCommandWithStickyParams( + client: ModCDPClient, + method: string, + params: unknown, + sticky: AliasSticky, + sticky_param: string, + sticky_params: readonly string[], + sticky_fields: readonly string[], + unwrap: string, +): Promise { + const raw_params = paramsToAliasObject(params); + mergeAliasStickyParams(raw_params, sticky, sticky_param, sticky_params, sticky_fields); + return unwrapAliasResult(await client.send(method, raw_params), unwrap) as T; +} + +async function sendAliasCommand( + client: ModCDPClient, + method: string, + params: unknown, + sticky: AliasSticky, + sticky_param: string, + sticky_fields: readonly string[], + unwrap: string, +): Promise { + return sendAliasCommandWithStickyParams( + client, + method, + params, + sticky, + sticky_param, + sticky_param ? [sticky_param] : [], + sticky_fields, + unwrap, + ); +} + +function aliasStickyFromResult(result: unknown, unwrap: string, sticky_fields: readonly string[]): AliasSticky { + const unwrapped = unwrapAliasResult(result, unwrap); + const source = paramsToAliasObject(unwrapped); + if (sticky_fields.length === 0) return cloneAliasObject(source); + const filtered: AliasSticky = {}; + for (const field of sticky_fields) { + if (field in source) filtered[field] = source[field]; + } + return filtered; +} + +function aliasArrayFromResult(result: unknown, unwrap: string): unknown[] { + const unwrapped = unwrapAliasResult(result, unwrap); + if (Array.isArray(unwrapped)) return unwrapped; + throw new Error(`alias unwrap ${JSON.stringify(unwrap)} expected array`); +} + +function unwrapAliasResult(result: unknown, unwrap: string): unknown { + if (!unwrap) return result; + let current = result; + for (const part of unwrap.split(".")) { + if (!part) continue; + const object = paramsToAliasObject(current); + if (!(part in object)) throw new Error(`alias unwrap ${JSON.stringify(unwrap)} missing ${JSON.stringify(part)}`); + current = object[part]; + } + return current; +} + +function mergeAliasStickyParams( + params: AliasJSONObject, + sticky: AliasSticky, + primary_sticky_param: string, + sticky_params: readonly string[], + sticky_fields: readonly string[], +) { + if (sticky_params.length === 0) { + mergeAliasSticky(params, sticky, primary_sticky_param, sticky_fields); + return; + } + const seen = new Set(); + for (const sticky_param of sticky_params) { + if (!sticky_param || seen.has(sticky_param)) continue; + seen.add(sticky_param); + if (sticky_param !== primary_sticky_param && !(sticky_param in params)) continue; + mergeAliasSticky(params, sticky, sticky_param, sticky_fields); + } +} + +function mergeAliasSticky( + params: AliasJSONObject, + sticky: AliasSticky, + sticky_param: string, + sticky_fields: readonly string[], +) { + if (Object.keys(sticky).length === 0) return; + let target = params; + if (sticky_param) { + const current = params[sticky_param]; + target = current != null && typeof current === "object" && !Array.isArray(current) ? { ...current } : {}; + params[sticky_param] = target; + } + const fields = sticky_fields.length > 0 ? sticky_fields : Object.keys(sticky); + for (const field of fields) { + if (field in target) continue; + if (field in sticky) target[field] = sticky[field]; + } +} + +function paramsToAliasObject(params: unknown): AliasJSONObject { + if (params == null) return {}; + if (typeof params !== "object" || Array.isArray(params)) return {}; + return JSON.parse(JSON.stringify(params)) as AliasJSONObject; +} + +function cloneAliasObject(source: AliasSticky): AliasSticky { + return JSON.parse(JSON.stringify(source)) as AliasSticky; +} + +function installAliasMethods(client: ModCDPClient) { + for (const object of client.types.custom_alias_objects.values()) { + for (const method of object.methods ?? []) { + const path = sdkMethodPath(object.name, method); + if (path.length !== 1) continue; + Object.defineProperty(client, path[0]!, { + configurable: true, + value: (...params: unknown[]) => invokeAliasMethod(client, object.name, {}, method.name, params), + }); + } + } +} + +function createRuntimeAliasObject(client: ModCDPClient, object_name: string, sticky: AliasSticky): T { + const object = aliasObjectRegistration(client, object_name); + const receiver = new ModCDPAliasObject(client, sticky) as ModCDPAliasObject & AliasJSONObject; + Object.assign(receiver, sticky); + for (const method of object.methods ?? []) { + const path = sdkMethodPath(object.name, method); + if (path.length !== 2) continue; + Object.defineProperty(receiver, path[1]!, { + configurable: true, + value: (...params: unknown[]) => invokeAliasMethod(client, object.name, receiver.sticky, method.name, params), + }); + } + return receiver as T; +} + +function invokeAliasMethod( + client: ModCDPClient, + object_name: string, + sticky: AliasSticky, + method_name: string, + params: unknown[], +) { + const object = aliasObjectRegistration(client, object_name); + const method = (object.methods ?? []).find((candidate) => candidate.name === method_name); + if (!method) throw new Error(`Unknown alias method ${object_name}.${method_name}`); + if (params.length > 1) throw new Error(`${object.name}.${method.name} accepts at most one params object`); + const request = params.length === 1 ? params[0] : {}; + const return_object = method.return?.object; + if (!method.command) { + if (!return_object) throw new Error(`${object.name}.${method.name} has no command or return object`); + return createRuntimeAliasObject(client, return_object, {}); + } + return invokeCommandAliasMethod(client, object, method, sticky, request, return_object); +} + +async function invokeCommandAliasMethod( + client: ModCDPClient, + object: ModCDPAliasObjectRegistration, + method: NonNullable[number], + sticky: AliasSticky, + request: unknown, + return_object: string | undefined, +) { + if (return_object) { + const raw = await sendAliasCommandWithStickyParams( + client, + method.command, + request, + sticky, + method.sticky_param ?? "", + method.sticky_params ?? [], + method.sticky_fields ?? [], + "", + ); + if (method.return?.array === true) { + return aliasArrayFromResult(raw, method.return.unwrap ?? "").map((item) => + createRuntimeAliasObject( + client, + return_object, + aliasStickyFromResult(item, "", aliasObjectRegistration(client, return_object).sticky_fields ?? []), + ), + ); + } + const unwrapped = aliasStickyFromResult( + raw, + method.return?.unwrap ?? "", + aliasObjectRegistration(client, return_object).sticky_fields ?? [], + ); + if (method.return?.nullable === true && Object.keys(unwrapped).length === 0) return null; + return createRuntimeAliasObject(client, return_object, unwrapped); + } + return sendAliasCommandWithStickyParams( + client, + method.command, + request, + sticky, + method.sticky_param ?? "", + method.sticky_params ?? [], + method.sticky_fields ?? [], + method.return?.unwrap ?? "", + ); +} + +function aliasObjectRegistration(client: ModCDPClient, object_name: string): ModCDPAliasObjectRegistration { + const object = client.types.custom_alias_objects.get(object_name); + if (!object) throw new Error(`Unknown alias object ${JSON.stringify(object_name)}`); + return object; +} + +function pascal(value: string) { + return value + .split(/[^A-Za-z0-9]+/g) + .filter(Boolean) + .map((part) => part.slice(0, 1).toUpperCase() + part.slice(1)) + .join(""); +} + +function sdkMethodPath(object_name: string, method: { sdk_method_name?: string; name: string }) { + const raw = method.sdk_method_name; + if (typeof raw === "string" && raw.trim()) return raw.split(".").filter(Boolean); + return [object_name, method.name]; +} + +export { + ModCDPAliasObject, + installAliasMethods, + createRuntimeAliasObject, + optionalAliasParams, + sendAliasCommand, + sendAliasCommandWithStickyParams, + aliasStickyFromResult, + aliasArrayFromResult, + unwrapAliasResult, +}; +export type { AliasSticky, AliasJSONObject }; diff --git a/js/src/types/CDPTypes.ts b/js/src/types/CDPTypes.ts index 0ede5c7..717ae5a 100644 --- a/js/src/types/CDPTypes.ts +++ b/js/src/types/CDPTypes.ts @@ -24,6 +24,8 @@ import { type ModCDPAddCustomCommandParams, type ModCDPAddCustomEventObjectParams, type ModCDPAddMiddlewareParams, + type ModCDPAliasObject, + type ModCDPAliasMethod, type ModCDPNamedValue, type ModCDPPayloadSchemaSpec, type ProtocolEventParams, @@ -57,6 +59,7 @@ type CDPTypesConfig; custom_events?: CDPTypesCustomEvents; custom_middlewares?: ModCDPAddMiddlewareParams[]; + custom_alias_objects?: ModCDPAliasObject[]; }; type CDPTypesCommandRegistration = ModCDPAddCustomCommandParams & { params_schema?: z.ZodType | null; @@ -225,6 +228,7 @@ class CDPTypes; readonly custom_events: Map; readonly custom_middlewares: ModCDPAddMiddlewareParams[]; + readonly custom_alias_objects: Map; readonly event_schemas = new Map(); readonly command_params_schemas = new Map(); readonly command_result_schemas = new Map(); @@ -236,12 +240,14 @@ class CDPTypes { const parsed = Mod.EvaluateParams.parse(params); return ` @@ -260,6 +266,7 @@ class CDPTypes middleware, ), + custom_alias_objects: this.customAliasObjectWireRegistrations(), }, state: { custom_commands: this.custom_commands.size, custom_events: this.custom_events.size, custom_middlewares: this.custom_middlewares.length, + custom_alias_objects: this.custom_alias_objects.size, command_params_schemas: this.command_params_schemas.size, command_result_schemas: this.command_result_schemas.size, event_schemas: this.event_schemas.size, @@ -486,6 +495,72 @@ class CDPTypes ({ + name: object.name, + ...(object.type_name == null || object.type_name === "" ? {} : { type_name: object.type_name }), + ...(object.sticky_schema == null ? {} : { sticky_schema: serializablePayloadSchema(object.sticky_schema) }), + sticky_fields: [...(object.sticky_fields ?? [])], + methods: (object.methods ?? []).map((method) => ({ + name: method.name, + ...(method.command == null || method.command === "" ? {} : { command: method.command }), + ...(method.sdk_method_name == null || method.sdk_method_name === "" + ? {} + : { sdk_method_name: method.sdk_method_name }), + ...(method.params_schema == null ? {} : { params_schema: serializablePayloadSchema(method.params_schema) }), + ...(method.result_schema == null ? {} : { result_schema: serializablePayloadSchema(method.result_schema) }), + ...(method.sticky_param == null || method.sticky_param === "" ? {} : { sticky_param: method.sticky_param }), + sticky_params: [...(method.sticky_params ?? [])], + sticky_fields: [...(method.sticky_fields ?? [])], + ...(method.return == null ? {} : { return: method.return }), + })), + })); + } + addCustomEvent(registration: ModCDPAddCustomEventObjectParams) { const parsed = Mod.AddCustomEventObjectParams.parse(registration); const name = normalizeModCDPName(parsed.name); diff --git a/js/src/types/modcdp.ts b/js/src/types/modcdp.ts index 66abfd5..1de777b 100644 --- a/js/src/types/modcdp.ts +++ b/js/src/types/modcdp.ts @@ -162,6 +162,42 @@ const ModCDPAddMiddlewareParamsSchema = z.object({ }); type ModCDPAddMiddlewareParams = z.infer; +const ModCDPAliasReturnSchema = z + .object({ + object: z.string().optional(), + unwrap: z.string().optional(), + array: z.boolean().optional(), + nullable: z.boolean().optional(), + }) + .strict(); +type ModCDPAliasReturn = z.infer; + +const ModCDPAliasMethodSchema = z + .object({ + name: z.string(), + command: z.string().optional(), + sdk_method_name: z.string().optional(), + params_schema: ModCDPPayloadSchemaSpecSchema.optional().nullable(), + result_schema: ModCDPPayloadSchemaSpecSchema.optional().nullable(), + sticky_param: z.string().optional(), + sticky_params: z.array(z.string()).optional(), + sticky_fields: z.array(z.string()).optional(), + return: ModCDPAliasReturnSchema.optional().nullable(), + }) + .strict(); +type ModCDPAliasMethod = z.infer; + +const ModCDPAliasObjectSchema = z + .object({ + name: z.string(), + type_name: z.string().optional(), + sticky_schema: ModCDPPayloadSchemaSpecSchema.optional().nullable(), + sticky_fields: z.array(z.string()).optional(), + methods: z.array(ModCDPAliasMethodSchema).optional(), + }) + .strict(); +type ModCDPAliasObject = z.infer; + const BrowserbaseBrowserSettingsSchema = z .object({ extensionId: z.string().optional(), @@ -257,6 +293,7 @@ const ModCDPServerConfigSchema = z custom_commands: z.array(ModCDPAddCustomCommandParamsSchema).optional(), custom_events: z.array(ModCDPAddCustomEventObjectParamsSchema).optional(), custom_middlewares: z.array(ModCDPAddMiddlewareParamsSchema).optional(), + custom_alias_objects: z.array(ModCDPAliasObjectSchema).optional(), }) .strict(); type ModCDPServerConfig = z.infer; @@ -441,6 +478,9 @@ type ModCDPCustomEventRegistration = z.infer; +const ModCDPAliasObjectRegistrationSchema = ModCDPAliasObjectSchema; +type ModCDPAliasObjectRegistration = z.infer; + const CdpErrorSchema = z.object({ code: z.number().optional().nullable(), message: z.string(), @@ -509,6 +549,9 @@ const Mod = { AddCustomEventObjectParams: ModCDPAddCustomEventObjectParamsSchema, AddCustomEventParams: ModCDPAddCustomEventParamsSchema, AddMiddlewareParams: ModCDPAddMiddlewareParamsSchema, + AliasReturn: ModCDPAliasReturnSchema, + AliasMethod: ModCDPAliasMethodSchema, + AliasObject: ModCDPAliasObjectSchema, LauncherConfig: ModCDPLauncherConfigSchema, UpstreamConfig: ModCDPUpstreamConfigSchema, ClientConfig: ModCDPClientConfigSchema, @@ -536,6 +579,7 @@ const Mod = { CustomCommandRegistration: ModCDPCustomCommandRegistrationSchema, CustomEventRegistration: ModCDPCustomEventRegistrationSchema, MiddlewareRegistration: ModCDPMiddlewareRegistrationSchema, + AliasObjectRegistration: ModCDPAliasObjectRegistrationSchema, } as const; export { @@ -569,6 +613,9 @@ export { ModCDPAddCustomEventObjectParamsSchema, ModCDPAddCustomEventParamsSchema, ModCDPAddMiddlewareParamsSchema, + ModCDPAliasReturnSchema, + ModCDPAliasMethodSchema, + ModCDPAliasObjectSchema, ModCDPLauncherConfigSchema, ModCDPUpstreamConfigSchema, ModCDPClientConfigSchema, @@ -602,6 +649,7 @@ export { ModCDPCustomCommandRegistrationSchema, ModCDPCustomEventRegistrationSchema, ModCDPMiddlewareRegistrationSchema, + ModCDPAliasObjectRegistrationSchema, CdpErrorSchema, CdpCommandMessageSchema, CdpResponseMessageSchema, @@ -631,6 +679,9 @@ export type { ModCDPAddCustomEventObjectParams, ModCDPAddCustomEventParams, ModCDPAddMiddlewareParams, + ModCDPAliasReturn, + ModCDPAliasMethod, + ModCDPAliasObject, ModCDPLauncherConfig, ModCDPUpstreamConfig, ModCDPClientConfig, @@ -664,6 +715,7 @@ export type { ModCDPCustomCommandRegistration, ModCDPCustomEventRegistration, ModCDPMiddlewareRegistration, + ModCDPAliasObjectRegistration, CdpError, CdpCommandMessage, CdpResponseMessage, diff --git a/js/test/stagehand_client_generated/stagehand_client_live.test.ts b/js/test/stagehand_client_generated/stagehand_client_live.test.ts new file mode 100644 index 0000000..12d1541 --- /dev/null +++ b/js/test/stagehand_client_generated/stagehand_client_live.test.ts @@ -0,0 +1,463 @@ +// Before running this live test, build a local Stagehand alias manifest and +// regenerate the TypeScript client: +// +// pnpm --dir ../stagehand-server exec node src/protocol/generate_modcdp_alias_manifest.mjs ../modcdp2/testdata/codegen/stagehand_alias_manifest.json +// pnpm run build:ts +// pnpm run test:e2e:ts +// +// The manifest is intentionally untracked; do not commit it. + +import assert from "node:assert/strict"; +import { fileURLToPath } from "node:url"; +import { test } from "vitest"; + +import { type Page, StagehandClient } from "./stagehand_client_gen.js"; + +const CHALLENGE_URL = "https://pirate.github.io/stress-tests/challenge.html"; + +test("generated Stagehand client solves challenge suite", async () => { + const client = new StagehandClient(); + await client.connect(); + try { + let page = await client.browser.new_page({ url: "about:blank" }); + page = await page.goto({ url: CHALLENGE_URL, waitUntil: "load" }); + await page.wait_for_expression({ + expression: "document.querySelectorAll('.task').length === 45", + timeout_ms: 20_000, + }); + + const fixtureFile = fileURLToPath(import.meta.url); + + const challenges: [string, (page: Page) => Promise][] = [ + [ + "simple-button", + async (page) => { + const button = await page.locate({ locator: { css: "#start-button" } }); + await button.click(); + }, + ], + [ + "radio-selection", + async (page) => { + await page.click({ locator: { css: `input[name="entity"][value="dont-choose-this-one"]` } }); + }, + ], + [ + "checkbox-check", + async (page) => { + await page.click({ locator: { css: "input.challenge-checkbox", idx: 0 } }); + await page.click({ locator: { css: "input.challenge-checkbox", idx: 1 } }); + const first = await page.locate({ locator: { css: "input.challenge-checkbox" } }); + const third = await first.nth({ index: 2 }); + await third.click(); + }, + ], + [ + "search-squash", + async (page) => { + const locator = { css: "#search-input" }; + await page.fill({ locator, value: "squash" }); + await page.key_press({ locator, key: "Enter" }); + }, + ], + [ + "date-time-input", + async (page) => { + const now = new Date(); + await page.fill({ locator: { css: "#date-picker" }, value: now.toISOString().slice(0, 10) }); + await page.fill({ + locator: { css: "#time-picker" }, + value: `${now.getHours().toString().padStart(2, "0")}:${now.getMinutes().toString().padStart(2, "0")}`, + }); + }, + ], + [ + "copy-text", + async (page) => { + await page.fill({ locator: { css: "#copy-input" }, value: "abc" }); + }, + ], + [ + "slider-drag", + async (page) => { + await page.click({ locator: { css: "#range-slider" }, offsetX: 295, offsetY: 8 }); + }, + ], + [ + "hover-element", + async (page) => { + await page.hover({ locator: { css: "#hover-target" } }); + await page.wait_for_timeout({ ms: 1100 }); + }, + ], + [ + "drag-drop", + async (page) => { + await page.drag_and_drop({ start: { css: "#drag-element" }, end: { css: "#drop-target" } }); + }, + ], + [ + "file-drop", + async (page) => { + await page.drop_files({ files: [fixtureFile], locator: { css: "#file-drop-area" } }); + }, + ], + [ + "multi-select", + async (page) => { + await page.select_option({ + locator: { css: "#multi-select-element" }, + values: ["option1", "option2", "option3"], + }); + }, + ], + [ + "canvas-captcha", + async (page) => { + await page.fill({ locator: { css: "#canvas-text-input" }, value: "CAPTCHA123" }); + }, + ], + [ + "iframe-slider", + async (page) => { + await page.click({ locator: { css: "#iframe-slider", idx: 2 }, offsetX: 295, offsetY: 8 }); + await page.click({ locator: { css: "#iframe-slider-popover-button" } }); + }, + ], + [ + "oopif-form-submit", + async (page) => { + await page.wait_for_locator({ locator: { css: "#first-name" }, timeout_ms: 20_000 }); + await page.fill({ locator: { css: "#first-name" }, value: "Ada" }); + await page.fill({ locator: { css: "#last-name" }, value: "Lovelace" }); + await page.fill({ locator: { css: "#email" }, value: "ada@example.com" }); + await page.click({ locator: { css: "#contact-form button[type='submit']" } }); + }, + ], + [ + "oopif-open-shadow-button", + async (page) => { + await page.click({ + locator: { + xpath: "/html[1]/body[1]/div[2]/div[15]/iframe[1]/html[1]/body[1]/shadow-demo[1]//div[1]/button[1]", + }, + }); + await page.click({ locator: { css: "#oopif-open-shadow-confirm" } }); + }, + ], + [ + "oopif-closed-shadow-button", + async (page) => { + await page.click({ + locator: { + xpath: "/html[1]/body[1]/div[2]/div[16]/iframe[1]/html[1]/body[1]/shadow-demo[1]//div[1]/button[1]", + }, + }); + await page.click({ locator: { css: "#oopif-closed-shadow-confirm" } }); + }, + ], + [ + "shadow-dom-dblclick", + async (page) => { + await page.double_click({ locator: { text: "Double-click me" } }); + }, + ], + [ + "right-click-component", + async (page) => { + const locator = await page.locate({ locator: { text: "Right-click me" } }); + await locator.click({ button: "secondary" }); + }, + ], + [ + "scroll-accept", + async (page) => { + await page.scroll({ locator: { css: "#scroll-container" }, percent: 100 }); + await page.click({ locator: { css: "#accept-button" } }); + }, + ], + [ + "draw-circle", + async (page) => { + await page.click({ locator: { css: "#draw-canvas" }, offsetX: 150, offsetY: 150 }); + await page.scroll({ deltaY: 320 }); + const canvas = await page.locate({ locator: { css: "#draw-canvas" } }); + assert.ok(canvas.coordinates?.left != null, "canvas locator is missing left coordinate"); + assert.ok(canvas.coordinates.top != null, "canvas locator is missing top coordinate"); + const left = canvas.coordinates.left; + const top = canvas.coordinates.top; + for (let i = 0; i < 24; i++) { + const angle = (Math.PI * 2 * i) / 24; + await page.hover({ + locator: { coordinates: { x: left + 150 + Math.cos(angle) * 100, y: top + 150 + Math.sin(angle) * 100 } }, + }); + } + }, + ], + [ + "cancel-dialog", + async (page) => { + await page.click({ dialog: { accept: false }, locator: { css: "#confirm-button" } }); + }, + ], + [ + "geolocation-permission", + async (page) => { + await page.set_geolocation({ + accuracy: 25, + latitude: 37.7749, + longitude: -122.4194, + origin: "https://pirate.github.io", + }); + await page.click({ locator: { css: "#request-location-button" } }); + }, + ], + [ + "arrow-key-presses", + async (page) => { + const locator = { css: "#key-press-area" }; + await page.click({ locator }); + await page.key_press({ key: "ArrowRight", repeat: 3, locator }); + }, + ], + [ + "alert-secret", + async (page) => { + const result = await page.click({ dialog: { accept: true }, locator: { css: "#secret-alert-button" } }); + assert.match(result.message ?? "", /avocado/); + await page.fill({ locator: { css: "#secret-word-input" }, value: "avocado" }); + }, + ], + [ + "press-hold-button", + async (page) => { + await page.click_and_hold({ durationMs: 800, locator: { css: "#hold-button" } }); + }, + ], + [ + "tooltip-secret", + async (page) => { + await page.hover({ locator: { css: "#tooltip-element" } }); + await page.fill({ locator: { css: "#tooltip-secret-input" }, value: "octopus" }); + }, + ], + [ + "resize-textarea", + async (page) => { + const textarea = await page.locate({ locator: { css: "#resizable-textarea" } }); + assert.ok(textarea.coordinates?.right != null, "textarea locator is missing right coordinate"); + assert.ok(textarea.coordinates.bottom != null, "textarea locator is missing bottom coordinate"); + const right = textarea.coordinates.right; + const bottom = textarea.coordinates.bottom; + await page.drag_and_drop({ + start: { coordinates: { x: right - 2, y: bottom - 2 } }, + end: { coordinates: { x: right + 160, y: bottom + 90 } }, + }); + await page.fill({ locator: { css: "#resize-secret-input" }, value: "giraffe" }); + }, + ], + [ + "file-upload", + async (page) => { + await page.set_input_files({ files: [fixtureFile], locator: { css: "#file-upload-input" } }); + }, + ], + [ + "phone-input", + async (page) => { + const locator = { css: "#phone-input-field" }; + await page.click({ locator }); + await page.key_press({ locator, key: "Backspace" }); + await page.key_press({ method: "type", locator, value: "555-1234" }); + }, + ], + [ + "expand-details", + async (page) => { + const locator = await page.locate({ locator: { css: "#details-element summary" } }); + await locator.click(); + }, + ], + [ + "drag-square-to-circle", + async (page) => { + await page.click({ locator: { css: "#drag-canvas" }, offsetX: 175, offsetY: 125 }); + await page.scroll({ deltaY: 220 }); + const canvas = await page.locate({ locator: { css: "#drag-canvas" } }); + assert.ok(canvas.coordinates?.left != null, "drag canvas locator is missing left coordinate"); + assert.ok(canvas.coordinates.top != null, "drag canvas locator is missing top coordinate"); + const left = canvas.coordinates.left; + const top = canvas.coordinates.top; + await page.drag_and_drop({ + start: { coordinates: { x: left + 75, y: top + 125 } }, + end: { coordinates: { x: left + 250, y: top + 125 } }, + }); + }, + ], + [ + "audio-transcription", + async (page) => { + await page.fill({ locator: { css: "#transcription-input" }, value: "everything" }); + }, + ], + [ + "dropdown-selections", + async (page) => { + await page.select_option({ locator: { css: "#color-select" }, values: ["red"] }); + await page.select_option({ locator: { css: "#object-select" }, values: ["ball"] }); + }, + ], + [ + "contenteditable-div", + async (page) => { + const locator = { css: "#editable-content" }; + await page.click({ locator }); + await page.key_press({ method: "type", locator, value: "banana" }); + }, + ], + [ + "nested-tiny-button", + async (page) => { + await page.click({ + dialog: { accept: true }, + expect_timeout_ms: 20_000, + locator: { css: "#deep-tiny-button" }, + }); + }, + ], + [ + "wrapped-word-click", + async (page) => { + await page.hover({ locator: { css: "#wrapped-word-paragraph" } }); + const pointResult = await page.evaluate({ + expression: `(() => { + const paragraph = document.getElementById("wrapped-word-paragraph"); + if (!(paragraph?.firstChild instanceof Text)) throw new Error("wrapped word paragraph text is missing"); + const paragraphRect = paragraph.getBoundingClientRect(); + const start = paragraph.textContent.indexOf("ox"); + const range = document.createRange(); + range.setStart(paragraph.firstChild, start); + range.setEnd(paragraph.firstChild, start + 2); + const rect = Array.from(range.getClientRects()).at(-1); + range.detach(); + if (rect == null) throw new Error("wrapped word range has no visible rect"); + return { x: rect.left + rect.width / 2 - paragraphRect.left, y: rect.top + rect.height / 2 - paragraphRect.top }; + })()`, + }); + const point = pointResult.value as { x?: unknown; y?: unknown }; + const x = point.x; + const y = point.y; + if (typeof x !== "number" || typeof y !== "number") + throw new Error("wrapped word point is missing numeric coordinates"); + await page.click({ locator: { css: "#wrapped-word-paragraph" }, offsetX: x, offsetY: y }); + }, + ], + [ + "long-link-maze", + async (page) => { + await page.click({ locator: { text: "TARGET-LINK::orion-needle-1847" } }); + }, + ], + [ + "wall-secret-word", + async (page) => { + await page.fill({ locator: { css: "#wall-secret-input" }, value: "cobaltglass" }); + }, + ], + [ + "closed-shadow-aria-word", + async (page) => { + await page.fill({ locator: { css: "#closed-shadow-aria-input" }, value: "violetcircuit" }); + }, + ], + [ + "dynamic-frame-ordinal-trap", + async (page) => { + await page.click({ locator: { css: "#frame-trap-start" } }); + await page.wait_for_locator({ locator: { css: "#frame-trap-create-third" }, timeout_ms: 10_000 }); + await page.click({ locator: { css: "#frame-trap-create-third" } }); + await page.wait_for_locator({ locator: { css: "#frame-trap-insert-final" }, timeout_ms: 10_000 }); + await page.click({ locator: { css: "#frame-trap-insert-final" } }); + await page.wait_for_locator({ locator: { css: "#frame-trap-final-button" }, timeout_ms: 10_000 }); + await page.click({ locator: { css: "#frame-trap-final-button" } }); + }, + ], + [ + "rotated-transform-click", + async (page) => { + await page.hover({ locator: { css: "#rotated-transform-button" } }); + const pointResult = await page.evaluate({ + expression: `(() => { + const button = document.getElementById("rotated-transform-button"); + if (button == null) throw new Error("rotated transform button is missing"); + const rect = button.getBoundingClientRect(); + return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 }; + })()`, + }); + const point = pointResult.value as { x?: unknown; y?: unknown }; + const x = point.x; + const y = point.y; + if (typeof x !== "number" || typeof y !== "number") + throw new Error("rotated transform point is missing numeric coordinates"); + await page.click({ locator: { coordinates: { x, y } } }); + }, + ], + [ + "transparent-overlay-button", + async (page) => { + const locator = await page.locate({ locator: { css: "#transparent-overlay-target" } }); + await locator.click(); + }, + ], + [ + "realistic-input-sequence", + async (page) => { + await page.scroll({ locator: { css: "#realistic-scroll-panel" }, deltaY: 450 }); + await page.scroll({ locator: { css: "#realistic-scroll-panel" }, deltaY: 450 }); + const locator = { css: "#realistic-type-input" }; + await page.click({ locator }); + for (const key of ["o", "r", "c", "h", "i", "d"]) await page.key_press({ locator, key }); + await page.click({ locator: { css: "#realistic-submit-button" } }); + }, + ], + [ + "css-transform-text", + async (page) => { + await page.fill({ locator: { css: "#css-transform-text-input" }, value: "skyline" }); + }, + ], + [ + "google-docs", + async (page) => { + const locator = { css: "#google-docs-answer-input" }; + await page.click({ locator }); + await page.key_press({ method: "type", locator, value: "snooker" }); + }, + ], + ]; + + for (const [taskId, run] of challenges) { + await run(page); + try { + await page.wait_for_expression({ + expression: `document.getElementById(${JSON.stringify(taskId)})?.classList.contains('completed') === true`, + timeout_ms: 20_000, + }); + } catch (error) { + throw new Error(`${taskId} did not complete`, { cause: error }); + } + } + const scoreResult = await page.evaluate({ expression: "document.querySelectorAll('.task.completed').length" }); + assert.equal(typeof scoreResult.value, "number"); + const score = scoreResult.value; + const countsResult = await page.evaluate({ + expression: `(() => ({ completed: document.querySelectorAll('.task.completed').length, total: document.querySelectorAll('.task').length }))()`, + }); + const counts = countsResult.value as { completed: number; total: number }; + assert.equal(score, counts.completed); + assert.equal(counts.total, 45); + assert.equal(score, 45); + } finally { + await client.close(); + } +}, 180_000); diff --git a/package.json b/package.json index 90e974e..6a38942 100644 --- a/package.json +++ b/package.json @@ -158,6 +158,9 @@ "clean": "rm -rf dist", "build": "pnpm run build:package && pnpm run build:extension-assets && pnpm run build:python-readme", "build:clean": "pnpm run clean && pnpm run build", + "build:ts": "node --experimental-strip-types src/codegen/codegen_ts.ts", + "build:python": "node --experimental-strip-types src/codegen/codegen_python.ts", + "build:go": "node --experimental-strip-types src/codegen/codegen_go.ts && gofmt -w testdata/codegen/stagehand_client_generated_go/stagehand_client_gen.go", "build:package": "tsc -p tsconfig.json && rm -rf dist/js/test && rm -f dist/.gitignore", "build:extension-assets": "node js/scripts/build-extension-assets.mjs", "build:python-readme": "test -f python/README.md", @@ -172,6 +175,9 @@ "demo:proxy:puppeteer": "pnpm run build && node dist/js/examples/puppeteer.js", "proxy": "pnpm run build && node dist/js/src/proxy/cli.js", "test": "pnpm run build && vitest run --fileParallelism=false --maxWorkers=1", + "test:e2e:ts": "vitest run js/test/stagehand_client_generated/stagehand_client_live.test.ts --fileParallelism=false --maxWorkers=1", + "test:e2e:python": "cd python && uv run python -m unittest tests.stagehand_client_generated.stagehand_client_live_test", + "test:e2e:go": "cd testdata/codegen/stagehand_client_generated_go && go test -count=1 -run TestGeneratedStagehandClientChallengeSuite", "test:launchers": "vitest run js/test/test.LocalBrowserLauncher.ts js/test/test.RemoteBrowserLauncher.ts js/test/test.BBBrowserLauncher.ts", "pack:npm": "pnpm run build && npm pack --ignore-scripts", "prepare": "pnpm run build" diff --git a/python/modcdp/client/alias.py b/python/modcdp/client/alias.py new file mode 100644 index 0000000..ff3afc4 --- /dev/null +++ b/python/modcdp/client/alias.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import json +from collections.abc import Mapping, Sequence +from typing import Any, TypeVar, cast + +from pydantic import BaseModel + +from .ModCDPClient import ModCDPClient + +AliasSticky = dict[str, object] +AliasJSONObject = dict[str, object] +T = TypeVar("T") + + +class ModCDPAliasObject: + def __init__(self, client: ModCDPClient, sticky: Mapping[str, object] | None = None) -> None: + self.client = client + self.sticky = _clone_alias_object(sticky or {}) + + +def optional_alias_params(method: str, params: Sequence[T]) -> T | None: + if len(params) > 1: + raise ValueError(f"{method} accepts at most one params object") + return params[0] if params else None + + +def send_alias_command_with_sticky_params( + client: ModCDPClient, + method: str, + params: object, + sticky: Mapping[str, object], + sticky_param: str, + sticky_params: Sequence[str], + sticky_fields: Sequence[str], + unwrap: str, +) -> object: + raw_params = _params_to_alias_object(params) + _merge_alias_sticky_params(raw_params, sticky, sticky_param, sticky_params, sticky_fields) + return unwrap_alias_result(client._send_command(method, raw_params), unwrap) + + +def send_alias_command( + client: ModCDPClient, + method: str, + params: object, + sticky: Mapping[str, object], + sticky_param: str, + sticky_fields: Sequence[str], + unwrap: str, +) -> object: + return send_alias_command_with_sticky_params( + client, + method, + params, + sticky, + sticky_param, + [sticky_param] if sticky_param else [], + sticky_fields, + unwrap, + ) + + +def alias_sticky_from_result(result: object, unwrap: str, sticky_fields: Sequence[str]) -> AliasSticky: + source = _params_to_alias_object(unwrap_alias_result(result, unwrap)) + if not sticky_fields: + return _clone_alias_object(source) + return {field: source[field] for field in sticky_fields if field in source} + + +def alias_array_from_result(result: object, unwrap: str) -> list[object]: + unwrapped = unwrap_alias_result(result, unwrap) + if isinstance(unwrapped, list): + return cast(list[object], unwrapped) + raise TypeError(f"alias unwrap {unwrap!r} expected array") + + +def unwrap_alias_result(result: object, unwrap: str) -> object: + if not unwrap: + return result + current = result + for part in unwrap.split("."): + if not part: + continue + current_object = _params_to_alias_object(current) + if part not in current_object: + raise KeyError(f"alias unwrap {unwrap!r} missing {part!r}") + current = current_object[part] + return current + + +def _merge_alias_sticky_params( + params: AliasJSONObject, + sticky: Mapping[str, object], + primary_sticky_param: str, + sticky_params: Sequence[str], + sticky_fields: Sequence[str], +) -> None: + if not sticky_params: + _merge_alias_sticky(params, sticky, primary_sticky_param, sticky_fields) + return + seen: set[str] = set() + for sticky_param in sticky_params: + if not sticky_param or sticky_param in seen: + continue + seen.add(sticky_param) + if sticky_param != primary_sticky_param and sticky_param not in params: + continue + _merge_alias_sticky(params, sticky, sticky_param, sticky_fields) + + +def _merge_alias_sticky( + params: AliasJSONObject, + sticky: Mapping[str, object], + sticky_param: str, + sticky_fields: Sequence[str], +) -> None: + if not sticky: + return + target = params + if sticky_param: + raw_target = params.get(sticky_param) + target = dict(raw_target) if isinstance(raw_target, Mapping) else {} + params[sticky_param] = target + fields = sticky_fields or list(sticky.keys()) + for field in fields: + if field in target: + continue + if field in sticky: + target[field] = sticky[field] + + +def _params_to_alias_object(params: object) -> AliasJSONObject: + if params is None: + return {} + if isinstance(params, BaseModel): + return dict(params.model_dump(mode="json", exclude_none=True, by_alias=True)) + if isinstance(params, Mapping): + return {str(key): value for key, value in params.items()} + return {} + + +def _clone_alias_object(source: Mapping[str, object]) -> AliasSticky: + return cast(AliasSticky, json.loads(json.dumps(dict(source)))) diff --git a/python/modcdp/extension.zip b/python/modcdp/extension.zip index e68ec9f..9c35c8d 100644 Binary files a/python/modcdp/extension.zip and b/python/modcdp/extension.zip differ diff --git a/python/modcdp/types/CDPTypes.py b/python/modcdp/types/CDPTypes.py index bd6acc1..9c1c610 100644 --- a/python/modcdp/types/CDPTypes.py +++ b/python/modcdp/types/CDPTypes.py @@ -24,6 +24,8 @@ ModCDPAddCustomEventObjectParams, ModCDPAddCustomEventParams, ModCDPAddMiddlewareParams, + ModCDPAliasObject, + ModCDPAliasMethod, ModCDPPayloadSchemaSpec, ProtocolParams, ProtocolPayload, @@ -39,9 +41,11 @@ CustomCommandRegistration: TypeAlias = dict[str, object] CustomEventRegistration: TypeAlias = dict[str, object] CustomMiddlewareRegistration: TypeAlias = dict[str, object] +CustomAliasObjectRegistration: TypeAlias = dict[str, object] CustomCommandRegistrations: TypeAlias = Sequence[Mapping[str, object]] | dict[str, object] CustomEventRegistrations: TypeAlias = Sequence[str | Mapping[str, object]] | dict[str, object] CustomMiddlewareRegistrations: TypeAlias = Sequence[Mapping[str, object]] +CustomAliasObjectRegistrations: TypeAlias = Sequence[Mapping[str, object]] class _ModCDPAddCustomCommand(BaseModel): @@ -115,6 +119,7 @@ class CDPTypesConfig(BaseModel): custom_commands: CustomCommandRegistrations | None = None custom_events: CustomEventRegistrations | None = None custom_middlewares: CustomMiddlewareRegistrations | None = None + custom_alias_objects: CustomAliasObjectRegistrations | None = None JSON_SCHEMA_OBJECT: JsonSchema = {"type": "object"} @@ -448,7 +453,7 @@ def _json_value(value: object) -> JsonValue: if value is None or isinstance(value, bool | int | float | str): return value if isinstance(value, type) and issubclass(value, BaseModel): - return _json_object(value.model_json_schema()) + return _normalize_json_schema(_json_object(value.model_json_schema())) if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray): return [_json_value(item) for item in value] if isinstance(value, Mapping): @@ -456,6 +461,30 @@ def _json_value(value: object) -> JsonValue: raise TypeError(f"expected a JSON value, got {type(value).__name__}") +def _normalize_json_schema(value: JsonValue) -> JsonValue: + if isinstance(value, list): + return [_normalize_json_schema(item) for item in value] + if not isinstance(value, dict): + return value + normalized: dict[str, JsonValue] = {} + for key, raw_value in value.items(): + if key == "type" and raw_value == "None": + normalized[key] = "null" + elif key == "type" and isinstance(raw_value, list): + normalized[key] = ["null" if item == "None" else _normalize_json_schema(item) for item in raw_value] + else: + normalized[key] = _normalize_json_schema(raw_value) + return normalized + + +def _json_schema_object(value: object) -> dict[str, JsonValue]: + json_schema = _json_object(value) + normalized = _normalize_json_schema(json_schema) + if not isinstance(normalized, dict): + raise TypeError("expected a JSON Schema object") + return normalized + + def _model_or_json_object(value: object) -> ProtocolResult: if isinstance(value, BaseModel): return _json_object(value.model_dump(mode="json", exclude_none=True, by_alias=True)) @@ -470,6 +499,7 @@ def __init__(self, config: CDPTypesConfig | Mapping[str, object] | None = None, self.custom_commands: dict[str, CustomCommandRegistration] = {} self.custom_events: dict[str, CustomEventRegistration] = {} self.custom_middlewares: list[CustomMiddlewareRegistration] = [] + self.custom_alias_objects: dict[str, CustomAliasObjectRegistration] = {} self.command_schemas: dict[str, CommandSchema] = {} self.event_schemas: dict[str, TypeAdapter[object]] = {} self.native_command_names: set[str] = set() @@ -487,6 +517,8 @@ def __init__(self, config: CDPTypesConfig | Mapping[str, object] | None = None, self.addCustomEvent({"name": event} if isinstance(event, str) else event) for middleware in parsed_config.custom_middlewares or []: self.addCustomMiddleware(middleware) + for alias_object in parsed_config.custom_alias_objects or []: + self.addCustomAliasObject(alias_object) self.service_worker_expression_builders["Mod.evaluate"] = lambda params, _cdp_session_id: ( "\n async ({ params = {}, cdpSessionId = null }) => {\n" f" const value = ({params['expression']});\n" @@ -505,11 +537,13 @@ def update( commands = [*self.custom_commands.values(), *_custom_command_entries(parsed_config.custom_commands)] events = [*self.custom_events.values(), *_custom_event_entries(parsed_config.custom_events)] middlewares = [*self.custom_middlewares, *(parsed_config.custom_middlewares or [])] + alias_objects = [*self.custom_alias_objects.values(), *(parsed_config.custom_alias_objects or [])] return CDPTypes( { "custom_commands": commands, "custom_events": events, "custom_middlewares": middlewares, + "custom_alias_objects": alias_objects, } ) @@ -531,11 +565,13 @@ def toJSON(self) -> dict[str, object]: "custom_commands": custom_commands, "custom_events": self.customEventWireRegistrations(), "custom_middlewares": custom_middlewares, + "custom_alias_objects": self.customAliasObjectWireRegistrations(), }, "state": { "custom_commands": len(self.custom_commands), "custom_events": len(self.custom_events), "custom_middlewares": len(self.custom_middlewares), + "custom_alias_objects": len(self.custom_alias_objects), "command_params_schemas": len(self.command_schemas), "command_result_schemas": len(self.command_schemas), "event_schemas": len(self.event_schemas), @@ -803,6 +839,62 @@ def customMiddlewareRegistrations(self, phase: str, name: str) -> list[CustomMid if middleware["phase"] == phase and (middleware.get("name") in (None, "*", name)) ] + def addCustomAliasObject(self, registration: Mapping[str, object]) -> str: + parsed = ModCDPAliasObject.model_validate(registration) + name = parsed.name.strip() + if not name: + raise ValueError("custom alias object name is required") + methods: list[dict[str, object]] = [] + for method in parsed.methods or []: + method_name = method.name.strip() + if not method_name: + raise ValueError(f"{name} alias method name is required") + command_name = "" if not method.command else normalizeModCDPName(method.command) + if command_name and not re.match(r"^[^.]+\.[^.]+$", command_name): + raise ValueError(f"{name}.{method_name} command must be in Domain.method form") + params_schema = self._adapterFromOptionalSchema(method.params_schema, "params_schema") + result_schema = self._adapterFromOptionalSchema(method.result_schema, "result_schema") + if command_name: + with self._lock: + existing = self.command_schemas.get(command_name, CommandSchema()) + if params_schema.adapter is not None: + existing = CommandSchema(params=params_schema.adapter, result=existing.result) + if result_schema.adapter is not None and existing.result is None: + existing = CommandSchema(params=existing.params, result=result_schema.adapter) + self.command_schemas[command_name] = existing + method_registration: dict[str, object] = {"name": method_name} + if command_name: + method_registration["command"] = command_name + if method.sdk_method_name: + method_registration["sdk_method_name"] = method.sdk_method_name.strip() + if params_schema.json_schema: + method_registration["params_schema"] = params_schema.json_schema + if result_schema.json_schema: + method_registration["result_schema"] = result_schema.json_schema + if method.sticky_param: + method_registration["sticky_param"] = method.sticky_param.strip() + method_registration["sticky_params"] = list(method.sticky_params or []) + method_registration["sticky_fields"] = list(method.sticky_fields or []) + if method.return_ is not None: + method_registration["return"] = method.return_.model_dump(mode="json", exclude_none=True, by_alias=True) + methods.append(method_registration) + alias_object: CustomAliasObjectRegistration = { + "name": name, + "sticky_fields": list(parsed.sticky_fields or []), + "methods": methods, + } + if parsed.type_name: + alias_object["type_name"] = parsed.type_name.strip() + if parsed.sticky_schema is not None: + alias_object["sticky_schema"] = _json_value(parsed.sticky_schema) + with self._lock: + self.custom_alias_objects[name] = alias_object + return name + + def customAliasObjectWireRegistrations(self) -> list[CustomAliasObjectRegistration]: + with self._lock: + return [dict(alias_object) for alias_object in self.custom_alias_objects.values()] + def serviceWorkerCommandStep( self, method: str, @@ -925,10 +1017,10 @@ def _adapterFromOptionalSchema(self, schema: object, field_name: str) -> _Adapte if schema is None: return _AdapterRegistration() if isinstance(schema, type) and issubclass(schema, BaseModel): - return _AdapterRegistration(adapter=TypeAdapter(schema), json_schema=schema.model_json_schema()) + return _AdapterRegistration(adapter=TypeAdapter(schema), json_schema=_json_schema_object(schema.model_json_schema())) if not isinstance(schema, Mapping): raise TypeError(f"{field_name} must be a JSON Schema object") - json_schema = _json_object(schema) + json_schema = _json_schema_object(schema) return _AdapterRegistration(adapter=type_adapter_from_json_schema(json_schema), json_schema=json_schema) diff --git a/python/modcdp/types/modcdp.py b/python/modcdp/types/modcdp.py index 0af7dd6..0cf7997 100644 --- a/python/modcdp/types/modcdp.py +++ b/python/modcdp/types/modcdp.py @@ -88,6 +88,33 @@ class ModCDPAddMiddlewareParams(ModCDPModel): name: str | None = None +class ModCDPAliasReturn(ModCDPModel): + object: str | None = None + unwrap: str | None = None + array: bool | None = None + nullable: bool | None = None + + +class ModCDPAliasMethod(ModCDPModel): + name: str + command: str | None = None + sdk_method_name: str | None = None + params_schema: ModCDPPayloadSchemaSpec | None = None + result_schema: ModCDPPayloadSchemaSpec | None = None + sticky_param: str | None = None + sticky_params: list[str] | None = None + sticky_fields: list[str] | None = None + return_: ModCDPAliasReturn | None = Field(default=None, alias="return") + + +class ModCDPAliasObject(ModCDPModel): + name: str + type_name: str | None = None + sticky_schema: ModCDPPayloadSchemaSpec | None = None + sticky_fields: list[str] | None = None + methods: list[ModCDPAliasMethod] | None = None + + class ModCDPEvaluateParams(ModCDPModel): expression: str params: dict[str, object] | None = None @@ -260,6 +287,7 @@ class ModCDPServerConfig(ModCDPModel): custom_commands: list[ModCDPAddCustomCommandParams] | None = None custom_events: list[ModCDPAddCustomEventObjectParams] | None = None custom_middlewares: list[ModCDPAddMiddlewareParams] | None = None + custom_alias_objects: list[ModCDPAliasObject] | None = None ModCDPConfigureParams: TypeAlias = ModCDPServerConfig diff --git a/python/tests/stagehand_client_generated/stagehand_client_live_test.py b/python/tests/stagehand_client_generated/stagehand_client_live_test.py new file mode 100644 index 0000000..43a8c50 --- /dev/null +++ b/python/tests/stagehand_client_generated/stagehand_client_live_test.py @@ -0,0 +1,504 @@ +# Before running this live test, build a local Stagehand alias manifest and +# regenerate the Python client: +# +# pnpm --dir ../stagehand-server exec node src/protocol/generate_modcdp_alias_manifest.mjs ../modcdp2/testdata/codegen/stagehand_alias_manifest.json +# pnpm run build:python +# pnpm run test:e2e:python +# +# The manifest is intentionally untracked; do not commit it. + +from __future__ import annotations + +import json +import unittest +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +from tests.stagehand_client_generated import stagehand_client_gen as sh + + +CHALLENGE_URL = "https://pirate.github.io/stress-tests/challenge.html" + + +class GeneratedStagehandClientChallengeSuite(unittest.TestCase): + def test_generated_stagehand_client_solves_challenge_suite(self) -> None: + client = sh.StagehandClient() + client.connect() + try: + page = client.browser.new_page(sh.BrowserNewPageParams(url="about:blank")) + page = page.goto(sh.PageGotoParams(url=CHALLENGE_URL, wait_until="load")) + page.wait_for_expression( + sh.PageWaitForExpressionParams( + expression="document.querySelectorAll('.task').length === 45", + timeout_ms=20_000, + ) + ) + + fixture_file = str(Path(__file__).resolve()) + + challenges: list[tuple[str, Callable[[sh.Page], object]]] = [ + ( + "simple-button", + lambda page: page.locate(sh.PageLocateParams(locator=sh.Locator(css="#start-button"))).click(), + ), + ( + "radio-selection", + lambda page: page.click( + sh.PageClickParams(locator=sh.Locator(css='input[name="entity"][value="dont-choose-this-one"]')) + ), + ), + ( + "checkbox-check", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="input.challenge-checkbox", idx=0))), + page.click(sh.PageClickParams(locator=sh.Locator(css="input.challenge-checkbox", idx=1))), + page.locate(sh.PageLocateParams(locator=sh.Locator(css="input.challenge-checkbox"))) + .nth(sh.LocatorNthParams(index=2)) + .click(), + ), + ), + ( + "search-squash", + lambda page: ( + page.fill(sh.PageFillParams(locator=sh.Locator(css="#search-input"), value="squash")), + page.key_press(sh.PageKeyPressParams(locator=sh.Locator(css="#search-input"), key="Enter")), + ), + ), + ( + "date-time-input", + lambda page: ( + page.fill( + sh.PageFillParams( + locator=sh.Locator(css="#date-picker"), + value=datetime.now().strftime("%Y-%m-%d"), + ) + ), + page.fill( + sh.PageFillParams( + locator=sh.Locator(css="#time-picker"), + value=datetime.now().strftime("%H:%M"), + ) + ), + ), + ), + ( + "copy-text", + lambda page: page.fill(sh.PageFillParams(locator=sh.Locator(css="#copy-input"), value="abc")), + ), + ( + "slider-drag", + lambda page: page.click( + sh.PageClickParams(locator=sh.Locator(css="#range-slider"), offset_x=295, offset_y=8) + ), + ), + ( + "hover-element", + lambda page: ( + page.hover(sh.PageHoverParams(locator=sh.Locator(css="#hover-target"))), + page.wait_for_timeout(sh.PageWaitForTimeoutParams(ms=1100)), + ), + ), + ( + "drag-drop", + lambda page: page.drag_and_drop( + sh.PageDragAndDropParams( + start=sh.Locator(css="#drag-element"), + end=sh.Locator(css="#drop-target"), + ) + ), + ), + ( + "file-drop", + lambda page: page.drop_files( + sh.PageDropFilesParams(files=[fixture_file], locator=sh.Locator(css="#file-drop-area")) + ), + ), + ( + "multi-select", + lambda page: page.select_option( + sh.PageSelectOptionParams( + locator=sh.Locator(css="#multi-select-element"), + values=["option1", "option2", "option3"], + ) + ), + ), + ( + "canvas-captcha", + lambda page: page.fill( + sh.PageFillParams(locator=sh.Locator(css="#canvas-text-input"), value="CAPTCHA123") + ), + ), + ( + "iframe-slider", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="#iframe-slider", idx=2), offset_x=295, offset_y=8)), + page.click(sh.PageClickParams(locator=sh.Locator(css="#iframe-slider-popover-button"))), + ), + ), + ( + "oopif-form-submit", + lambda page: ( + page.wait_for_locator( + sh.PageWaitForLocatorParams(locator=sh.Locator(css="#first-name"), timeout_ms=20_000) + ), + page.fill(sh.PageFillParams(locator=sh.Locator(css="#first-name"), value="Ada")), + page.fill(sh.PageFillParams(locator=sh.Locator(css="#last-name"), value="Lovelace")), + page.fill(sh.PageFillParams(locator=sh.Locator(css="#email"), value="ada@example.com")), + page.click(sh.PageClickParams(locator=sh.Locator(css="#contact-form button[type='submit']"))), + ), + ), + ( + "oopif-open-shadow-button", + lambda page: ( + page.click( + sh.PageClickParams( + locator=sh.Locator( + xpath="/html[1]/body[1]/div[2]/div[15]/iframe[1]/html[1]/body[1]/shadow-demo[1]//div[1]/button[1]" + ) + ) + ), + page.click(sh.PageClickParams(locator=sh.Locator(css="#oopif-open-shadow-confirm"))), + ), + ), + ( + "oopif-closed-shadow-button", + lambda page: ( + page.click( + sh.PageClickParams( + locator=sh.Locator( + xpath="/html[1]/body[1]/div[2]/div[16]/iframe[1]/html[1]/body[1]/shadow-demo[1]//div[1]/button[1]" + ) + ) + ), + page.click(sh.PageClickParams(locator=sh.Locator(css="#oopif-closed-shadow-confirm"))), + ), + ), + ( + "shadow-dom-dblclick", + lambda page: page.double_click(sh.PageDoubleClickParams(locator=sh.Locator(text="Double-click me"))), + ), + ( + "right-click-component", + lambda page: page.locate(sh.PageLocateParams(locator=sh.Locator(text="Right-click me"))).click( + sh.LocatorClickParams(button="secondary") + ), + ), + ( + "scroll-accept", + lambda page: ( + page.scroll(sh.PageScrollParams(locator=sh.Locator(css="#scroll-container"), percent=100)), + page.click(sh.PageClickParams(locator=sh.Locator(css="#accept-button"))), + ), + ), + ("draw-circle", self._draw_circle), + ( + "cancel-dialog", + lambda page: page.click( + sh.PageClickParams(dialog={"accept": False}, locator=sh.Locator(css="#confirm-button")) + ), + ), + ( + "geolocation-permission", + lambda page: ( + page.set_geolocation( + sh.PageSetGeolocationParams( + accuracy=25, + latitude=37.7749, + longitude=-122.4194, + origin="https://pirate.github.io", + ) + ), + page.click(sh.PageClickParams(locator=sh.Locator(css="#request-location-button"))), + ), + ), + ( + "arrow-key-presses", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="#key-press-area"))), + page.key_press( + sh.PageKeyPressParams(locator=sh.Locator(css="#key-press-area"), key="ArrowRight", repeat=3) + ), + ), + ), + ("alert-secret", self._alert_secret), + ( + "press-hold-button", + lambda page: page.click_and_hold( + sh.PageClickAndHoldParams(locator=sh.Locator(css="#hold-button"), duration_ms=800) + ), + ), + ( + "tooltip-secret", + lambda page: ( + page.hover(sh.PageHoverParams(locator=sh.Locator(css="#tooltip-element"))), + page.fill(sh.PageFillParams(locator=sh.Locator(css="#tooltip-secret-input"), value="octopus")), + ), + ), + ("resize-textarea", self._resize_textarea), + ( + "file-upload", + lambda page: page.set_input_files( + sh.PageSetInputFilesParams(files=[fixture_file], locator=sh.Locator(css="#file-upload-input")) + ), + ), + ( + "phone-input", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="#phone-input-field"))), + page.key_press(sh.PageKeyPressParams(locator=sh.Locator(css="#phone-input-field"), key="Backspace")), + page.key_press( + sh.PageKeyPressParams( + method="type", + locator=sh.Locator(css="#phone-input-field"), + value="555-1234", + ) + ), + ), + ), + ( + "expand-details", + lambda page: page.locate(sh.PageLocateParams(locator=sh.Locator(css="#details-element summary"))).click(), + ), + ("drag-square-to-circle", self._drag_square_to_circle), + ( + "audio-transcription", + lambda page: page.fill( + sh.PageFillParams(locator=sh.Locator(css="#transcription-input"), value="everything") + ), + ), + ( + "dropdown-selections", + lambda page: ( + page.select_option(sh.PageSelectOptionParams(locator=sh.Locator(css="#color-select"), values=["red"])), + page.select_option(sh.PageSelectOptionParams(locator=sh.Locator(css="#object-select"), values=["ball"])), + ), + ), + ( + "contenteditable-div", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="#editable-content"))), + page.key_press( + sh.PageKeyPressParams(method="type", locator=sh.Locator(css="#editable-content"), value="banana") + ), + ), + ), + ( + "nested-tiny-button", + lambda page: page.click( + sh.PageClickParams( + dialog={"accept": True}, + expect_timeout_ms=20_000, + locator=sh.Locator(css="#deep-tiny-button"), + ) + ), + ), + ("wrapped-word-click", self._wrapped_word_click), + ( + "long-link-maze", + lambda page: page.click(sh.PageClickParams(locator=sh.Locator(text="TARGET-LINK::orion-needle-1847"))), + ), + ( + "wall-secret-word", + lambda page: page.fill(sh.PageFillParams(locator=sh.Locator(css="#wall-secret-input"), value="cobaltglass")), + ), + ( + "closed-shadow-aria-word", + lambda page: page.fill( + sh.PageFillParams(locator=sh.Locator(css="#closed-shadow-aria-input"), value="violetcircuit") + ), + ), + ( + "dynamic-frame-ordinal-trap", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="#frame-trap-start"))), + page.wait_for_locator( + sh.PageWaitForLocatorParams(locator=sh.Locator(css="#frame-trap-create-third"), timeout_ms=10_000) + ), + page.click(sh.PageClickParams(locator=sh.Locator(css="#frame-trap-create-third"))), + page.wait_for_locator( + sh.PageWaitForLocatorParams(locator=sh.Locator(css="#frame-trap-insert-final"), timeout_ms=10_000) + ), + page.click(sh.PageClickParams(locator=sh.Locator(css="#frame-trap-insert-final"))), + page.wait_for_locator( + sh.PageWaitForLocatorParams(locator=sh.Locator(css="#frame-trap-final-button"), timeout_ms=10_000) + ), + page.click(sh.PageClickParams(locator=sh.Locator(css="#frame-trap-final-button"))), + ), + ), + ("rotated-transform-click", self._rotated_transform_click), + ( + "transparent-overlay-button", + lambda page: page.locate(sh.PageLocateParams(locator=sh.Locator(css="#transparent-overlay-target"))).click(), + ), + ( + "realistic-input-sequence", + lambda page: ( + page.scroll(sh.PageScrollParams(locator=sh.Locator(css="#realistic-scroll-panel"), delta_y=450)), + page.scroll(sh.PageScrollParams(locator=sh.Locator(css="#realistic-scroll-panel"), delta_y=450)), + page.click(sh.PageClickParams(locator=sh.Locator(css="#realistic-type-input"))), + [ + page.key_press(sh.PageKeyPressParams(locator=sh.Locator(css="#realistic-type-input"), key=key)) + for key in ["o", "r", "c", "h", "i", "d"] + ], + page.click(sh.PageClickParams(locator=sh.Locator(css="#realistic-submit-button"))), + ), + ), + ( + "css-transform-text", + lambda page: page.fill( + sh.PageFillParams(locator=sh.Locator(css="#css-transform-text-input"), value="skyline") + ), + ), + ( + "google-docs", + lambda page: ( + page.click(sh.PageClickParams(locator=sh.Locator(css="#google-docs-answer-input"))), + page.key_press( + sh.PageKeyPressParams( + method="type", + locator=sh.Locator(css="#google-docs-answer-input"), + value="snooker", + ) + ), + ), + ), + ] + + for task_id, run in challenges: + run(page) + page.wait_for_expression( + sh.PageWaitForExpressionParams( + expression=f"document.getElementById({json.dumps(task_id)})?.classList.contains('completed') === true", + timeout_ms=20_000, + ) + ) + + score = page.evaluate( + sh.PageEvaluateParams(expression="document.querySelectorAll('.task.completed').length") + ).value + counts = page.evaluate( + sh.PageEvaluateParams( + expression="(() => ({ completed: document.querySelectorAll('.task.completed').length, total: document.querySelectorAll('.task').length }))()" + ) + ).value + self.assertEqual(score, counts["completed"]) + self.assertEqual(counts["total"], 45) + self.assertEqual(score, 45) + finally: + client.close() + + def _draw_circle(self, page: sh.Page) -> None: + page.click(sh.PageClickParams(locator=sh.Locator(css="#draw-canvas"), offset_x=150, offset_y=150)) + page.scroll(sh.PageScrollParams(delta_y=320)) + canvas = page.locate(sh.PageLocateParams(locator=sh.Locator(css="#draw-canvas"))) + self.assertIsNotNone(canvas.coordinates) + assert canvas.coordinates is not None + self.assertIsNotNone(canvas.coordinates.left) + self.assertIsNotNone(canvas.coordinates.top) + left = canvas.coordinates.left + top = canvas.coordinates.top + assert left is not None and top is not None + for index in range(24): + import math + + angle = math.pi * 2 * index / 24 + page.hover( + sh.PageHoverParams( + locator=sh.Locator( + coordinates=sh.LocatorCoordinates( + x=left + 150 + math.cos(angle) * 100, + y=top + 150 + math.sin(angle) * 100, + ) + ) + ) + ) + + def _alert_secret(self, page: sh.Page) -> None: + result = page.click(sh.PageClickParams(dialog={"accept": True}, locator=sh.Locator(css="#secret-alert-button"))) + self.assertRegex(result.message or "", "avocado") + page.fill(sh.PageFillParams(locator=sh.Locator(css="#secret-word-input"), value="avocado")) + + def _resize_textarea(self, page: sh.Page) -> None: + textarea = page.locate(sh.PageLocateParams(locator=sh.Locator(css="#resizable-textarea"))) + self.assertIsNotNone(textarea.coordinates) + assert textarea.coordinates is not None + self.assertIsNotNone(textarea.coordinates.right) + self.assertIsNotNone(textarea.coordinates.bottom) + right = textarea.coordinates.right + bottom = textarea.coordinates.bottom + assert right is not None and bottom is not None + page.drag_and_drop( + sh.PageDragAndDropParams( + start=sh.Locator(coordinates=sh.LocatorCoordinates(x=right - 2, y=bottom - 2)), + end=sh.Locator(coordinates=sh.LocatorCoordinates(x=right + 160, y=bottom + 90)), + ) + ) + page.fill(sh.PageFillParams(locator=sh.Locator(css="#resize-secret-input"), value="giraffe")) + + def _drag_square_to_circle(self, page: sh.Page) -> None: + page.click(sh.PageClickParams(locator=sh.Locator(css="#drag-canvas"), offset_x=175, offset_y=125)) + page.scroll(sh.PageScrollParams(delta_y=220)) + canvas = page.locate(sh.PageLocateParams(locator=sh.Locator(css="#drag-canvas"))) + self.assertIsNotNone(canvas.coordinates) + assert canvas.coordinates is not None + self.assertIsNotNone(canvas.coordinates.left) + self.assertIsNotNone(canvas.coordinates.top) + left = canvas.coordinates.left + top = canvas.coordinates.top + assert left is not None and top is not None + page.drag_and_drop( + sh.PageDragAndDropParams( + start=sh.Locator(coordinates=sh.LocatorCoordinates(x=left + 75, y=top + 125)), + end=sh.Locator(coordinates=sh.LocatorCoordinates(x=left + 250, y=top + 125)), + ) + ) + + def _wrapped_word_click(self, page: sh.Page) -> None: + page.hover(sh.PageHoverParams(locator=sh.Locator(css="#wrapped-word-paragraph"))) + point = page.evaluate( + sh.PageEvaluateParams( + expression="""(() => { + const paragraph = document.getElementById("wrapped-word-paragraph"); + if (!(paragraph?.firstChild instanceof Text)) throw new Error("wrapped word paragraph text is missing"); + const paragraphRect = paragraph.getBoundingClientRect(); + const start = paragraph.textContent.indexOf("ox"); + const range = document.createRange(); + range.setStart(paragraph.firstChild, start); + range.setEnd(paragraph.firstChild, start + 2); + const rect = Array.from(range.getClientRects()).at(-1); + range.detach(); + if (rect == null) throw new Error("wrapped word range has no visible rect"); + return { x: rect.left + rect.width / 2 - paragraphRect.left, y: rect.top + rect.height / 2 - paragraphRect.top }; + })()""" + ) + ).value + self.assertIsInstance(point, dict) + x = point["x"] + y = point["y"] + self.assertIsInstance(x, (int, float)) + self.assertIsInstance(y, (int, float)) + page.click(sh.PageClickParams(locator=sh.Locator(css="#wrapped-word-paragraph"), offset_x=x, offset_y=y)) + + def _rotated_transform_click(self, page: sh.Page) -> None: + page.hover(sh.PageHoverParams(locator=sh.Locator(css="#rotated-transform-button"))) + point = page.evaluate( + sh.PageEvaluateParams( + expression="""(() => { + const button = document.getElementById("rotated-transform-button"); + if (button == null) throw new Error("rotated transform button is missing"); + const rect = button.getBoundingClientRect(); + return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 }; + })()""" + ) + ).value + self.assertIsInstance(point, dict) + x = point["x"] + y = point["y"] + self.assertIsInstance(x, (int, float)) + self.assertIsInstance(y, (int, float)) + page.click(sh.PageClickParams(locator=sh.Locator(coordinates=sh.LocatorCoordinates(x=x, y=y)))) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/codegen/codegen.ts b/src/codegen/codegen.ts new file mode 100644 index 0000000..7f9a6a0 --- /dev/null +++ b/src/codegen/codegen.ts @@ -0,0 +1,34 @@ +import { mkdirSync, writeFileSync } from "node:fs"; +import { dirname } from "node:path"; + +import { generateClientWithAliasesForGo } from "./codegen_go.ts"; +import { generateClientWithAliasesForPython } from "./codegen_python.ts"; +import { generateClientWithAliasesForTs } from "./codegen_ts.ts"; + +type GenerateClientWithAliasesLanguage = "ts" | "python" | "go"; + +type GenerateClientWithAliasesOptions = { + language: GenerateClientWithAliasesLanguage; + name: string; + custom_commands?: unknown[]; + custom_alias_objects?: unknown[]; + default_config?: unknown; + output?: string; +}; + +export type { GenerateClientWithAliasesLanguage, GenerateClientWithAliasesOptions }; +export { generateClientWithAliases }; + +function generateClientWithAliases(options: GenerateClientWithAliasesOptions): string { + const objects = (options.custom_alias_objects ?? []) as any[]; + const commands = (options.custom_commands ?? []) as any[]; + let generated: string; + if (options.language === "ts") generated = generateClientWithAliasesForTs(options, objects, commands); + else if (options.language === "python") generated = generateClientWithAliasesForPython(options, objects, commands); + else generated = generateClientWithAliasesForGo(options, objects, commands); + if (options.output) { + mkdirSync(dirname(options.output), { recursive: true }); + writeFileSync(options.output, generated); + } + return generated; +} diff --git a/src/codegen/codegen_go.ts b/src/codegen/codegen_go.ts new file mode 100644 index 0000000..8f3784b --- /dev/null +++ b/src/codegen/codegen_go.ts @@ -0,0 +1,450 @@ +import { readFileSync } from "node:fs"; +import { mkdirSync, writeFileSync } from "node:fs"; +import { dirname } from "node:path"; +import { pathToFileURL } from "node:url"; + +import type { GenerateClientWithAliasesOptions } from "./codegen.ts"; +import { + collectSchemaDefinitions, + lowerFirst, + paramsCanBeOmitted, + pascal, + receiverStickyParamFields, + requiredSet, + safeTypeName, + sanitizeCommandRegistry, + sanitizeRegistry, + schemaPropertiesObject, + sdkPath, + sortedEntries, + statelessAliasObjects, + unionSchemas, +} from "./codegen_shared.ts"; + +export { generateClientWithAliasesForGo }; + +function generateClientWithAliasesForGo(options: GenerateClientWithAliasesOptions, objects, commands) { + const clientName = safeTypeName(options.name); + const packageName = goPackageName(clientName); + const clientRef = "modcdpclient."; + const aliasObjectsName = `${clientName}AliasObjects`; + const customCommandsName = `${clientName}CustomCommands`; + const hasDefaultConfig = options.default_config != null; + const statelessObjects = statelessAliasObjects(objects); + const objectsByName = new Map(objects.map((object) => [object.name, object])); + let out = `// Code generated by ModCDP codegen. DO NOT EDIT.\n`; + out += `package ${packageName}\n\n`; + out += `import (\n "encoding/json"\n`; + if (hasDefaultConfig) out += ` "os"\n`; + out += ` modcdpclient "github.com/browserbase/modcdp/go/modcdp/client"\n`; + out += `)\n\n`; + out += emitGoSchemaDefinitions(objects); + out += emitGoClientType( + clientName, + clientRef, + aliasObjectsName, + customCommandsName, + sortedObjectsArray(objects), + statelessObjects, + options.default_config, + ); + for (const object of sortedObjectsArray(objects)) { + const typeName = object.type_name || pascal(object.name); + out += `type ${typeName} struct {\n`; + out += ` alias ${clientRef}ModCDPAliasObject \`json:"-"\`\n`; + out += emitGoSchemaFields(object.sticky_schema, typeName, new Set()); + out += `}\n\n`; + out += emitGoAliasObjectConstructor(clientRef, object); + for (const method of object.methods ?? []) { + if (sdkPath(object, method).length === 2) + out += emitGoMethod( + clientRef, + typeName, + "receiver.alias.Client", + "receiver.alias.Sticky", + object, + method, + objectsByName, + ); + } + } + out += emitGoAliasObjectsFunction(clientRef, aliasObjectsName, sortedObjectsArray(objects)); + out += emitGoCustomCommandsFunction(clientRef, customCommandsName, commands); + return out; +} + +function emitGoClientType( + clientName, + clientRef, + aliasObjectsName, + customCommandsName, + objects, + statelessObjects, + defaultConfig, +) { + const concreteName = lowerFirst(clientName); + const optionsName = `${clientName}Options`; + const hasDefaultConfig = defaultConfig != null; + let out = `type ${optionsName} ${clientRef}Config\n\n`; + out += `type ${concreteName} struct {\n *${clientRef}ModCDPClient\n`; + for (const object of objects) { + if (statelessObjects.has(object.name)) + out += ` ${object.type_name || pascal(object.name)} ${object.type_name || pascal(object.name)} \`json:"-"\`\n`; + } + out += `}\n\n`; + if (hasDefaultConfig) out += emitGoDefaultConfig(clientName, clientRef, defaultConfig); + out += `func ${clientName}(options ...${optionsName}) *${concreteName} {\n`; + out += ` config := ${clientRef}Config(${hasDefaultConfig ? `${clientName}DefaultOptions()` : `${optionsName}{}`})\n`; + out += ` if len(options) > 0 {\n`; + out += ` config = ${clientRef}Config(options[0])\n`; + out += ` }\n`; + out += ` generatedTypes := ${clientRef}CDPTypesConfig{\n`; + out += ` CustomCommands: ${customCommandsName}(),\n`; + out += ` CustomAliasObjects: ${aliasObjectsName}(),\n`; + out += ` }\n`; + out += ` if config.Types == nil {\n`; + out += ` config.Types = &generatedTypes\n`; + out += ` } else {\n`; + out += ` config.Types.CustomCommands = append(config.Types.CustomCommands, generatedTypes.CustomCommands...)\n`; + out += ` config.Types.CustomAliasObjects = append(config.Types.CustomAliasObjects, generatedTypes.CustomAliasObjects...)\n`; + out += ` }\n`; + out += ` receiver := &${concreteName}{ModCDPClient: ${clientRef}New(config)}\n`; + for (const object of objects) { + if (statelessObjects.has(object.name)) + out += ` receiver.${object.type_name || pascal(object.name)} = new${object.type_name || pascal(object.name)}Alias(receiver.ModCDPClient, nil)\n`; + } + out += ` return receiver\n`; + out += `}\n\n`; + return out; +} + +function emitGoDefaultConfig(clientName, clientRef, config) { + return ( + `const ${lowerFirst(clientName)}DefaultOptionsJSON = ${JSON.stringify(JSON.stringify(config))}\n\n` + + `func ${clientName}DefaultOptions() ${clientName}Options {\n` + + ` var config ${clientRef}Config\n` + + ` if err := json.Unmarshal([]byte(os.ExpandEnv(${lowerFirst(clientName)}DefaultOptionsJSON)), &config); err != nil {\n` + + ` panic(err)\n` + + ` }\n` + + ` return ${clientName}Options(config)\n` + + `}\n\n` + ); +} + +function emitGoSchemaDefinitions(objects) { + const emitted = new Set(objects.map((object) => object.type_name || pascal(object.name))); + let out = ""; + for (const object of sortedObjectsArray(objects)) { + const defs = collectDefinitionsFromSchemas([ + object.sticky_schema, + ...(object.methods ?? []).flatMap((method) => [method.params_schema, method.result_schema]), + ]); + for (const [name, schema] of sortedEntries(defs)) { + const typeName = goExportName(name); + if (emitted.has(typeName)) continue; + emitted.add(typeName); + out += emitGoSchemaDefinition(typeName, schema); + } + } + return out; +} + +function emitGoSchemaDefinition(name, schema) { + if (Object.keys(schemaPropertiesObject(schema)).length > 0) return emitGoSchemaType(name, schema, new Set()); + let fieldType = goTypeForSchemaWithName(schema, name); + if (fieldType === name || fieldType.startsWith(`*${name}`) || fieldType.startsWith(`[]${name}`)) + fieldType = "json.RawMessage"; + return `type ${name} ${fieldType}\n\n`; +} + +function emitGoAliasObjectConstructor(clientRef, object) { + const typeName = object.type_name || pascal(object.name); + let out = `func new${typeName}Alias(base *${clientRef}ModCDPClient, sticky ${clientRef}AliasSticky) ${typeName} {\n`; + out += ` value := ${typeName}{alias: ${clientRef}NewModCDPAliasObject(base, sticky)}\n`; + if (Object.keys(schemaPropertiesObject(object.sticky_schema)).length > 0) + out += ` _ = value.applyAliasSticky(sticky)\n`; + out += ` return value\n`; + out += `}\n\n`; + if (Object.keys(schemaPropertiesObject(object.sticky_schema)).length === 0) return out; + out += `func (receiver *${typeName}) applyAliasSticky(sticky ${clientRef}AliasSticky) error {\n`; + out += ` if sticky == nil {\n return nil\n }\n`; + out += ` body, err := json.Marshal(sticky)\n`; + out += ` if err != nil {\n return err\n }\n`; + out += ` return json.Unmarshal(body, receiver)\n`; + out += `}\n\n`; + return out; +} + +function emitGoMethod(clientRef, receiverType, clientExpression, stickyExpression, object, method, objectsByName) { + const typeName = object.type_name || pascal(object.name); + const methodName = goExportName(method.name); + const paramsType = `${typeName}${methodName}Params`; + const resultType = `${typeName}${methodName}Result`; + const optionalFields = receiverStickyParamFields(object, method); + const paramsOptional = paramsCanBeOmitted(method.params_schema, optionalFields); + const paramsSignature = paramsOptional ? `params ...${paramsType}` : `params ${paramsType}`; + const paramsValue = paramsOptional ? "request" : "params"; + let out = emitGoSchemaType(paramsType, method.params_schema, optionalFields); + out += emitGoSchemaType(resultType, method.result_schema, new Set()); + const stickyParam = JSON.stringify(method.sticky_param || ""); + const stickyParams = goStringSliceLiteral(method.sticky_params || []); + const stickyFields = goStringSliceLiteral(method.sticky_fields || []); + const unwrap = method.return?.unwrap || ""; + const returnObject = method.return?.object || ""; + const optionalResolver = (zeroValue) => + paramsOptional + ? ` request, err := ${clientRef}OptionalAliasParams(${JSON.stringify(`${object.name}.${method.name}`)}, params)\n if err != nil {\n return ${zeroValue}, err\n }\n` + : ""; + if (returnObject) { + const target = objectsByName.get(returnObject); + if (!target) throw new Error(`${object.name}.${method.name} returns unknown alias object ${returnObject}`); + const targetType = target.type_name || pascal(target.name); + if (!method.command) { + out += `func (receiver ${receiverType}) ${methodName}(${paramsSignature}) (${targetType}, error) {\n`; + out += optionalResolver(`${targetType}{}`); + out += ` return new${targetType}Alias(${clientExpression}, nil), nil\n`; + out += `}\n\n`; + return out; + } + if (method.return?.array) { + out += `func (receiver ${receiverType}) ${methodName}(${paramsSignature}) ([]${targetType}, error) {\n`; + out += optionalResolver("nil"); + out += ` raw, err := ${clientRef}SendAliasCommandWithStickyParams[json.RawMessage](${clientExpression}, ${JSON.stringify(method.command)}, ${paramsValue}, ${stickyExpression}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "")\n`; + out += ` if err != nil {\n return nil, err\n }\n`; + out += ` items, err := ${clientRef}AliasArrayFromResult(raw, ${JSON.stringify(unwrap)})\n`; + out += ` if err != nil {\n return nil, err\n }\n`; + out += ` output := make([]${targetType}, 0, len(items))\n`; + out += ` for _, item := range items {\n`; + out += ` sticky, err := ${clientRef}AliasStickyFromResult(item, "", ${goStringSliceLiteral(target.sticky_fields || [])})\n`; + out += ` if err != nil {\n return nil, err\n }\n`; + out += ` output = append(output, new${targetType}Alias(${clientExpression}, sticky))\n`; + out += ` }\n`; + out += ` return output, nil\n`; + out += `}\n\n`; + return out; + } + if (method.return?.nullable) { + out += `func (receiver ${receiverType}) ${methodName}(${paramsSignature}) (*${targetType}, error) {\n`; + out += optionalResolver("nil"); + out += ` raw, err := ${clientRef}SendAliasCommandWithStickyParams[json.RawMessage](${clientExpression}, ${JSON.stringify(method.command)}, ${paramsValue}, ${stickyExpression}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "")\n`; + out += ` if err != nil {\n return nil, err\n }\n`; + out += ` unwrapped, err := ${clientRef}UnwrapAliasResult(raw, ${JSON.stringify(unwrap)})\n`; + out += ` if err != nil {\n return nil, err\n }\n`; + out += ` if unwrapped == nil {\n return nil, nil\n }\n`; + out += ` sticky, err := ${clientRef}AliasStickyFromResult(unwrapped, "", ${goStringSliceLiteral(target.sticky_fields || [])})\n`; + out += ` if err != nil {\n return nil, err\n }\n`; + out += ` value := new${targetType}Alias(${clientExpression}, sticky)\n`; + out += ` return &value, nil\n`; + out += `}\n\n`; + return out; + } + out += `func (receiver ${receiverType}) ${methodName}(${paramsSignature}) (${targetType}, error) {\n`; + out += optionalResolver(`${targetType}{}`); + out += ` raw, err := ${clientRef}SendAliasCommandWithStickyParams[${clientRef}AliasJSONObject](${clientExpression}, ${JSON.stringify(method.command)}, ${paramsValue}, ${stickyExpression}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "")\n`; + out += ` if err != nil {\n return ${targetType}{}, err\n }\n`; + out += ` sticky, err := ${clientRef}AliasStickyFromResult(raw, ${JSON.stringify(unwrap)}, ${goStringSliceLiteral(target.sticky_fields || [])})\n`; + out += ` if err != nil {\n return ${targetType}{}, err\n }\n`; + out += ` return new${targetType}Alias(${clientExpression}, sticky), nil\n`; + out += `}\n\n`; + return out; + } + out += `func (receiver ${receiverType}) ${methodName}(${paramsSignature}) (${resultType}, error) {\n`; + out += optionalResolver(`${resultType}{}`); + out += ` return ${clientRef}SendAliasCommandWithStickyParams[${resultType}](${clientExpression}, ${JSON.stringify(method.command)}, ${paramsValue}, ${stickyExpression}, ${stickyParam}, ${stickyParams}, ${stickyFields}, ${JSON.stringify(unwrap)})\n`; + out += `}\n\n`; + return out; +} + +function emitGoSchemaType(name, schema, optionalFields) { + let out = emitGoInlineSchemaTypes(name, schema, new Set([name])); + out += `type ${name} struct {\n`; + out += emitGoSchemaFields(schema, name, optionalFields); + out += `}\n\n`; + return out; +} + +function emitGoInlineSchemaTypes(parentName, schema, emitted) { + let out = ""; + for (const [property, propertySchema] of sortedEntries(schemaPropertiesObject(schema))) { + out += emitGoInlineSchemaTypeForSchema(`${parentName}${goExportName(property)}`, propertySchema, emitted); + } + return out; +} + +function emitGoInlineSchemaTypeForSchema(name, schema, emitted) { + if (!schema || typeof schema !== "object" || schema.$ref) return ""; + const nonNull = nullableSchema(schema); + if (nonNull) return emitGoInlineSchemaTypeForSchema(name, nonNull, emitted); + if (schema.type === "array") return emitGoInlineSchemaTypeForSchema(`${name}Item`, schema.items, emitted); + let out = ""; + if (schema.additionalProperties && typeof schema.additionalProperties === "object") { + out += emitGoInlineSchemaTypeForSchema(`${name}Value`, schema.additionalProperties, emitted); + } + if (Object.keys(schemaPropertiesObject(schema)).length === 0 || emitted.has(name)) return out; + emitted.add(name); + out += emitGoInlineSchemaTypes(name, schema, emitted); + out += `type ${name} struct {\n`; + out += emitGoSchemaFields(schema, name, new Set()); + out += `}\n\n`; + return out; +} + +function emitGoSchemaFields(schema, parentName, optionalFields) { + let out = ""; + const properties = schemaPropertiesObject(schema); + const required = requiredSet(schema); + for (const [property, propertySchema] of sortedEntries(properties)) { + const receiverStickyOptional = optionalFields.has(property); + if (receiverStickyOptional) required.delete(property); + let fieldType = goTypeForSchemaWithName(propertySchema, `${parentName}${goExportName(property)}`); + if (!required.has(property) && !receiverStickyOptional && goIsObjectLikeSchema(propertySchema)) + fieldType = goPointerType(fieldType); + let tag = property; + if (!required.has(property)) tag += ",omitempty"; + out += ` ${goExportName(property)} ${fieldType} \`json:"${tag}"\`\n`; + } + return out; +} + +function goTypeForSchemaWithName(schema, suggestedName) { + if (!schema || typeof schema !== "object") return "json.RawMessage"; + if (schema.$ref) return goTypeNameForRef(schema.$ref); + const union = unionSchemas(schema); + if (union) { + let nullable = false; + const nonNullTypes = []; + for (const item of union) { + if (item?.type === "null") { + nullable = true; + continue; + } + const itemType = goTypeForSchemaWithName(item, suggestedName); + if (!nonNullTypes.includes(itemType)) nonNullTypes.push(itemType); + } + if (nonNullTypes.length > 0) { + let nonNullType = nonNullTypes[0]; + if (nonNullTypes.length > 1) { + nonNullType = + nonNullTypes.includes("int") && nonNullTypes.includes("float64") && nonNullTypes.length === 2 + ? "float64" + : "json.RawMessage"; + } + if (nullable) return nonNullType === "string" ? nonNullType : goPointerType(nonNullType); + return nonNullType; + } + } + if (Array.isArray(schema.type)) { + const nonNull = schema.type.find((item) => item !== "null"); + if (typeof nonNull === "string") { + const scalar = goScalarType(nonNull, schema, suggestedName); + return scalar === "string" ? scalar : goPointerType(scalar); + } + } + if (typeof schema.type === "string") return goScalarType(schema.type, schema, suggestedName); + if (schema.properties) return suggestedName; + return "json.RawMessage"; +} + +function goScalarType(value, schema, suggestedName) { + if (value === "string") return "string"; + if (value === "boolean") return "bool"; + if (value === "integer") return "int"; + if (value === "number") return "float64"; + if (value === "object") { + if (Object.keys(schemaPropertiesObject(schema)).length > 0) return suggestedName; + if (schema.additionalProperties && typeof schema.additionalProperties === "object") { + return `map[string]${goTypeForSchemaWithName(schema.additionalProperties, `${suggestedName}Value`).replace(/^\*/, "")}`; + } + return "json.RawMessage"; + } + if (value === "array") { + if (schema.items && typeof schema.items === "object") + return `[]${goTypeForSchemaWithName(schema.items, `${suggestedName}Item`).replace(/^\*/, "")}`; + return "[]json.RawMessage"; + } + return "json.RawMessage"; +} + +function nullableSchema(schema) { + const union = unionSchemas(schema); + if (!union) return null; + let nonNull = null; + let nullable = false; + for (const item of union) { + if (item?.type === "null") nullable = true; + else if (nonNull == null) nonNull = item; + } + return nullable ? nonNull : null; +} + +function goIsObjectLikeSchema(schema) { + if (!schema || typeof schema !== "object") return false; + if (schema.$ref) return true; + if (schema.type === "object" || schema.properties) return true; + const nonNull = nullableSchema(schema); + return nonNull ? goIsObjectLikeSchema(nonNull) : false; +} + +function goPointerType(value) { + if (value === "any" || value.startsWith("*") || value.startsWith("[]")) return value; + return `*${value}`; +} + +function goTypeNameForRef(ref) { + const prefix = "#/$defs/"; + return String(ref).startsWith(prefix) ? goExportName(String(ref).slice(prefix.length)) : "json.RawMessage"; +} + +function emitGoAliasObjectsFunction(clientRef, name, objects) { + return `func ${name}() []${clientRef}CustomAliasObject {\n var values []${clientRef}CustomAliasObject\n if err := json.Unmarshal([]byte(${JSON.stringify(JSON.stringify(sanitizeRegistry(objects)))}), &values); err != nil {\n panic(err)\n }\n return values\n}\n\n`; +} + +function emitGoCustomCommandsFunction(clientRef, name, commands) { + return `func ${name}() []${clientRef}CustomCommand {\n var values []${clientRef}CustomCommand\n if err := json.Unmarshal([]byte(${JSON.stringify(JSON.stringify(sanitizeCommandRegistry(commands)))}), &values); err != nil {\n panic(err)\n }\n return values\n}\n\n`; +} + +function collectDefinitionsFromSchemas(schemas) { + const defs = new Map(); + for (const schema of schemas) collectSchemaDefinitions(schema, defs); + return defs; +} + +function sortedObjectsArray(objects) { + return [...objects].sort((left, right) => String(left.name).localeCompare(String(right.name))); +} + +function goStringSliceLiteral(values) { + if (!values || values.length === 0) return "nil"; + return `[]string{${values.map((value) => JSON.stringify(value)).join(", ")}}`; +} + +function goExportName(value) { + return pascal(value); +} + +function goPackageName(value) { + return ( + String(value) + .replace(/[^A-Za-z0-9]/g, "") + .toLowerCase() || "generatedclient" + ); +} + +if (process.argv[1] && import.meta.url === pathToFileURL(process.argv[1]).href) { + const manifest = JSON.parse(readFileSync("testdata/codegen/stagehand_alias_manifest.json", "utf8")); + const output = "testdata/codegen/stagehand_client_generated_go/stagehand_client_gen.go"; + const generated = generateClientWithAliasesForGo( + { + language: "go", + name: "StagehandClient", + custom_commands: manifest.custom_commands, + custom_alias_objects: manifest.custom_alias_objects, + default_config: manifest.modcdp_config, + output, + }, + manifest.custom_alias_objects ?? [], + manifest.custom_commands ?? [], + ); + mkdirSync(dirname(output), { recursive: true }); + writeFileSync(output, generated); +} diff --git a/src/codegen/codegen_python.ts b/src/codegen/codegen_python.ts new file mode 100644 index 0000000..11c405f --- /dev/null +++ b/src/codegen/codegen_python.ts @@ -0,0 +1,214 @@ +import { readFileSync } from "node:fs"; +import { mkdirSync, writeFileSync } from "node:fs"; +import { dirname } from "node:path"; +import { pathToFileURL } from "node:url"; + +import type { GenerateClientWithAliasesOptions } from "./codegen.ts"; +import { + collectDefinitions, + paramsCanBeOmitted, + pascal, + receiverStickyParamFields, + refName, + requiredSet, + safeTypeName, + sanitizeCommandRegistry, + sanitizeRegistry, + schemaPropertiesObject, + sdkMethodName, + sdkPath, + sortedEntries, + statelessAliasObjects, + unionSchemas, +} from "./codegen_shared.ts"; + +export { generateClientWithAliasesForPython }; + +function generateClientWithAliasesForPython(options: GenerateClientWithAliasesOptions, objects, commands) { + const clientName = safeTypeName(options.name); + const functionPrefix = pyFieldName(clientName); + const objectNames = new Set(objects.map((object) => object.type_name || pascal(object.name))); + const statelessObjects = statelessAliasObjects(objects); + const defs = collectDefinitions(objects); + const hasDefaultConfig = options.default_config != null; + let out = `# Code generated by ModCDP codegen. DO NOT EDIT.\nfrom __future__ import annotations\n\nimport json\nimport os\nfrom collections.abc import Mapping\nfrom typing import Any, cast\nfrom pydantic import AliasChoices, BaseModel, ConfigDict, Field, PrivateAttr\n\nfrom modcdp import ModCDPClient\nfrom modcdp.client.alias import ModCDPAliasObject, alias_array_from_result, alias_sticky_from_result, optional_alias_params, send_alias_command_with_sticky_params\nfrom modcdp.types.CDPTypes import CDPTypes\n\nJsonValue = Any\n\nclass AliasModel(BaseModel):\n model_config = ConfigDict(extra="forbid", populate_by_name=True, arbitrary_types_allowed=True)\n\n def alias_params(self) -> dict[str, object]:\n return cast(dict[str, object], self.model_dump(mode="json", exclude_none=True, by_alias=True))\n\n`; + if (hasDefaultConfig) out += emitPythonDefaultConfig(functionPrefix, options.default_config); + for (const [name, schema] of sortedEntries(defs)) { + if (!objectNames.has(name)) out += emitPyModel(name, schema, objectNames); + } + for (const object of objects) { + const typeName = object.type_name || pascal(object.name); + for (const method of object.methods ?? []) { + out += emitPyModel( + `${typeName}${pascal(method.name)}Params`, + method.params_schema, + objectNames, + receiverStickyParamFields(object, method), + ); + out += emitPyModel(`${typeName}${pascal(method.name)}Result`, method.result_schema, objectNames); + } + } + out += `class ${clientName}(ModCDPClient):\n`; + for (const object of objects) { + if (!statelessObjects.has(object.name)) continue; + const property = sdkPath(object, object.methods?.[0] ?? null)[0] ?? object.name; + out += ` ${pyFieldName(property)}: ${object.type_name || pascal(object.name)}\n`; + } + out += `\n def __init__(self, config: Mapping[str, Any] | None = None, **kwargs: Any) -> None:\n payload: dict[str, Any] = {**dict(${hasDefaultConfig ? `${functionPrefix}_default_options()` : "{}"} if config is None else config), **kwargs}\n generated_types = ${functionPrefix}_types()\n existing_types = payload.get("types")\n if isinstance(existing_types, CDPTypes):\n payload["types"] = existing_types.update(generated_types)\n elif isinstance(existing_types, Mapping):\n payload["types"] = CDPTypes(existing_types).update(generated_types)\n else:\n payload["types"] = CDPTypes(generated_types)\n super().__init__(**payload)\n`; + for (const object of objects) { + if (!statelessObjects.has(object.name)) continue; + const property = sdkPath(object, object.methods?.[0] ?? null)[0] ?? object.name; + out += ` self.${pyFieldName(property)} = ${object.type_name || pascal(object.name)}(self, {})\n`; + } + out += `\n`; + for (const object of objects) out += emitPyClass(object, objects); + out += `for _model in list(AliasModel.__subclasses__()):\n _model.model_rebuild()\n\n`; + out += `def ${functionPrefix}_types() -> dict[str, object]:\n return {\n "custom_commands": ${functionPrefix}_custom_commands(),\n "custom_alias_objects": ${functionPrefix}_alias_objects(),\n }\n\n`; + out += `def ${functionPrefix}_custom_commands() -> list[dict[str, object]]:\n return ${pyLiteral(sanitizeCommandRegistry(commands))}\n\n`; + out += `def ${functionPrefix}_alias_objects() -> list[dict[str, object]]:\n return ${pyLiteral(sanitizeRegistry(objects))}\n`; + return out; +} + +function emitPythonDefaultConfig(functionPrefix, config) { + return ( + `${functionPrefix.toUpperCase()}_DEFAULT_OPTIONS_JSON = ${JSON.stringify(JSON.stringify(config))}\n\n` + + `def ${functionPrefix}_default_options() -> dict[str, Any]:\n` + + ` return cast(dict[str, Any], _${functionPrefix}_replace_env(json.loads(${functionPrefix.toUpperCase()}_DEFAULT_OPTIONS_JSON)))\n\n` + + `def _${functionPrefix}_replace_env(value: object) -> object:\n` + + ` if isinstance(value, str):\n` + + ` return value.replace("\${CHROME_PATH}", os.environ.get("CHROME_PATH", "")).replace("\${STAGEHAND_EXTENSION_PATH}", os.environ.get("STAGEHAND_EXTENSION_PATH", ""))\n` + + ` if isinstance(value, list):\n` + + ` return [_${functionPrefix}_replace_env(item) for item in value]\n` + + ` if isinstance(value, dict):\n` + + ` return {key: _${functionPrefix}_replace_env(item) for key, item in value.items()}\n` + + ` return value\n\n` + ); +} + +function emitPyClass(object, objects) { + const typeName = object.type_name || pascal(object.name); + let out = `class ${typeName}(AliasModel, ModCDPAliasObject):\n _client: ModCDPClient | None = PrivateAttr(default=None)\n _sticky: dict[str, object] = PrivateAttr(default_factory=dict)\n`; + for (const [name, schema] of sortedEntries(schemaPropertiesObject(object.sticky_schema))) { + out += ` ${pyFieldName(name)}: ${pyType(schema, `${typeName}${pascal(name)}`, new Set(objects.map((item) => item.type_name || pascal(item.name))))} | None = ${pyFieldDefault(name)}\n`; + } + out += `\n def __init__(self, client: ModCDPClient | None = None, sticky: dict[str, object] | None = None, **data: object) -> None:\n payload = {**(sticky or {}), **data}\n AliasModel.__init__(self, **payload)\n self._client = client\n self._sticky = self.alias_params()\n\n @property\n def client(self) -> ModCDPClient:\n if self._client is None:\n raise RuntimeError("${typeName} is not bound to a ModCDPClient")\n return self._client\n\n @property\n def sticky(self) -> dict[str, object]:\n return dict(self._sticky)\n`; + for (const method of object.methods ?? []) { + if (sdkPath(object, method).length === 2) out += emitPyMethod(object, method, objects); + } + return `${out}\n`; +} + +function emitPyMethod(object, method, objects, receiver = { client: "self.client", sticky: "self.sticky" }) { + const typeName = object.type_name || pascal(object.name); + const methodName = sdkMethodName(object, method); + const paramsType = `${typeName}${pascal(method.name)}Params`; + const resultType = `${typeName}${pascal(method.name)}Result`; + const optional = paramsCanBeOmitted(method.params_schema, receiverStickyParamFields(object, method)); + const signature = optional ? `self, *params: ${paramsType}` : `self, params: ${paramsType}`; + const request = optional ? `optional_alias_params("${object.name}.${method.name}", params)` : "params"; + const stickyParam = JSON.stringify(method.sticky_param || ""); + const stickyParams = pyLiteral(method.sticky_params || []); + const stickyFields = pyLiteral(method.sticky_fields || []); + if (method.return?.object) { + const target = objects.find((item) => item.name === method.return.object); + const targetType = target?.type_name || pascal(method.return.object); + const targetStickyFields = pyLiteral(target?.sticky_fields || []); + if (method.return.array) { + return `\n def ${methodName}(${signature}) -> list[${targetType}]:\n request = ${request}\n raw = send_alias_command_with_sticky_params(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "")\n return [${targetType}(${receiver.client}, alias_sticky_from_result(item, "", ${targetStickyFields})) for item in alias_array_from_result(raw, ${JSON.stringify(method.return.unwrap || "")})]\n`; + } + if (method.return.nullable) { + return `\n def ${methodName}(${signature}) -> ${targetType} | None:\n request = ${request}\n raw = send_alias_command_with_sticky_params(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "")\n sticky = alias_sticky_from_result(raw, ${JSON.stringify(method.return.unwrap || "")}, ${targetStickyFields})\n return None if not sticky else ${targetType}(${receiver.client}, sticky)\n`; + } + return `\n def ${methodName}(${signature}) -> ${targetType}:\n request = ${request}\n raw = send_alias_command_with_sticky_params(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "")\n return ${targetType}(${receiver.client}, alias_sticky_from_result(raw, ${JSON.stringify(method.return.unwrap || "")}, ${targetStickyFields}))\n`; + } + return `\n def ${methodName}(${signature}) -> ${resultType}:\n request = ${request}\n raw = send_alias_command_with_sticky_params(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, ${JSON.stringify(method.return?.unwrap || "")})\n return ${resultType}.model_validate(raw if isinstance(raw, dict) else {"value": raw})\n`; +} + +function emitPyModel(name, schema, objectNames, optionalFields = new Set()) { + const properties = schemaPropertiesObject(schema); + if (Object.keys(properties).length === 0) return `class ${name}(AliasModel):\n pass\n\n`; + const required = requiredSet(schema); + let out = `class ${name}(AliasModel):\n`; + for (const [property, propertySchema] of sortedEntries(properties)) { + const optional = !required.has(property) || optionalFields.has(property); + out += ` ${pyFieldName(property)}: ${pyType(propertySchema, `${name}${pascal(property)}`, objectNames)}${optional ? " | None" : ""} = ${optional ? pyFieldDefault(property) : pyRequiredFieldDefault(property)}\n`; + } + return `${out}\n`; +} + +function pyType(schema, suggestedName, objectNames) { + if (!schema || typeof schema !== "object") return "JsonValue"; + if (schema.$ref) { + const ref = refName(schema.$ref); + return ref; + } + const union = unionSchemas(schema); + if (union) { + const parts = union.filter((item) => item.type !== "null").map((item) => pyType(item, suggestedName, objectNames)); + return [...new Set(parts.length ? parts : ["None"])].join(" | "); + } + if (schema.type === "array") return `list[${pyType(schema.items, `${suggestedName}Item`, objectNames)}]`; + if (schema.type === "object" || schema.properties) { + if (schema.properties) return "dict[str, JsonValue]"; + return "dict[str, JsonValue]"; + } + if (Array.isArray(schema.type)) + return schema.type + .filter((item) => item !== "null") + .map((item) => pyType({ type: item }, suggestedName, objectNames)) + .join(" | "); + if (schema.enum) return "JsonValue"; + if (schema.type === "string") return "str"; + if (schema.type === "boolean") return "bool"; + if (schema.type === "integer") return "int"; + if (schema.type === "number") return "float"; + return "JsonValue"; +} + +function pyFieldName(name) { + const normalized = String(name) + .replace(/[A-Z]/g, (value) => `_${value.toLowerCase()}`) + .replace(/[^A-Za-z0-9_]/g, "_"); + const field = normalized.replace(/^_+/, "") || "value"; + return PY_RESERVED_FIELD_NAMES.has(field) ? `${field}_` : field; +} + +const PY_RESERVED_FIELD_NAMES = new Set(["from", "return", "schema"]); + +function pyRequiredFieldDefault(name) { + const field = pyFieldName(name); + if (field === name) return `Field()`; + return `Field(validation_alias=AliasChoices(${JSON.stringify(field)}, ${JSON.stringify(name)}), serialization_alias=${JSON.stringify(name)})`; +} + +function pyFieldDefault(name) { + const field = pyFieldName(name); + if (field === name) return "None"; + return `Field(default=None, validation_alias=AliasChoices(${JSON.stringify(field)}, ${JSON.stringify(name)}), serialization_alias=${JSON.stringify(name)})`; +} + +function pyLiteral(value) { + return JSON.stringify(value, null, 2) + .replace(/\btrue\b/g, "True") + .replace(/\bfalse\b/g, "False") + .replace(/\bnull\b/g, "None"); +} + +if (process.argv[1] && import.meta.url === pathToFileURL(process.argv[1]).href) { + const manifest = JSON.parse(readFileSync("testdata/codegen/stagehand_alias_manifest.json", "utf8")); + const output = "python/tests/stagehand_client_generated/stagehand_client_gen.py"; + const generated = generateClientWithAliasesForPython( + { + language: "python", + name: "StagehandClient", + custom_commands: manifest.custom_commands, + custom_alias_objects: manifest.custom_alias_objects, + default_config: manifest.modcdp_config, + output, + }, + manifest.custom_alias_objects ?? [], + manifest.custom_commands ?? [], + ); + mkdirSync(dirname(output), { recursive: true }); + writeFileSync(output, generated); +} diff --git a/src/codegen/codegen_shared.ts b/src/codegen/codegen_shared.ts new file mode 100644 index 0000000..90bfcae --- /dev/null +++ b/src/codegen/codegen_shared.ts @@ -0,0 +1,162 @@ +function collectDefinitions(objects) { + const defs = new Map(); + for (const object of objects) { + collectSchemaDefinitions(object.sticky_schema, defs); + for (const method of object.methods ?? []) { + collectSchemaDefinitions(method.params_schema, defs); + collectSchemaDefinitions(method.result_schema, defs); + } + } + return defs; +} + +function collectSchemaDefinitions(schema, defs) { + if (!schema || typeof schema !== "object") return; + for (const [name, definition] of Object.entries(schema.$defs ?? {})) { + const typeName = safeTypeName(name); + if (!defs.has(typeName)) defs.set(typeName, definition); + collectSchemaDefinitions(definition, defs); + } + for (const property of Object.values(schema.properties ?? {})) collectSchemaDefinitions(property, defs); + for (const item of [...(schema.anyOf ?? []), ...(schema.oneOf ?? [])]) collectSchemaDefinitions(item, defs); + if (schema.items) collectSchemaDefinitions(schema.items, defs); + if (schema.additionalProperties && typeof schema.additionalProperties === "object") + collectSchemaDefinitions(schema.additionalProperties, defs); +} + +function receiverStickyParamFields(object, method) { + const fields = new Set(); + const properties = schemaProperties(method.params_schema); + if (!aliasObjectCanPopulateStickyParam(object, method)) return fields; + for (const stickyParam of [method.sticky_param, ...(method.sticky_params ?? [])]) { + if (stickyParam && properties.has(stickyParam)) fields.add(stickyParam); + } + return fields; +} + +function aliasObjectCanPopulateStickyParam(object, method) { + const objectFields = object.sticky_fields ?? []; + const methodFields = method.sticky_fields ?? []; + if (objectFields.length === 0) return schemaProperties(object.sticky_schema).size > 0 && methodFields.length === 0; + if (methodFields.length === 0) return true; + return objectFields.some((field) => methodFields.includes(field)); +} + +function paramsCanBeOmitted(schema, optionalFields) { + const required = requiredSet(schema); + for (const field of optionalFields) required.delete(field); + return required.size === 0; +} + +function schemaProperties(schema) { + return new Set(Object.keys(schemaPropertiesObject(schema))); +} + +function schemaPropertiesObject(schema) { + return schema?.properties && typeof schema.properties === "object" ? schema.properties : {}; +} + +function requiredSet(schema) { + return new Set(Array.isArray(schema?.required) ? schema.required : []); +} + +function unionSchemas(schema) { + return Array.isArray(schema?.anyOf) ? schema.anyOf : Array.isArray(schema?.oneOf) ? schema.oneOf : null; +} + +function sortedEntries(object) { + if (object instanceof Map) return [...object.entries()].sort(([left], [right]) => left.localeCompare(right)); + return Object.entries(object).sort(([left], [right]) => left.localeCompare(right)); +} + +function refName(ref) { + return safeTypeName(String(ref).replace("#/$defs/", "")); +} + +function safeTypeName(value) { + const raw = String(value); + if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(raw)) return raw; + return pascal(raw.replace(/[^A-Za-z0-9_]+/g, "_")); +} + +function lowerFirst(value) { + return String(value).slice(0, 1).toLowerCase() + String(value).slice(1); +} + +function pascal(value) { + return String(value) + .split(/[_\-.]+/) + .filter(Boolean) + .map((part) => part.slice(0, 1).toUpperCase() + part.slice(1)) + .join(""); +} + +function sdkPath(object, method = null) { + const raw = method?.sdk_method_name; + if (typeof raw === "string" && raw.includes(".")) return raw.split("."); + return [object.name, method?.name].filter(Boolean); +} + +function sdkMethodName(object, method) { + return sdkPath(object, method)[1] || method.name; +} + +function statelessAliasObjects(objects) { + const stateless = new Set(); + for (const object of objects) { + if ((object.sticky_fields ?? []).length === 0 && schemaProperties(object.sticky_schema).size === 0) + stateless.add(object.name); + } + return stateless; +} + +function sanitizeRegistry(value) { + if (Array.isArray(value)) return value.map((item) => sanitizeRegistry(item)); + if (value == null || typeof value !== "object") return value; + const output = {}; + for (const [key, child] of Object.entries(value)) { + if (key === "description" || key === "$schema" || key === "event_type" || key === "root") continue; + output[key] = sanitizeRegistry(child); + } + return output; +} + +function sanitizeCommandRegistry(value) { + if (Array.isArray(value)) return value.map((item) => sanitizeCommandRegistry(item)); + if (value == null || typeof value !== "object") return value; + const output = {}; + for (const [key, child] of Object.entries(value)) { + if ( + key === "description" || + key === "$schema" || + key === "event_type" || + key === "sdk_method_name" || + key === "root" + ) + continue; + output[key] = sanitizeCommandRegistry(child); + } + return output; +} + +export { + aliasObjectCanPopulateStickyParam, + collectDefinitions, + collectSchemaDefinitions, + lowerFirst, + paramsCanBeOmitted, + pascal, + receiverStickyParamFields, + refName, + requiredSet, + safeTypeName, + sanitizeCommandRegistry, + sanitizeRegistry, + schemaProperties, + schemaPropertiesObject, + sdkMethodName, + sdkPath, + sortedEntries, + statelessAliasObjects, + unionSchemas, +}; diff --git a/src/codegen/codegen_ts.ts b/src/codegen/codegen_ts.ts new file mode 100644 index 0000000..bfb000c --- /dev/null +++ b/src/codegen/codegen_ts.ts @@ -0,0 +1,216 @@ +import { mkdirSync, readFileSync, writeFileSync } from "node:fs"; +import { dirname } from "node:path"; +import { pathToFileURL } from "node:url"; + +import type { GenerateClientWithAliasesOptions } from "./codegen.ts"; +import { + collectDefinitions, + lowerFirst, + paramsCanBeOmitted, + pascal, + receiverStickyParamFields, + refName, + requiredSet, + safeTypeName, + sanitizeCommandRegistry, + sanitizeRegistry, + schemaProperties, + schemaPropertiesObject, + sdkMethodName, + sdkPath, + sortedEntries, + statelessAliasObjects, + unionSchemas, +} from "./codegen_shared.ts"; + +export { generateClientWithAliasesForTs }; + +function generateClientWithAliasesForTs(options: GenerateClientWithAliasesOptions, objects, commands) { + const clientName = safeTypeName(options.name); + const functionPrefix = lowerFirst(clientName); + const objectNames = new Set(objects.map((object) => object.type_name || pascal(object.name))); + const statelessObjects = statelessAliasObjects(objects); + const defs = collectDefinitions(objects); + const hasDefaultConfig = options.default_config != null; + let out = `// Code generated by ModCDP codegen. DO NOT EDIT.\n`; + out += `import { ModCDPClient } from "../../src/client/ModCDPClient.js";\n`; + out += `import { CDPTypes, type CDPTypesConfig } from "../../src/types/CDPTypes.js";\n`; + out += `import type { ModCDPAddCustomCommandParams, ModCDPAliasObject as ModCDPAliasObjectRegistration } from "../../src/types/modcdp.js";\n`; + out += `import { ModCDPAliasObject, aliasArrayFromResult, aliasStickyFromResult, optionalAliasParams, sendAliasCommandWithStickyParams } from "../../src/client/alias.js";\n\n`; + out += `export type JsonValue = null | boolean | number | string | JsonValue[] | { [key: string]: JsonValue };\n\n`; + out += `export type ${clientName}Options = ConstructorParameters[0];\n\n`; + if (hasDefaultConfig) out += emitTSDefaultConfig(clientName, functionPrefix, options.default_config); + for (const [name, schema] of sortedEntries(defs)) out += emitTSInterface(name, schema, objectNames); + for (const object of objects) { + const typeName = object.type_name || pascal(object.name); + if (schemaProperties(object.sticky_schema).size > 0) + out += emitTSInterface(`${typeName}Data`, object.sticky_schema, objectNames); + } + for (const object of objects) { + const typeName = object.type_name || pascal(object.name); + for (const method of object.methods ?? []) { + out += emitTSInterface( + `${typeName}${pascal(method.name)}Params`, + method.params_schema, + objectNames, + receiverStickyParamFields(object, method), + ); + out += emitTSInterface(`${typeName}${pascal(method.name)}Result`, method.result_schema, objectNames); + } + } + out += `export class ${clientName} extends ModCDPClient {\n`; + for (const object of objects) { + if (!statelessObjects.has(object.name)) continue; + const property = sdkPath(object, object.methods?.[0] ?? null)[0] ?? object.name; + out += ` readonly ${property}: ${object.type_name || pascal(object.name)};\n`; + } + out += ` constructor(config: ${clientName}Options = ${hasDefaultConfig ? `${functionPrefix}DefaultOptions()` : "{}"}) {\n`; + out += ` const generatedTypes = ${functionPrefix}Types();\n`; + out += ` const types = config.types instanceof CDPTypes ? config.types.update(generatedTypes) : new CDPTypes(config.types ?? {}).update(generatedTypes);\n`; + out += ` super({ ...config, types });\n`; + for (const object of objects) { + if (!statelessObjects.has(object.name)) continue; + const property = sdkPath(object, object.methods?.[0] ?? null)[0] ?? object.name; + out += ` this.${property} = new ${object.type_name || pascal(object.name)}(this, {});\n`; + } + out += ` }\n`; + out += `}\n\n`; + for (const object of objects) out += emitTSClass(object, objects); + out += `export function ${functionPrefix}Types(): CDPTypesConfig {\n return {\n custom_commands: ${functionPrefix}CustomCommands(),\n custom_alias_objects: ${functionPrefix}AliasObjects(),\n };\n}\n\n`; + out += `export function ${functionPrefix}CustomCommands(): ModCDPAddCustomCommandParams[] {\n return ${JSON.stringify(sanitizeCommandRegistry(commands), null, 2)} as ModCDPAddCustomCommandParams[];\n}\n\n`; + out += `export function ${functionPrefix}AliasObjects(): ModCDPAliasObjectRegistration[] {\n return ${JSON.stringify(sanitizeRegistry(objects), null, 2)} as ModCDPAliasObjectRegistration[];\n}\n`; + return out; +} + +function emitTSDefaultConfig(clientName, functionPrefix, config) { + return ( + `const ${clientName.toUpperCase()}_DEFAULT_OPTIONS_JSON = ${JSON.stringify(JSON.stringify(config))};\n\n` + + `export function ${functionPrefix}DefaultOptions(): ${clientName}Options {\n` + + ` return ${functionPrefix}ReplaceEnv(JSON.parse(${clientName.toUpperCase()}_DEFAULT_OPTIONS_JSON)) as ${clientName}Options;\n` + + `}\n\n` + + `function ${functionPrefix}ReplaceEnv(value: unknown): unknown {\n` + + ` if (typeof value === "string") return value.replaceAll("\${CHROME_PATH}", process.env.CHROME_PATH ?? "").replaceAll("\${STAGEHAND_EXTENSION_PATH}", process.env.STAGEHAND_EXTENSION_PATH ?? "");\n` + + ` if (Array.isArray(value)) return value.map((item) => ${functionPrefix}ReplaceEnv(item));\n` + + ` if (value != null && typeof value === "object") return Object.fromEntries(Object.entries(value).map(([key, item]) => [key, ${functionPrefix}ReplaceEnv(item)]));\n` + + ` return value;\n` + + `}\n\n` + ); +} + +function emitTSClass(object, objects) { + const typeName = object.type_name || pascal(object.name); + const stickyType = schemaProperties(object.sticky_schema).size > 0 ? `${typeName}Data` : "Record"; + let out = `export class ${typeName} extends ModCDPAliasObject${stickyType === "Record" ? "" : ` implements ${stickyType}`} {\n`; + for (const [name, schema] of sortedEntries(schemaPropertiesObject(object.sticky_schema))) { + out += ` ${tsFieldName(name)}?: ${tsType(schema, `${typeName}${pascal(name)}`, new Set(objects.map((item) => item.type_name || pascal(item.name))))} | null;\n`; + } + out += ` constructor(client: ModCDPClient, sticky: Partial<${stickyType}> = {}) {\n super(client, sticky as Record);\n Object.assign(this, sticky);\n }\n`; + for (const method of object.methods ?? []) { + if (sdkPath(object, method).length === 2) out += emitTSMethod(object, method, objects); + } + out += `}\n\n`; + return out; +} + +function emitTSMethod(object, method, objects, receiver = { client: "this.client", sticky: "this.sticky" }) { + const typeName = object.type_name || pascal(object.name); + const methodName = sdkMethodName(object, method); + const paramsType = `${typeName}${pascal(method.name)}Params`; + const resultType = `${typeName}${pascal(method.name)}Result`; + const optional = paramsCanBeOmitted(method.params_schema, receiverStickyParamFields(object, method)); + const signature = optional ? `...params: ${paramsType}[]` : `params: ${paramsType}`; + const paramsValue = optional ? `optionalAliasParams("${object.name}.${method.name}", params)` : "params"; + const stickyParam = JSON.stringify(method.sticky_param || ""); + const stickyParams = JSON.stringify(method.sticky_params || []); + const stickyFields = JSON.stringify(method.sticky_fields || []); + if (method.return?.object) { + const target = objects.find((item) => item.name === method.return.object); + const targetType = target?.type_name || pascal(method.return.object); + const targetStickyFields = JSON.stringify(target?.sticky_fields || []); + if (method.return.array) { + return ` async ${methodName}(${signature}): Promise<${targetType}[]> {\n const request = ${paramsValue};\n const raw = await sendAliasCommandWithStickyParams(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "");\n return aliasArrayFromResult(raw, ${JSON.stringify(method.return.unwrap || "")}).map((item) => new ${targetType}(${receiver.client}, aliasStickyFromResult(item, "", ${targetStickyFields})));\n }\n`; + } + if (method.return.nullable) { + return ` async ${methodName}(${signature}): Promise<${targetType} | null> {\n const request = ${paramsValue};\n const raw = await sendAliasCommandWithStickyParams(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "");\n const unwrapped = aliasStickyFromResult(raw, ${JSON.stringify(method.return.unwrap || "")}, ${targetStickyFields});\n return Object.keys(unwrapped).length === 0 ? null : new ${targetType}(${receiver.client}, unwrapped);\n }\n`; + } + return ` async ${methodName}(${signature}): Promise<${targetType}> {\n const request = ${paramsValue};\n const raw = await sendAliasCommandWithStickyParams(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, "");\n return new ${targetType}(${receiver.client}, aliasStickyFromResult(raw, ${JSON.stringify(method.return.unwrap || "")}, ${targetStickyFields}));\n }\n`; + } + return ` async ${methodName}(${signature}): Promise<${resultType}> {\n const request = ${paramsValue};\n return sendAliasCommandWithStickyParams<${resultType}>(${receiver.client}, ${JSON.stringify(method.command)}, request, ${receiver.sticky}, ${stickyParam}, ${stickyParams}, ${stickyFields}, ${JSON.stringify(method.return?.unwrap || "")});\n }\n`; +} + +function emitTSInterface(name, schema, objectNames, optionalFields = new Set()) { + const properties = schemaPropertiesObject(schema); + if (Object.keys(properties).length === 0) return `export type ${name} = Record;\n\n`; + const required = requiredSet(schema); + let out = `export interface ${name} {\n`; + for (const [property, propertySchema] of sortedEntries(properties)) { + const optional = !required.has(property) || optionalFields.has(property); + out += ` ${tsFieldName(property)}${optional ? "?" : ""}: ${tsType(propertySchema, `${name}${pascal(property)}`, objectNames)}${optional ? " | null" : ""};\n`; + } + out += `}\n\n`; + return out; +} + +function tsType(schema, suggestedName, objectNames) { + if (!schema || typeof schema !== "object") return "JsonValue"; + if (schema.$ref) { + const ref = refName(schema.$ref); + return objectNames.has(ref) ? `${ref}Data` : ref; + } + const union = unionSchemas(schema); + if (union) { + const parts = union.filter((item) => item.type !== "null").map((item) => tsType(item, suggestedName, objectNames)); + return [...new Set(parts.length ? parts : ["null"])].join(" | "); + } + if (schema.type === "array") { + const itemType = tsType(schema.items, `${suggestedName}Item`, objectNames); + return itemType.includes(" | ") ? `(${itemType})[]` : `${itemType}[]`; + } + if (schema.type === "object" || schema.properties) { + if (schema.properties) return tsInlineObjectType(schema, objectNames); + return "{ [key: string]: JsonValue }"; + } + if (Array.isArray(schema.type)) + return schema.type + .filter((item) => item !== "null") + .map((item) => tsType({ type: item }, suggestedName, objectNames)) + .join(" | "); + if (schema.enum) return schema.enum.map((value) => JSON.stringify(value)).join(" | "); + if (schema.type === "string") return "string"; + if (schema.type === "boolean") return "boolean"; + if (schema.type === "integer" || schema.type === "number") return "number"; + return "JsonValue"; +} + +function tsInlineObjectType(schema, objectNames) { + const properties = schemaPropertiesObject(schema); + const required = requiredSet(schema); + const fields = sortedEntries(properties).map(([property, propertySchema]) => { + const optional = !required.has(property); + return `${tsFieldName(property)}${optional ? "?" : ""}: ${tsType(propertySchema, `${pascal(property)}`, objectNames)}${optional ? " | null" : ""}`; + }); + return `{ ${fields.join("; ")} }`; +} + +function tsFieldName(name) { + return /^[A-Za-z_$][A-Za-z0-9_$]*$/.test(name) ? name : JSON.stringify(name); +} + +if (process.argv[1] && import.meta.url === pathToFileURL(process.argv[1]).href) { + const manifest = JSON.parse(readFileSync("testdata/codegen/stagehand_alias_manifest.json", "utf8")); + const output = "js/test/stagehand_client_generated/stagehand_client_gen.ts"; + const generated = generateClientWithAliasesForTs( + { + language: "ts", + name: "StagehandClient", + custom_commands: manifest.custom_commands, + custom_alias_objects: manifest.custom_alias_objects, + default_config: manifest.modcdp_config, + output, + }, + manifest.custom_alias_objects ?? [], + manifest.custom_commands ?? [], + ); + mkdirSync(dirname(output), { recursive: true }); + writeFileSync(output, generated); +} diff --git a/testdata/codegen/stagehand_client_generated_go/go.mod b/testdata/codegen/stagehand_client_generated_go/go.mod new file mode 100644 index 0000000..0333d75 --- /dev/null +++ b/testdata/codegen/stagehand_client_generated_go/go.mod @@ -0,0 +1,15 @@ +module github.com/browserbase/modcdp/testdata/codegen/generated_client_go + +go 1.25.0 + +require github.com/browserbase/modcdp/go v0.0.0 + +require ( + github.com/ArchiveBox/abxbus/abxbus-go/v2 v2.5.2 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/gobwas/ws v1.4.0 // indirect + golang.org/x/sys v0.43.0 // indirect +) + +replace github.com/browserbase/modcdp/go => ../../../go diff --git a/testdata/codegen/stagehand_client_generated_go/go.sum b/testdata/codegen/stagehand_client_generated_go/go.sum new file mode 100644 index 0000000..5407c2d --- /dev/null +++ b/testdata/codegen/stagehand_client_generated_go/go.sum @@ -0,0 +1,11 @@ +github.com/ArchiveBox/abxbus/abxbus-go/v2 v2.5.2 h1:qpCQCuZ/Nfi6mHf00Ed9nI88JBaAMOh+PhXEO1hdcfQ= +github.com/ArchiveBox/abxbus/abxbus-go/v2 v2.5.2/go.mod h1:Nk7W3WwqKgQw07gQ6V+UdaUJzJwa0Qj78F/n7sERs6A= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= diff --git a/testdata/codegen/stagehand_client_generated_go/stagehand_client_live_test.go b/testdata/codegen/stagehand_client_generated_go/stagehand_client_live_test.go new file mode 100644 index 0000000..a9057c8 --- /dev/null +++ b/testdata/codegen/stagehand_client_generated_go/stagehand_client_live_test.go @@ -0,0 +1,551 @@ +// Before running this live test, build a local Stagehand alias manifest and +// regenerate the Go client: +// +// pnpm --dir ../stagehand-server exec node src/protocol/generate_modcdp_alias_manifest.mjs ../modcdp2/testdata/codegen/stagehand_alias_manifest.json +// pnpm run build:go +// pnpm run test:e2e:go +// +// The manifest is intentionally untracked; do not commit it. + +package stagehandclient + +import ( + "encoding/json" + "math" + "path/filepath" + "strconv" + "strings" + "testing" + "time" +) + +const challengeURL = "https://pirate.github.io/stress-tests/challenge.html" + +func TestGeneratedStagehandClientChallengeSuite(t *testing.T) { + fixtureFile, err := filepath.Abs("stagehand_client_live_test.go") + if err != nil { + t.Fatal(err) + } + + client := StagehandClient() + if err := client.Connect(); err != nil { + t.Fatal(err) + } + defer client.Close() + + page, err := client.Browser.NewPage(BrowserNewPageParams{Url: "about:blank"}) + if err != nil { + t.Fatal(err) + } + page, err = page.Goto(PageGotoParams{Url: challengeURL, WaitUntil: "load"}) + if err != nil { + t.Fatal(err) + } + if _, err := page.WaitForExpression(PageWaitForExpressionParams{ + Expression: "document.querySelectorAll('.task').length === 45", + TimeoutMs: 20_000, + }); err != nil { + t.Fatal(err) + } + + type challenge struct { + id string + run func(Page) error + } + + challenges := []challenge{ + {"simple-button", func(page Page) error { + button, err := page.Locate(PageLocateParams{Locator: Locator{Css: "#start-button"}}) + if err != nil { + return err + } + _, err = button.Click() + return err + }}, + {"radio-selection", func(page Page) error { + _, err := page.Click(PageClickParams{Locator: Locator{Css: `input[name="entity"][value="dont-choose-this-one"]`}}) + return err + }}, + {"checkbox-check", func(page Page) error { + firstIdx := 0 + secondIdx := 1 + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "input.challenge-checkbox", Idx: &firstIdx}}); err != nil { + return err + } + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "input.challenge-checkbox", Idx: &secondIdx}}); err != nil { + return err + } + first, err := page.Locate(PageLocateParams{Locator: Locator{Css: "input.challenge-checkbox"}}) + if err != nil { + return err + } + third, err := first.Nth(LocatorNthParams{Index: 2}) + if err != nil { + return err + } + _, err = third.Click() + return err + }}, + {"search-squash", func(page Page) error { + locator := Locator{Css: "#search-input"} + if _, err := page.Fill(PageFillParams{Locator: locator, Value: "squash"}); err != nil { + return err + } + _, err := page.KeyPress(PageKeyPressParams{Locator: locator, Key: "Enter"}) + return err + }}, + {"date-time-input", func(page Page) error { + now := time.Now() + if _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#date-picker"}, Value: now.Format("2006-01-02")}); err != nil { + return err + } + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#time-picker"}, Value: now.Format("15:04")}) + return err + }}, + {"copy-text", func(page Page) error { + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#copy-input"}, Value: "abc"}) + return err + }}, + {"slider-drag", func(page Page) error { + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#range-slider"}, OffsetX: 295, OffsetY: 8}) + return err + }}, + {"hover-element", func(page Page) error { + if _, err := page.Hover(PageHoverParams{Locator: Locator{Css: "#hover-target"}}); err != nil { + return err + } + _, err := page.WaitForTimeout(PageWaitForTimeoutParams{Ms: 1100}) + return err + }}, + {"drag-drop", func(page Page) error { + _, err := page.DragAndDrop(PageDragAndDropParams{ + Start: Locator{Css: "#drag-element"}, + End: Locator{Css: "#drop-target"}, + }) + return err + }}, + {"file-drop", func(page Page) error { + _, err := page.DropFiles(PageDropFilesParams{ + Files: []json.RawMessage{json.RawMessage(strconv.Quote(fixtureFile))}, + Locator: Locator{Css: "#file-drop-area"}, + }) + return err + }}, + {"multi-select", func(page Page) error { + _, err := page.SelectOption(PageSelectOptionParams{ + Locator: Locator{Css: "#multi-select-element"}, + Values: []string{"option1", "option2", "option3"}, + }) + return err + }}, + {"canvas-captcha", func(page Page) error { + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#canvas-text-input"}, Value: "CAPTCHA123"}) + return err + }}, + {"iframe-slider", func(page Page) error { + idx := 2 + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "#iframe-slider", Idx: &idx}, OffsetX: 295, OffsetY: 8}); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#iframe-slider-popover-button"}}) + return err + }}, + {"oopif-form-submit", func(page Page) error { + if _, err := page.WaitForLocator(PageWaitForLocatorParams{Locator: Locator{Css: "#first-name"}, TimeoutMs: 20_000}); err != nil { + return err + } + if _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#first-name"}, Value: "Ada"}); err != nil { + return err + } + if _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#last-name"}, Value: "Lovelace"}); err != nil { + return err + } + if _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#email"}, Value: "ada@example.com"}); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#contact-form button[type='submit']"}}) + return err + }}, + {"oopif-open-shadow-button", func(page Page) error { + if _, err := page.Click(PageClickParams{Locator: Locator{Xpath: "/html[1]/body[1]/div[2]/div[15]/iframe[1]/html[1]/body[1]/shadow-demo[1]//div[1]/button[1]"}}); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#oopif-open-shadow-confirm"}}) + return err + }}, + {"oopif-closed-shadow-button", func(page Page) error { + if _, err := page.Click(PageClickParams{Locator: Locator{Xpath: "/html[1]/body[1]/div[2]/div[16]/iframe[1]/html[1]/body[1]/shadow-demo[1]//div[1]/button[1]"}}); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#oopif-closed-shadow-confirm"}}) + return err + }}, + {"shadow-dom-dblclick", func(page Page) error { + _, err := page.DoubleClick(PageDoubleClickParams{Locator: Locator{Text: "Double-click me"}}) + return err + }}, + {"right-click-component", func(page Page) error { + locator, err := page.Locate(PageLocateParams{Locator: Locator{Text: "Right-click me"}}) + if err != nil { + return err + } + _, err = locator.Click(LocatorClickParams{Button: "secondary"}) + return err + }}, + {"scroll-accept", func(page Page) error { + if _, err := page.Scroll(PageScrollParams{Locator: Locator{Css: "#scroll-container"}, Percent: json.RawMessage("100")}); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#accept-button"}}) + return err + }}, + {"draw-circle", func(page Page) error { + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "#draw-canvas"}, OffsetX: 150, OffsetY: 150}); err != nil { + return err + } + if _, err := page.Scroll(PageScrollParams{DeltaY: 320}); err != nil { + return err + } + canvas, err := page.Locate(PageLocateParams{Locator: Locator{Css: "#draw-canvas"}}) + if err != nil { + return err + } + if canvas.Coordinates == nil || canvas.Coordinates.Left == nil || canvas.Coordinates.Top == nil { + t.Fatalf("draw canvas locator is missing coordinates: %+v", canvas.Coordinates) + } + left := *canvas.Coordinates.Left + top := *canvas.Coordinates.Top + for i := 0; i < 24; i++ { + angle := math.Pi * 2 * float64(i) / 24 + x := left + 150 + math.Cos(angle)*100 + y := top + 150 + math.Sin(angle)*100 + if _, err := page.Hover(PageHoverParams{Locator: Locator{Coordinates: &LocatorCoordinates{X: &x, Y: &y}}}); err != nil { + return err + } + } + return nil + }}, + {"cancel-dialog", func(page Page) error { + _, err := page.Click(PageClickParams{ + Dialog: &PageClickParamsDialog{Accept: false}, + Locator: Locator{Css: "#confirm-button"}, + }) + return err + }}, + {"geolocation-permission", func(page Page) error { + if _, err := page.SetGeolocation(PageSetGeolocationParams{ + Accuracy: 25, + Latitude: 37.7749, + Longitude: -122.4194, + Origin: "https://pirate.github.io", + }); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#request-location-button"}}) + return err + }}, + {"arrow-key-presses", func(page Page) error { + locator := Locator{Css: "#key-press-area"} + if _, err := page.Click(PageClickParams{Locator: locator}); err != nil { + return err + } + _, err := page.KeyPress(PageKeyPressParams{Locator: locator, Key: "ArrowRight", Repeat: 3}) + return err + }}, + {"alert-secret", func(page Page) error { + result, err := page.Click(PageClickParams{ + Dialog: &PageClickParamsDialog{Accept: true}, + Locator: Locator{Css: "#secret-alert-button"}, + }) + if err != nil { + return err + } + if !strings.Contains(result.Message, "avocado") { + t.Fatalf("alert message did not contain secret: %q", result.Message) + } + _, err = page.Fill(PageFillParams{Locator: Locator{Css: "#secret-word-input"}, Value: "avocado"}) + return err + }}, + {"press-hold-button", func(page Page) error { + _, err := page.ClickAndHold(PageClickAndHoldParams{DurationMs: 800, Locator: Locator{Css: "#hold-button"}}) + return err + }}, + {"tooltip-secret", func(page Page) error { + if _, err := page.Hover(PageHoverParams{Locator: Locator{Css: "#tooltip-element"}}); err != nil { + return err + } + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#tooltip-secret-input"}, Value: "octopus"}) + return err + }}, + {"resize-textarea", func(page Page) error { + textarea, err := page.Locate(PageLocateParams{Locator: Locator{Css: "#resizable-textarea"}}) + if err != nil { + return err + } + if textarea.Coordinates == nil || textarea.Coordinates.Right == nil || textarea.Coordinates.Bottom == nil { + t.Fatalf("textarea locator is missing coordinates: %+v", textarea.Coordinates) + } + right := *textarea.Coordinates.Right + bottom := *textarea.Coordinates.Bottom + startX := right - 2 + startY := bottom - 2 + endX := right + 160 + endY := bottom + 90 + if _, err := page.DragAndDrop(PageDragAndDropParams{ + Start: Locator{Coordinates: &LocatorCoordinates{X: &startX, Y: &startY}}, + End: Locator{Coordinates: &LocatorCoordinates{X: &endX, Y: &endY}}, + }); err != nil { + return err + } + _, err = page.Fill(PageFillParams{Locator: Locator{Css: "#resize-secret-input"}, Value: "giraffe"}) + return err + }}, + {"file-upload", func(page Page) error { + _, err := page.SetInputFiles(PageSetInputFilesParams{Files: []string{fixtureFile}, Locator: Locator{Css: "#file-upload-input"}}) + return err + }}, + {"phone-input", func(page Page) error { + locator := Locator{Css: "#phone-input-field"} + if _, err := page.Click(PageClickParams{Locator: locator}); err != nil { + return err + } + if _, err := page.KeyPress(PageKeyPressParams{Locator: locator, Key: "Backspace"}); err != nil { + return err + } + _, err := page.KeyPress(PageKeyPressParams{Method: "type", Locator: locator, Value: "555-1234"}) + return err + }}, + {"expand-details", func(page Page) error { + locator, err := page.Locate(PageLocateParams{Locator: Locator{Css: "#details-element summary"}}) + if err != nil { + return err + } + _, err = locator.Click() + return err + }}, + {"drag-square-to-circle", func(page Page) error { + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "#drag-canvas"}, OffsetX: 175, OffsetY: 125}); err != nil { + return err + } + if _, err := page.Scroll(PageScrollParams{DeltaY: 220}); err != nil { + return err + } + canvas, err := page.Locate(PageLocateParams{Locator: Locator{Css: "#drag-canvas"}}) + if err != nil { + return err + } + if canvas.Coordinates == nil || canvas.Coordinates.Left == nil || canvas.Coordinates.Top == nil { + t.Fatalf("drag canvas locator is missing coordinates: %+v", canvas.Coordinates) + } + left := *canvas.Coordinates.Left + top := *canvas.Coordinates.Top + startX := left + 75 + startY := top + 125 + endX := left + 250 + endY := top + 125 + _, err = page.DragAndDrop(PageDragAndDropParams{ + Start: Locator{Coordinates: &LocatorCoordinates{X: &startX, Y: &startY}}, + End: Locator{Coordinates: &LocatorCoordinates{X: &endX, Y: &endY}}, + }) + return err + }}, + {"audio-transcription", func(page Page) error { + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#transcription-input"}, Value: "everything"}) + return err + }}, + {"dropdown-selections", func(page Page) error { + if _, err := page.SelectOption(PageSelectOptionParams{Locator: Locator{Css: "#color-select"}, Values: []string{"red"}}); err != nil { + return err + } + _, err := page.SelectOption(PageSelectOptionParams{Locator: Locator{Css: "#object-select"}, Values: []string{"ball"}}) + return err + }}, + {"contenteditable-div", func(page Page) error { + locator := Locator{Css: "#editable-content"} + if _, err := page.Click(PageClickParams{Locator: locator}); err != nil { + return err + } + _, err := page.KeyPress(PageKeyPressParams{Method: "type", Locator: locator, Value: "banana"}) + return err + }}, + {"nested-tiny-button", func(page Page) error { + _, err := page.Click(PageClickParams{ + Dialog: &PageClickParamsDialog{Accept: true}, + ExpectTimeoutMs: 20_000, + Locator: Locator{Css: "#deep-tiny-button"}, + }) + return err + }}, + {"wrapped-word-click", func(page Page) error { + if _, err := page.Hover(PageHoverParams{Locator: Locator{Css: "#wrapped-word-paragraph"}}); err != nil { + return err + } + pointResult, err := page.Evaluate(PageEvaluateParams{Expression: `(() => { + const paragraph = document.getElementById("wrapped-word-paragraph"); + if (!(paragraph?.firstChild instanceof Text)) throw new Error("wrapped word paragraph text is missing"); + const paragraphRect = paragraph.getBoundingClientRect(); + const start = paragraph.textContent.indexOf("ox"); + const range = document.createRange(); + range.setStart(paragraph.firstChild, start); + range.setEnd(paragraph.firstChild, start + 2); + const rect = Array.from(range.getClientRects()).at(-1); + range.detach(); + if (rect == null) throw new Error("wrapped word range has no visible rect"); + return { x: rect.left + rect.width / 2 - paragraphRect.left, y: rect.top + rect.height / 2 - paragraphRect.top }; + })()`}) + if err != nil { + return err + } + var point struct { + X float64 `json:"x"` + Y float64 `json:"y"` + } + if err := json.Unmarshal(pointResult.Value, &point); err != nil { + return err + } + _, err = page.Click(PageClickParams{Locator: Locator{Css: "#wrapped-word-paragraph"}, OffsetX: point.X, OffsetY: point.Y}) + return err + }}, + {"long-link-maze", func(page Page) error { + _, err := page.Click(PageClickParams{Locator: Locator{Text: "TARGET-LINK::orion-needle-1847"}}) + return err + }}, + {"wall-secret-word", func(page Page) error { + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#wall-secret-input"}, Value: "cobaltglass"}) + return err + }}, + {"closed-shadow-aria-word", func(page Page) error { + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#closed-shadow-aria-input"}, Value: "violetcircuit"}) + return err + }}, + {"dynamic-frame-ordinal-trap", func(page Page) error { + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "#frame-trap-start"}}); err != nil { + return err + } + if _, err := page.WaitForLocator(PageWaitForLocatorParams{Locator: Locator{Css: "#frame-trap-create-third"}, TimeoutMs: 10_000}); err != nil { + return err + } + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "#frame-trap-create-third"}}); err != nil { + return err + } + if _, err := page.WaitForLocator(PageWaitForLocatorParams{Locator: Locator{Css: "#frame-trap-insert-final"}, TimeoutMs: 10_000}); err != nil { + return err + } + if _, err := page.Click(PageClickParams{Locator: Locator{Css: "#frame-trap-insert-final"}}); err != nil { + return err + } + if _, err := page.WaitForLocator(PageWaitForLocatorParams{Locator: Locator{Css: "#frame-trap-final-button"}, TimeoutMs: 10_000}); err != nil { + return err + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#frame-trap-final-button"}}) + return err + }}, + {"rotated-transform-click", func(page Page) error { + if _, err := page.Hover(PageHoverParams{Locator: Locator{Css: "#rotated-transform-button"}}); err != nil { + return err + } + pointResult, err := page.Evaluate(PageEvaluateParams{Expression: `(() => { + const button = document.getElementById("rotated-transform-button"); + if (button == null) throw new Error("rotated transform button is missing"); + const rect = button.getBoundingClientRect(); + return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 }; + })()`}) + if err != nil { + return err + } + var point struct { + X float64 `json:"x"` + Y float64 `json:"y"` + } + if err := json.Unmarshal(pointResult.Value, &point); err != nil { + return err + } + _, err = page.Click(PageClickParams{Locator: Locator{Coordinates: &LocatorCoordinates{X: &point.X, Y: &point.Y}}}) + return err + }}, + {"transparent-overlay-button", func(page Page) error { + locator, err := page.Locate(PageLocateParams{Locator: Locator{Css: "#transparent-overlay-target"}}) + if err != nil { + return err + } + _, err = locator.Click() + return err + }}, + {"realistic-input-sequence", func(page Page) error { + panel := Locator{Css: "#realistic-scroll-panel"} + if _, err := page.Scroll(PageScrollParams{Locator: panel, DeltaY: 450}); err != nil { + return err + } + if _, err := page.Scroll(PageScrollParams{Locator: panel, DeltaY: 450}); err != nil { + return err + } + locator := Locator{Css: "#realistic-type-input"} + if _, err := page.Click(PageClickParams{Locator: locator}); err != nil { + return err + } + for _, key := range []string{"o", "r", "c", "h", "i", "d"} { + if _, err := page.KeyPress(PageKeyPressParams{Locator: locator, Key: key}); err != nil { + return err + } + } + _, err := page.Click(PageClickParams{Locator: Locator{Css: "#realistic-submit-button"}}) + return err + }}, + {"css-transform-text", func(page Page) error { + _, err := page.Fill(PageFillParams{Locator: Locator{Css: "#css-transform-text-input"}, Value: "skyline"}) + return err + }}, + {"google-docs", func(page Page) error { + locator := Locator{Css: "#google-docs-answer-input"} + if _, err := page.Click(PageClickParams{Locator: locator}); err != nil { + return err + } + _, err := page.KeyPress(PageKeyPressParams{Method: "type", Locator: locator, Value: "snooker"}) + return err + }}, + } + + if len(challenges) != 45 { + t.Fatalf("expected 45 challenges, got %d", len(challenges)) + } + for _, challenge := range challenges { + if err := challenge.run(page); err != nil { + t.Fatalf("%s failed: %v", challenge.id, err) + } + if _, err := page.WaitForExpression(PageWaitForExpressionParams{ + Expression: "document.getElementById(" + strconv.Quote(challenge.id) + ")?.classList.contains('completed') === true", + TimeoutMs: 20_000, + }); err != nil { + t.Fatalf("%s did not complete: %v", challenge.id, err) + } + } + + scoreResult, err := page.Evaluate(PageEvaluateParams{Expression: "document.querySelectorAll('.task.completed').length"}) + if err != nil { + t.Fatal(err) + } + var score int + if err := json.Unmarshal(scoreResult.Value, &score); err != nil { + t.Fatal(err) + } + countsResult, err := page.Evaluate(PageEvaluateParams{Expression: "(() => ({ completed: document.querySelectorAll('.task.completed').length, total: document.querySelectorAll('.task').length }))()"}) + if err != nil { + t.Fatal(err) + } + var counts struct { + Completed int `json:"completed"` + Total int `json:"total"` + } + if err := json.Unmarshal(countsResult.Value, &counts); err != nil { + t.Fatal(err) + } + if score != counts.Completed { + t.Fatalf("score %d did not match completed count %d", score, counts.Completed) + } + if counts.Total != 45 { + t.Fatalf("expected 45 total challenges, got %d", counts.Total) + } + if score != 45 { + t.Fatalf("expected score 45, got %d", score) + } +} diff --git a/tsconfig.json b/tsconfig.json index fa02b7e..a38f797 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -17,9 +17,10 @@ "skipLibCheck": true, "sourceMap": true, "strictNullChecks": false, + "rewriteRelativeImportExtensions": true, "useUnknownInCatchVariables": false, "verbatimModuleSyntax": true }, - "include": ["js/src/**/*.ts", "js/examples/**/*.ts", "js/test/**/*.ts", "extension/src/**/*.ts"], - "exclude": ["dist", "node_modules"] + "include": ["src/**/*.ts", "js/src/**/*.ts", "js/examples/**/*.ts", "js/test/**/*.ts", "extension/src/**/*.ts"], + "exclude": ["dist", "node_modules", "src/codegen/codegen_*.ts"] }