Skip to content

Commit f9e9fb3

Browse files
committed
Add tests for 100-continue connection pool corruption scenario
1 parent bb4b904 commit f9e9fb3

3 files changed

Lines changed: 335 additions & 0 deletions

File tree

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#!/usr/bin/env python3
2+
"""Client that sends two requests on one TCP connection to reproduce
3+
100-continue connection pool corruption."""
4+
5+
# Licensed to the Apache Software Foundation (ASF) under one
6+
# or more contributor license agreements. See the NOTICE file
7+
# distributed with this work for additional information
8+
# regarding copyright ownership. The ASF licenses this file
9+
# to you under the Apache License, Version 2.0 (the
10+
# "License"); you may not use this file except in compliance
11+
# with the License. You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from http_utils import wait_for_headers_complete, determine_outstanding_bytes_to_read, drain_socket
22+
23+
import argparse
24+
import socket
25+
import sys
26+
import time
27+
28+
29+
def main() -> int:
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument('proxy_address')
32+
parser.add_argument('proxy_port', type=int)
33+
parser.add_argument('-s', '--server-hostname', dest='server_hostname',
34+
default='example.com')
35+
args = parser.parse_args()
36+
37+
host = args.server_hostname
38+
body_size = 103
39+
body_data = b'X' * body_size
40+
41+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
42+
sock.connect((args.proxy_address, args.proxy_port))
43+
44+
with sock:
45+
# Request 1: POST with Expect: 100-continue and a body.
46+
request1 = (
47+
f'POST /expect-100-corrupted HTTP/1.1\r\n'
48+
f'Host: {host}\r\n'
49+
f'Connection: keep-alive\r\n'
50+
f'Content-Length: {body_size}\r\n'
51+
f'Expect: 100-continue\r\n'
52+
f'\r\n'
53+
).encode()
54+
sock.sendall(request1)
55+
56+
# Send the body after a short delay without waiting for 100-continue.
57+
time.sleep(0.5)
58+
sock.sendall(body_data)
59+
60+
# Drain the response (might be 100 + 301, or just 301).
61+
resp1_data = wait_for_headers_complete(sock)
62+
63+
# If we got a 100 Continue, read past it to the real response.
64+
if b'100' in resp1_data.split(b'\r\n')[0]:
65+
after_100 = resp1_data.split(b'\r\n\r\n', 1)[1] if b'\r\n\r\n' in resp1_data else b''
66+
if b'\r\n\r\n' not in after_100:
67+
after_100 += wait_for_headers_complete(sock)
68+
resp1_data = after_100
69+
70+
# Drain the response body.
71+
try:
72+
outstanding = determine_outstanding_bytes_to_read(resp1_data)
73+
if outstanding > 0:
74+
drain_socket(sock, resp1_data, outstanding)
75+
except ValueError:
76+
pass
77+
78+
# Let ATS pool the origin connection.
79+
time.sleep(0.5)
80+
81+
# Request 2: plain GET on the same client connection.
82+
request2 = (
83+
f'GET /second-request HTTP/1.1\r\n'
84+
f'Host: {host}\r\n'
85+
f'Connection: close\r\n'
86+
f'\r\n'
87+
).encode()
88+
sock.sendall(request2)
89+
90+
resp2_data = wait_for_headers_complete(sock)
91+
status_line = resp2_data.split(b'\r\n')[0]
92+
93+
if b'400' in status_line or b'corrupted' in resp2_data.lower():
94+
print('Corruption detected: second request saw corrupted data', flush=True)
95+
elif b'502' in status_line:
96+
print('Corruption detected: ATS returned 502 (origin parse error)', flush=True)
97+
else:
98+
print('No corruption: second request completed normally', flush=True)
99+
100+
return 0
101+
102+
103+
if __name__ == '__main__':
104+
sys.exit(main())
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env python3
2+
"""Origin that sends a 301 without consuming the request body, then checks
3+
whether a reused connection carries leftover (corrupted) data. Handles
4+
multiple connections so that a fixed ATS can open a fresh one for the
5+
second request."""
6+
7+
# Licensed to the Apache Software Foundation (ASF) under one
8+
# or more contributor license agreements. See the NOTICE file
9+
# distributed with this work for additional information
10+
# regarding copyright ownership. The ASF licenses this file
11+
# to you under the Apache License, Version 2.0 (the
12+
# "License"); you may not use this file except in compliance
13+
# with the License. You may obtain a copy of the License at
14+
#
15+
# http://www.apache.org/licenses/LICENSE-2.0
16+
#
17+
# Unless required by applicable law or agreed to in writing, software
18+
# distributed under the License is distributed on an "AS IS" BASIS,
19+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20+
# See the License for the specific language governing permissions and
21+
# limitations under the License.
22+
23+
import argparse
24+
import socket
25+
import sys
26+
import threading
27+
import time
28+
29+
VALID_METHODS = {'GET', 'POST', 'PUT', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH'}
30+
31+
32+
def read_until_headers_complete(conn: socket.socket) -> bytes:
33+
data = b''
34+
while b'\r\n\r\n' not in data:
35+
chunk = conn.recv(4096)
36+
if not chunk:
37+
return data
38+
data += chunk
39+
return data
40+
41+
42+
def is_valid_http_request_line(line: str) -> bool:
43+
parts = line.strip().split(' ')
44+
if len(parts) < 3:
45+
return False
46+
return parts[0] in VALID_METHODS and parts[-1].startswith('HTTP/')
47+
48+
49+
def send_200(conn: socket.socket) -> None:
50+
ok_body = b'OK'
51+
conn.sendall(
52+
b'HTTP/1.1 200 OK\r\n'
53+
b'Content-Length: ' + str(len(ok_body)).encode() + b'\r\n'
54+
b'\r\n' + ok_body)
55+
56+
57+
def handle_connection(conn: socket.socket, args: argparse.Namespace,
58+
result: dict) -> None:
59+
try:
60+
data = read_until_headers_complete(conn)
61+
if not data:
62+
# Readiness probe.
63+
conn.close()
64+
return
65+
66+
first_line = data.split(b'\r\n')[0].decode('utf-8', errors='replace')
67+
68+
if first_line.startswith('POST'):
69+
# First request: send 301 without consuming the body.
70+
time.sleep(args.delay)
71+
72+
body = b'Redirecting'
73+
response = (
74+
b'HTTP/1.1 301 Moved Permanently\r\n'
75+
b'Location: http://example.com/\r\n'
76+
b'Connection: keep-alive\r\n'
77+
b'Content-Length: ' + str(len(body)).encode() + b'\r\n'
78+
b'\r\n' + body
79+
)
80+
conn.sendall(response)
81+
82+
# Wait for potential reuse on this connection.
83+
conn.settimeout(args.timeout)
84+
try:
85+
second_data = b''
86+
while b'\r\n' not in second_data:
87+
chunk = conn.recv(4096)
88+
if not chunk:
89+
break
90+
second_data += chunk
91+
92+
if second_data:
93+
second_line = second_data.split(b'\r\n')[0].decode('utf-8', errors='replace')
94+
if is_valid_http_request_line(second_line):
95+
send_200(conn)
96+
else:
97+
result['corrupted'] = True
98+
err_body = b'corrupted'
99+
conn.sendall(
100+
b'HTTP/1.1 400 Bad Request\r\n'
101+
b'Content-Length: ' + str(len(err_body)).encode() + b'\r\n'
102+
b'\r\n' + err_body)
103+
except socket.timeout:
104+
pass
105+
106+
elif first_line.startswith('GET'):
107+
# Second request on a new connection (fix is working).
108+
result['new_connection'] = True
109+
send_200(conn)
110+
111+
conn.close()
112+
except Exception:
113+
try:
114+
conn.close()
115+
except Exception:
116+
pass
117+
118+
119+
def main() -> int:
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument('port', type=int)
122+
parser.add_argument('--delay', type=float, default=1.0)
123+
parser.add_argument('--timeout', type=float, default=5.0)
124+
args = parser.parse_args()
125+
126+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
127+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
128+
sock.bind(('', args.port))
129+
sock.listen(5)
130+
sock.settimeout(args.timeout + 5)
131+
132+
result = {'corrupted': False, 'new_connection': False}
133+
threads = []
134+
connections_handled = 0
135+
136+
try:
137+
while connections_handled < 10:
138+
try:
139+
conn, _ = sock.accept()
140+
t = threading.Thread(target=handle_connection,
141+
args=(conn, args, result))
142+
t.daemon = True
143+
t.start()
144+
threads.append(t)
145+
connections_handled += 1
146+
except socket.timeout:
147+
break
148+
except Exception:
149+
pass
150+
151+
for t in threads:
152+
t.join(timeout=args.timeout + 2)
153+
154+
sock.close()
155+
return 0
156+
157+
158+
if __name__ == '__main__':
159+
sys.exit(main())
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import sys
18+
19+
Test.Summary = '''
20+
Verify that when an origin responds before consuming the request body on a
21+
connection with Expect: 100-continue, ATS does not return the origin connection
22+
to the pool with unconsumed data.
23+
'''
24+
25+
tr = Test.AddTestRun(
26+
'Verify 100-continue with early origin response does not corrupt pooled connections.')
27+
28+
# DNS.
29+
dns = tr.MakeDNServer('dns', default='127.0.0.1')
30+
31+
# Origin.
32+
Test.GetTcpPort('origin_port')
33+
tr.Setup.CopyAs('corruption_origin.py')
34+
origin = tr.Processes.Process(
35+
'origin',
36+
f'{sys.executable} corruption_origin.py '
37+
f'{Test.Variables.origin_port} --delay 1.0 --timeout 5.0')
38+
origin.Ready = When.PortOpen(Test.Variables.origin_port)
39+
40+
# ATS.
41+
ts = tr.MakeATSProcess('ts', enable_cache=False)
42+
ts.Disk.remap_config.AddLine(
43+
f'map / http://backend.example.com:{Test.Variables.origin_port}')
44+
ts.Disk.records_config.update({
45+
'proxy.config.diags.debug.enabled': 1,
46+
'proxy.config.diags.debug.tags': 'http',
47+
'proxy.config.dns.nameservers': f'127.0.0.1:{dns.Variables.Port}',
48+
'proxy.config.dns.resolv_conf': 'NULL',
49+
'proxy.config.http.send_100_continue_response': 1,
50+
})
51+
52+
# Client.
53+
tr.Setup.CopyAs('corruption_client.py')
54+
tr.Setup.CopyAs('http_utils.py')
55+
tr.Processes.Default.Command = (
56+
f'{sys.executable} corruption_client.py '
57+
f'127.0.0.1 {ts.Variables.port} '
58+
f'-s backend.example.com')
59+
tr.Processes.Default.ReturnCode = 0
60+
tr.Processes.Default.StartBefore(dns)
61+
tr.Processes.Default.StartBefore(origin)
62+
tr.Processes.Default.StartBefore(ts)
63+
64+
# With the fix, ATS should not pool the origin connection when the
65+
# request body was not fully consumed, preventing corruption.
66+
tr.Processes.Default.Streams.stdout += Testers.ContainsExpression(
67+
'No corruption',
68+
'The second request should complete normally because ATS '
69+
'does not pool origin connections with unconsumed body data.')
70+
tr.Processes.Default.Streams.stdout += Testers.ExcludesExpression(
71+
'Corruption detected',
72+
'No corruption should be detected.')

0 commit comments

Comments
 (0)