-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathsqlalchemy_spanner_asyncio.py
More file actions
147 lines (104 loc) · 3.99 KB
/
sqlalchemy_spanner_asyncio.py
File metadata and controls
147 lines (104 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import asyncio
from .sqlalchemy_spanner import SpannerDialect
from sqlalchemy.connectors.asyncio import (
AsyncAdapt_dbapi_connection,
AsyncAdapt_dbapi_cursor,
AsyncAdapt_dbapi_module,
)
from sqlalchemy.util.concurrency import await_only
class AsyncIODBAPISpannerCursor:
def __init__(self, sync_cursor):
self._sync_cursor = sync_cursor
@property
def description(self):
return self._sync_cursor.description
@property
def rowcount(self):
return self._sync_cursor.rowcount
@property
def lastrowid(self):
return self._sync_cursor.lastrowid
@property
def arraysize(self):
return self._sync_cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._sync_cursor.arraysize = value
async def close(self):
await asyncio.to_thread(self._sync_cursor.close)
async def execute(self, operation, parameters=None):
return await asyncio.to_thread(self._sync_cursor.execute, operation, parameters)
async def executemany(self, operation, seq_of_parameters):
return await asyncio.to_thread(
self._sync_cursor.executemany, operation, seq_of_parameters
)
async def fetchone(self):
return await asyncio.to_thread(self._sync_cursor.fetchone)
async def fetchmany(self, size=None):
return await asyncio.to_thread(self._sync_cursor.fetchmany, size)
async def fetchall(self):
return await asyncio.to_thread(self._sync_cursor.fetchall)
async def nextset(self):
if hasattr(self._sync_cursor, "nextset"):
return await asyncio.to_thread(self._sync_cursor.nextset)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
class AsyncIODBAPISpannerConnection:
def __init__(self, sync_conn):
self._sync_conn = sync_conn
async def commit(self):
await asyncio.to_thread(self._sync_conn.commit)
async def rollback(self):
await asyncio.to_thread(self._sync_conn.rollback)
async def close(self):
await asyncio.to_thread(self._sync_conn.close)
def cursor(self):
return AsyncIODBAPISpannerCursor(self._sync_conn.cursor())
def __getattr__(self, name):
return getattr(self._sync_conn, name)
class AsyncAdapt_spanner_cursor(AsyncAdapt_dbapi_cursor):
@property
def connection(self):
return self._adapt_connection
class AsyncAdapt_spanner_connection(AsyncAdapt_dbapi_connection):
_cursor_cls = AsyncAdapt_spanner_cursor
@property
def connection(self):
return self._connection._sync_conn
def __getattr__(self, name):
return getattr(self._connection, name)
class AsyncAdapt_spanner_dbapi(AsyncAdapt_dbapi_module):
await_ = staticmethod(await_only)
def __init__(self, spanner_dbapi):
self.spanner_dbapi = spanner_dbapi
for name in dir(spanner_dbapi):
if not name.startswith("__") and name != "connect":
setattr(self, name, getattr(spanner_dbapi, name))
def connect(self, *arg, **kw):
async_creator_fn = kw.pop("async_creator_fn", None)
if async_creator_fn:
connection = async_creator_fn(*arg, **kw)
else:
connection = self.spanner_dbapi.connect(*arg, **kw)
return AsyncAdapt_spanner_connection(
self, AsyncIODBAPISpannerConnection(connection)
)
class SpannerDialect_asyncio(SpannerDialect):
driver = "spanner_asyncio"
is_async = True
supports_statement_cache = True
@classmethod
def import_dbapi(cls):
from google.cloud import spanner_dbapi
return AsyncAdapt_spanner_dbapi(spanner_dbapi)
@classmethod
def dbapi(cls):
return cls.import_dbapi()
@classmethod
def get_pool_class(cls, url):
from sqlalchemy.pool import AsyncAdaptedQueuePool
return AsyncAdaptedQueuePool
def get_driver_connection(self, connection):
return connection._connection