diff --git a/requests_unixsocket/__init__.py b/requests_unixsocket/__init__.py index 0fb5e1f..516ccef 100644 --- a/requests_unixsocket/__init__.py +++ b/requests_unixsocket/__init__.py @@ -1,20 +1,25 @@ +import os import requests import sys from .adapters import UnixAdapter -DEFAULT_SCHEME = 'http+unix://' +DEFAULT_SCHEMES = os.getenv( + 'REQUESTS_UNIXSOCKET_URL_SCHEMES', + 'http+unix://,http://sock.local/' +).split(',') class Session(requests.Session): - def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs): + def __init__(self, url_schemes=DEFAULT_SCHEMES, *args, **kwargs): super(Session, self).__init__(*args, **kwargs) - self.mount(url_scheme, UnixAdapter()) + for url_scheme in url_schemes: + self.mount(url_scheme, UnixAdapter()) class monkeypatch(object): - def __init__(self, url_scheme=DEFAULT_SCHEME): - self.session = Session() + def __init__(self, url_schemes=DEFAULT_SCHEMES): + self.session = Session(url_schemes=url_schemes) requests = self._get_global_requests_module() # Methods to replace diff --git a/requests_unixsocket/adapters.py b/requests_unixsocket/adapters.py index a2c1564..c645df1 100644 --- a/requests_unixsocket/adapters.py +++ b/requests_unixsocket/adapters.py @@ -1,3 +1,4 @@ +import os import socket from requests.adapters import HTTPAdapter @@ -14,6 +15,36 @@ import urllib3 +def get_unix_socket(path_or_name, timeout=None, type=socket.SOCK_STREAM): + sock = socket.socket(family=socket.AF_UNIX, type=type) + if timeout: + sock.settimeout(timeout) + sock.connect(path_or_name) + return sock + + +def get_sock_path_and_req_path(path): + i = 1 + while True: + try: + items = path.rsplit('/', i) + sock_path = items[0] + rest = items[1:] + except ValueError: + return None, None + + if os.path.exists(sock_path): + return sock_path, '/' + '/'.join(rest) + + # Detect abstract namespace socket, starting with `/%00` + if '/' not in sock_path[1:] and sock_path[1:4] == '%00': + return '\x00' + sock_path[4:], '/' + '/'.join(rest) + + if sock_path == '': + return None, None + i += 1 + + # The following was adapted from some code from docker-py # https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py class UnixHTTPConnection(httplib.HTTPConnection, object): @@ -35,11 +66,13 @@ def __del__(self): # base class does not have d'tor self.sock.close() def connect(self): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(self.timeout) - socket_path = unquote(urlparse(self.unix_socket_url).netloc) - sock.connect(socket_path) - self.sock = sock + path = urlparse(self.unix_socket_url).path + socket_path, req_path = get_sock_path_and_req_path(path) + if not socket_path: + socket_path = urlparse(self.unix_socket_url).path + if '\x00' not in socket_path and not os.path.exists(socket_path): + socket_path = unquote(urlparse(self.unix_socket_url).netloc) + self.sock = get_unix_socket(socket_path, timeout=self.timeout) class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): @@ -83,7 +116,11 @@ def get_connection(self, url, proxies=None): return pool def request_url(self, request, proxies): - return request.path_url + sock_path, req_path = get_sock_path_and_req_path(request.path_url) + if req_path: + return req_path + else: + return request.path_url def close(self): self.pools.clear() diff --git a/requests_unixsocket/tests/test_requests_unixsocket.py b/requests_unixsocket/tests/test_requests_unixsocket.py index 733aa87..b3946b7 100755 --- a/requests_unixsocket/tests/test_requests_unixsocket.py +++ b/requests_unixsocket/tests/test_requests_unixsocket.py @@ -41,6 +41,31 @@ def test_unix_domain_adapter_ok(): assert r.text == 'Hello world!' +def test_unix_domain_adapter_ok_alt_scheme(): + with UnixSocketServerThread() as usock_thread: + session = requests_unixsocket.Session('http+unix://') + url = 'http+unix://unix.socket%s/path/to/page' % usock_thread.usock + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', + 'options']: + logger.debug('Calling session.%s(%r) ...', method, url) + r = getattr(session, method)(url) + logger.debug( + 'Received response: %r with text: %r and headers: %r', + r, r.text, r.headers) + assert r.status_code == 200 + assert r.headers['server'] == 'waitress' + assert r.headers['X-Transport'] == 'unix domain socket' + assert r.headers['X-Requested-Path'] == '/path/to/page' + assert r.headers['X-Socket-Path'] == usock_thread.usock + assert isinstance(r.connection, requests_unixsocket.UnixAdapter) + assert r.url.lower() == url.lower() + if method == 'head': + assert r.text == '' + else: + assert r.text == 'Hello world!' + + def test_unix_domain_adapter_url_with_query_params(): with UnixSocketServerThread() as usock_thread: session = requests_unixsocket.Session('http+unix://') @@ -69,6 +94,33 @@ def test_unix_domain_adapter_url_with_query_params(): assert r.text == 'Hello world!' +def test_unix_domain_adapter_url_with_query_params_alt_scheme(): + with UnixSocketServerThread() as usock_thread: + session = requests_unixsocket.Session('http+unix://') + url = ('http+unix://unix.socket%s' + '/containers/nginx/logs?timestamp=true' % usock_thread.usock) + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', + 'options']: + logger.debug('Calling session.%s(%r) ...', method, url) + r = getattr(session, method)(url) + logger.debug( + 'Received response: %r with text: %r and headers: %r', + r, r.text, r.headers) + assert r.status_code == 200 + assert r.headers['server'] == 'waitress' + assert r.headers['X-Transport'] == 'unix domain socket' + assert r.headers['X-Requested-Path'] == '/containers/nginx/logs' + assert r.headers['X-Requested-Query-String'] == 'timestamp=true' + assert r.headers['X-Socket-Path'] == usock_thread.usock + assert isinstance(r.connection, requests_unixsocket.UnixAdapter) + assert r.url.lower() == url.lower() + if method == 'head': + assert r.text == '' + else: + assert r.text == 'Hello world!' + + def test_unix_domain_adapter_connection_error(): session = requests_unixsocket.Session('http+unix://') @@ -78,6 +130,15 @@ def test_unix_domain_adapter_connection_error(): 'http+unix://socket_does_not_exist/path/to/page') +def test_unix_domain_adapter_connection_error_alt_scheme(): + session = requests_unixsocket.Session('http+unix://') + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']: + with pytest.raises(requests.ConnectionError): + getattr(session, method)( + 'http+unix://unix.socket/socket_does_not_exist/path/to/page') + + def test_unix_domain_adapter_connection_proxies_error(): session = requests_unixsocket.Session('http+unix://') @@ -90,6 +151,18 @@ def test_unix_domain_adapter_connection_proxies_error(): in str(excinfo.value)) +def test_unix_domain_adapter_connection_proxies_error_alt_scheme(): + session = requests_unixsocket.Session('http+unix://') + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']: + with pytest.raises(ValueError) as excinfo: + getattr(session, method)( + 'http+unix://unix.socket/socket_does_not_exist/path/to/page', + proxies={"http+unix": "http://10.10.1.10:1080"}) + assert ('UnixAdapter does not support specifying proxies' + in str(excinfo.value)) + + def test_unix_domain_adapter_monkeypatch(): with UnixSocketServerThread() as usock_thread: with requests_unixsocket.monkeypatch('http+unix://'): @@ -119,3 +192,29 @@ def test_unix_domain_adapter_monkeypatch(): for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']: with pytest.raises(requests.exceptions.InvalidSchema): getattr(requests, method)(url) + + +def test_unix_domain_adapter_monkeypatch_alt_scheme(): + with UnixSocketServerThread() as usock_thread: + with requests_unixsocket.monkeypatch(): + url = 'http://sock.local/%s/path/to/page' % usock_thread.usock + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', + 'options']: + logger.debug('Calling session.%s(%r) ...', method, url) + r = getattr(requests, method)(url) + logger.debug( + 'Received response: %r with text: %r and headers: %r', + r, r.text, r.headers) + assert r.status_code == 200 + assert r.headers['server'] == 'waitress' + assert r.headers['X-Transport'] == 'unix domain socket' + assert r.headers['X-Requested-Path'] == '/path/to/page' + assert r.headers['X-Socket-Path'] == usock_thread.usock + assert isinstance(r.connection, + requests_unixsocket.UnixAdapter) + assert r.url.lower() == url.lower() + if method == 'head': + assert r.text == '' + else: + assert r.text == 'Hello world!'