|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Auto-fix script for httpx.Client() violations. |
| 3 | +
|
| 4 | +This script automatically fixes violations by adding **get_httpx_client_kwargs() |
| 5 | +to httpx.Client() and httpx.AsyncClient() calls. |
| 6 | +""" |
| 7 | + |
| 8 | +import ast |
| 9 | +import sys |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | + |
| 13 | +class HttpxClientFixer(ast.NodeTransformer): |
| 14 | + """AST transformer to fix httpx.Client() usage violations.""" |
| 15 | + |
| 16 | + def __init__(self, filename: str): |
| 17 | + """Initialize the fixer with a filename. |
| 18 | +
|
| 19 | + Args: |
| 20 | + filename: The path to the file being fixed. |
| 21 | + """ |
| 22 | + self.filename = filename |
| 23 | + self.fixes_applied = 0 |
| 24 | + self.needs_import = False |
| 25 | + self.has_httpx_import = False |
| 26 | + self.has_get_httpx_client_kwargs_import = False |
| 27 | + |
| 28 | + def visit_Import(self, node: ast.Import) -> ast.Import: |
| 29 | + """Check for httpx imports.""" |
| 30 | + for alias in node.names: |
| 31 | + if alias.name == "httpx": |
| 32 | + self.has_httpx_import = True |
| 33 | + return node |
| 34 | + |
| 35 | + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: |
| 36 | + """Check for imports from httpx or get_httpx_client_kwargs.""" |
| 37 | + if node.module == "httpx": |
| 38 | + self.has_httpx_import = True |
| 39 | + elif node.module and "get_httpx_client_kwargs" in [ |
| 40 | + alias.name for alias in (node.names or []) |
| 41 | + ]: |
| 42 | + self.has_get_httpx_client_kwargs_import = True |
| 43 | + return node |
| 44 | + |
| 45 | + def visit_Call(self, node: ast.Call) -> ast.Call: |
| 46 | + """Fix httpx.Client() and httpx.AsyncClient() calls.""" |
| 47 | + if self._is_httpx_client_call(node): |
| 48 | + if not self._is_using_get_httpx_client_kwargs(node): |
| 49 | + # Add **get_httpx_client_kwargs() to the call |
| 50 | + new_keyword = ast.keyword( |
| 51 | + arg=None, # **kwargs |
| 52 | + value=ast.Call( |
| 53 | + func=ast.Name(id="get_httpx_client_kwargs", ctx=ast.Load()), |
| 54 | + args=[], |
| 55 | + keywords=[], |
| 56 | + ), |
| 57 | + ) |
| 58 | + |
| 59 | + # Create a new call with the added keyword |
| 60 | + new_node = ast.Call( |
| 61 | + func=node.func, |
| 62 | + args=node.args, |
| 63 | + keywords=node.keywords + [new_keyword], |
| 64 | + ) |
| 65 | + |
| 66 | + # Copy location information |
| 67 | + ast.copy_location(new_node, node) |
| 68 | + |
| 69 | + self.fixes_applied += 1 |
| 70 | + self.needs_import = True |
| 71 | + |
| 72 | + return new_node |
| 73 | + |
| 74 | + return self.generic_visit(node) |
| 75 | + |
| 76 | + def _is_httpx_client_call(self, node: ast.Call) -> bool: |
| 77 | + """Check if the call is httpx.Client() or httpx.AsyncClient().""" |
| 78 | + if isinstance(node.func, ast.Attribute): |
| 79 | + if ( |
| 80 | + isinstance(node.func.value, ast.Name) |
| 81 | + and node.func.value.id == "httpx" |
| 82 | + and node.func.attr in ("Client", "AsyncClient") |
| 83 | + ): |
| 84 | + return True |
| 85 | + elif isinstance(node.func, ast.Name) and node.func.id in ( |
| 86 | + "Client", |
| 87 | + "AsyncClient", |
| 88 | + ): |
| 89 | + return self.has_httpx_import |
| 90 | + return False |
| 91 | + |
| 92 | + def _is_using_get_httpx_client_kwargs(self, node: ast.Call) -> bool: |
| 93 | + """Check if the call already uses **get_httpx_client_kwargs().""" |
| 94 | + for keyword in node.keywords: |
| 95 | + if keyword.arg is None and isinstance(keyword.value, ast.Call): |
| 96 | + if isinstance(keyword.value.func, ast.Name): |
| 97 | + if keyword.value.func.id == "get_httpx_client_kwargs": |
| 98 | + return True |
| 99 | + elif isinstance(keyword.value.func, ast.Attribute): |
| 100 | + if keyword.value.func.attr == "get_httpx_client_kwargs": |
| 101 | + return True |
| 102 | + return False |
| 103 | + |
| 104 | + |
| 105 | +def fix_file(filepath: Path) -> bool: |
| 106 | + """Fix a single Python file for httpx.Client() violations.""" |
| 107 | + try: |
| 108 | + with open(filepath, "r", encoding="utf-8") as f: |
| 109 | + content = f.read() |
| 110 | + |
| 111 | + tree = ast.parse(content, filename=str(filepath)) |
| 112 | + fixer = HttpxClientFixer(str(filepath)) |
| 113 | + |
| 114 | + # Transform the AST |
| 115 | + new_tree = fixer.visit(tree) |
| 116 | + |
| 117 | + if fixer.fixes_applied > 0: |
| 118 | + # Add import if needed and not already present |
| 119 | + if fixer.needs_import and not fixer.has_get_httpx_client_kwargs_import: |
| 120 | + # Find a good place to add the import |
| 121 | + import_added = False |
| 122 | + for i, node in enumerate(new_tree.body): |
| 123 | + if ( |
| 124 | + isinstance(node, ast.ImportFrom) |
| 125 | + and node.module |
| 126 | + and "uipath" in node.module |
| 127 | + ): |
| 128 | + # Add to existing uipath import if possible |
| 129 | + if any( |
| 130 | + alias.name == "get_httpx_client_kwargs" |
| 131 | + for alias in (node.names or []) |
| 132 | + ): |
| 133 | + break # Already imported |
| 134 | + # Add new import after existing uipath imports |
| 135 | + new_import = ast.ImportFrom( |
| 136 | + module="uipath._utils._ssl_context", |
| 137 | + names=[ |
| 138 | + ast.alias(name="get_httpx_client_kwargs", asname=None) |
| 139 | + ], |
| 140 | + level=0, |
| 141 | + ) |
| 142 | + new_tree.body.insert(i + 1, new_import) |
| 143 | + import_added = True |
| 144 | + break |
| 145 | + |
| 146 | + if not import_added: |
| 147 | + # Add at the beginning after other imports |
| 148 | + insert_pos = 0 |
| 149 | + for i, node in enumerate(new_tree.body): |
| 150 | + if isinstance(node, (ast.Import, ast.ImportFrom)): |
| 151 | + insert_pos = i + 1 |
| 152 | + else: |
| 153 | + break |
| 154 | + |
| 155 | + new_import = ast.ImportFrom( |
| 156 | + module="uipath._utils._ssl_context", |
| 157 | + names=[ast.alias(name="get_httpx_client_kwargs", asname=None)], |
| 158 | + level=0, |
| 159 | + ) |
| 160 | + new_tree.body.insert(insert_pos, new_import) |
| 161 | + |
| 162 | + # Convert back to code |
| 163 | + import astor |
| 164 | + |
| 165 | + fixed_content = astor.to_source(new_tree) |
| 166 | + |
| 167 | + # Write back to file |
| 168 | + with open(filepath, "w", encoding="utf-8") as f: |
| 169 | + f.write(fixed_content) |
| 170 | + |
| 171 | + print(f"Fixed {fixer.fixes_applied} violations in {filepath}") |
| 172 | + return True |
| 173 | + |
| 174 | + return False |
| 175 | + |
| 176 | + except Exception as e: |
| 177 | + print(f"Error fixing {filepath}: {e}", file=sys.stderr) |
| 178 | + return False |
| 179 | + |
| 180 | + |
| 181 | +def main(): |
| 182 | + """Main function to run the fixer.""" |
| 183 | + if len(sys.argv) > 1: |
| 184 | + paths = [Path(p) for p in sys.argv[1:]] |
| 185 | + else: |
| 186 | + # Default to checking src and tests directories |
| 187 | + paths = [Path("src"), Path("tests")] |
| 188 | + |
| 189 | + total_files_fixed = 0 |
| 190 | + |
| 191 | + for path in paths: |
| 192 | + if path.is_file() and path.suffix == ".py": |
| 193 | + if fix_file(path): |
| 194 | + total_files_fixed += 1 |
| 195 | + elif path.is_dir(): |
| 196 | + for py_file in path.rglob("*.py"): |
| 197 | + if fix_file(py_file): |
| 198 | + total_files_fixed += 1 |
| 199 | + |
| 200 | + if total_files_fixed > 0: |
| 201 | + print(f"\nFixed {total_files_fixed} files. Run the linter again to verify.") |
| 202 | + else: |
| 203 | + print("No files needed fixing.") |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == "__main__": |
| 207 | + main() |
0 commit comments