|
5 | 5 | import sys |
6 | 6 | import timeit |
7 | 7 | import traceback |
8 | | -import pandas as pd |
9 | | -from daft import col, DataType, TimeUnit |
| 8 | +from daft import col, DataType |
10 | 9 |
|
11 | 10 | hits = None |
12 | 11 | current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
20 | 19 | with open("queries.sql") as f: |
21 | 20 | sql_list = [q.strip() for q in f.read().split(';') if q.strip()] |
22 | 21 |
|
23 | | -def daft_offset(df, start ,end): |
24 | | - pandas_df = df.to_pandas() |
25 | | - sliced_df = pandas_df.iloc[start:end] |
26 | | - return sliced_df |
27 | | - |
28 | | -queries = [] |
29 | | -for idx, sql in enumerate(sql_list): |
30 | | - query_entry = {"sql": sql} |
31 | | - |
32 | | - # Current limitations and workarounds for Daft execution: |
33 | | - |
34 | | - # 1. Queries q18, q35, q42 require manual API workarounds: |
35 | | - # - q18: The function `extract(minute FROM EventTime)` causes an error: |
36 | | - # `expected input to minute to be temporal, got UInt32`. |
37 | | - # - q35: Error is `duplicate field name ClientIP in the schema`. |
38 | | - # Attempts to alias the column in SQL but still failed. |
39 | | - # - q42: The function `DATE_TRUNC('minute', EventTime)` causes an error: |
40 | | - # `Unsupported SQL: Function date_trunc not found`. |
41 | | - if idx in [18, 35, 42]: |
42 | | - if idx == 18: |
43 | | - query_entry["lambda"] = lambda: ( |
44 | | - hits.with_column("m", col("EventTime").dt.minute()) |
45 | | - .groupby("UserID", "m", "SearchPhrase") |
46 | | - .agg(daft.sql_expr("COUNT(1)").alias("COUNT(*)")) |
47 | | - .sort("COUNT(*)", desc=True) |
48 | | - .limit(10) |
49 | | - .select("UserID", "m", "SearchPhrase", "COUNT(*)") |
50 | | - ) |
51 | | - elif idx == 35: |
52 | | - query_entry["lambda"] = lambda: ( |
53 | | - hits.groupby( |
54 | | - "ClientIP", |
55 | | - daft.sql_expr("ClientIP - 1").alias("ClientIP - 1"), |
56 | | - daft.sql_expr("ClientIP - 2").alias("ClientIP - 2"), |
57 | | - daft.sql_expr("ClientIP - 3").alias("ClientIP - 3")) |
58 | | - .agg(daft.sql_expr("COUNT(1)").alias("c")) |
59 | | - .sort("c", desc=True) |
60 | | - .limit(10) |
61 | | - .select("ClientIP", "ClientIP - 1", "ClientIP - 2", "ClientIP - 3", "c") |
62 | | - ) |
63 | | - elif idx == 42: |
64 | | - query_entry["lambda"] = lambda: ( |
65 | | - hits.with_column("M", col("EventTime").dt.truncate("1 minute")) |
66 | | - .where("CounterID = 62 AND EventDate >= '2013-07-14' AND EventDate <= '2013-07-15' AND IsRefresh = 0 AND DontCountHits = 0") |
67 | | - .groupby("M") |
68 | | - .agg(daft.sql_expr("COUNT(1)").alias("PageViews")) |
69 | | - .sort("M", desc=False) |
70 | | - .limit(1010) |
71 | | - .select("M", "PageViews") |
72 | | - ) |
73 | | - |
74 | | - # 2. OFFSET operator not supported in Daft: |
75 | | - # For queries q38, q39, q40, q41, q42, after executing the query, |
76 | | - # manually implement the `OFFSET` truncation logic via the API |
77 | | - if 38 <= idx <= 42: |
78 | | - if idx == 38: |
79 | | - query_entry["extra_api"] = lambda df: daft_offset(df, 1000, 1010) |
80 | | - elif idx == 39: |
81 | | - query_entry["extra_api"] = lambda df: daft_offset(df, 1000, 1010) |
82 | | - elif idx == 40: |
83 | | - query_entry["extra_api"] = lambda df: daft_offset(df, 100, 110) |
84 | | - elif idx == 41: |
85 | | - query_entry["extra_api"] = lambda df: daft_offset(df, 10000, 10010) |
86 | | - elif idx == 42: |
87 | | - query_entry["extra_api"] = lambda df: daft_offset(df, 1000, 1010) |
88 | | - |
89 | | - queries.append(query_entry) |
90 | | - |
91 | | -def run_single_query(query, i): |
| 22 | +def run_single_query(sql, i): |
92 | 23 | try: |
93 | 24 | start = timeit.default_timer() |
94 | 25 |
|
95 | 26 | global hits |
96 | 27 | if hits is None: |
97 | 28 | hits = daft.read_parquet(parquet_path) |
98 | | - hits = hits.with_column("EventTime", col("EventTime").cast(daft.DataType.timestamp("s"))) |
99 | | - hits = hits.with_column("EventDate", col("EventDate").cast(daft.DataType.date())) |
| 29 | + hits = hits.with_column("EventTime", col("EventTime").cast(DataType.timestamp("s"))) |
| 30 | + hits = hits.with_column("EventDate", col("EventDate").cast(DataType.date())) |
100 | 31 | hits = hits.with_column("URL", col("URL").decode("utf-8")) |
101 | 32 | hits = hits.with_column("Title", col("Title").decode("utf-8")) |
102 | 33 | hits = hits.with_column("Referer", col("Referer").decode("utf-8")) |
103 | 34 | hits = hits.with_column("MobilePhoneModel", col("MobilePhoneModel").decode("utf-8")) |
104 | 35 | hits = hits.with_column("SearchPhrase", col("SearchPhrase").decode("utf-8")) |
105 | 36 |
|
106 | | - result = None |
107 | | - |
108 | | - if "lambda" in query: |
109 | | - result = query["lambda"]() |
110 | | - else: |
111 | | - result = daft.sql(query["sql"]) |
112 | | - |
| 37 | + result = daft.sql(sql) |
113 | 38 | result.collect() |
114 | 39 |
|
115 | | - if "extra_api" in query: |
116 | | - result = query["extra_api"](result) |
117 | | - |
118 | 40 | run_time = round(timeit.default_timer() - start, 3) |
119 | | - |
120 | 41 | return run_time |
121 | 42 | except Exception as e: |
122 | 43 | print(f"Error executing query {query_idx}: {str(e)[:100]}", file=sys.stderr) |
123 | 44 | traceback.print_exc() |
124 | 45 | return None |
125 | 46 |
|
126 | 47 | if __name__ == "__main__": |
127 | | - query = queries[query_idx] |
128 | | - |
| 48 | + sql = sql_list[query_idx] |
129 | 49 | times = [] |
130 | 50 | for i in range(3): |
131 | | - elapsed = run_single_query(query, i) |
| 51 | + elapsed = run_single_query(sql, i) |
132 | 52 | times.append(f"{elapsed}" if elapsed else "") |
133 | | - |
134 | 53 | print(','.join(times)) |
0 commit comments