Skip to content

Commit 804132e

Browse files
author
Joel Collins
committed
Removed circular references to LabThing
1 parent abd3d75 commit 804132e

File tree

8 files changed

+138
-52
lines changed

8 files changed

+138
-52
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from ..find import current_labthing
2+
from ..view import View
3+
4+
5+
class RootView(View):
6+
def get(self):
7+
return current_labthing().thing_description.to_dict()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from ..sockets import SocketSubscriber, socket_handler_loop
2+
from ..find import current_labthing
3+
4+
import logging
5+
6+
7+
def socket_handler(ws):
8+
# Create a socket subscriber
9+
wssub = SocketSubscriber(ws)
10+
current_labthing().subscribers.add(wssub)
11+
logging.info(f"Added subscriber {wssub}")
12+
# Start the socket connection handler loop
13+
socket_handler_loop(ws)
14+
# Remove the subscriber once the loop returns
15+
current_labthing().subscribers.remove(wssub)
16+
logging.info(f"Removed subscriber {wssub}")

labthings/server/find.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from flask import current_app
3+
import weakref
34

45
from .names import EXTENSION_NAME
56

@@ -13,14 +14,15 @@ def current_labthing(app=None):
1314
# reach the Flask app object. Just using current_app returns
1415
# a wrapper, which breaks it's use in Task threads
1516
if not app:
16-
app = current_app._get_current_object() # skipcq: PYL-W0212
17-
if not app:
18-
return None
19-
logging.debug("Active app extensions:")
20-
logging.debug(app.extensions)
21-
logging.debug("Active labthing:")
22-
logging.debug(app.extensions.get(EXTENSION_NAME))
23-
return app.extensions.get(EXTENSION_NAME, None)
17+
try:
18+
app = current_app._get_current_object() # skipcq: PYL-W0212
19+
except RuntimeError:
20+
return None
21+
ext = app.extensions.get(EXTENSION_NAME, None)
22+
if isinstance(ext, weakref.ref):
23+
return ext()
24+
else:
25+
return ext
2426

2527

2628
def registered_extensions(labthing_instance=None):

labthings/server/labthing.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
from .spec.utilities import get_spec
1818
from .spec.td import ThingDescription
1919
from .decorators import tag
20-
from .sockets import Sockets, SocketSubscriber, socket_handler_loop
20+
from .sockets import Sockets
2121

2222
from .default_views.extensions import ExtensionList
2323
from .default_views.tasks import TaskList, TaskView
2424
from .default_views.docs import docs_blueprint, SwaggerUIView
25+
from .default_views.root import RootView
26+
from .default_views.sockets import socket_handler
2527

2628
from ..core.utilities import get_docstring
2729

30+
import weakref
2831
import logging
2932

3033

@@ -72,7 +75,7 @@ def __init__(
7275

7376
# Logging handler
7477
# TODO: Add cleanup code
75-
self.log_handler = LabThingLogger(self)
78+
self.log_handler = LabThingLogger()
7679
logging.getLogger().addHandler(self.log_handler)
7780

7881
self.spec = APISpec(
@@ -87,6 +90,12 @@ def __init__(
8790
if app is not None:
8891
self.init_app(app)
8992

93+
def __enter__(self):
94+
return self
95+
96+
def __exit__(self, *args, **kwargs):
97+
self.app = None
98+
9099
@property
91100
def description(self,):
92101
return self._description
@@ -128,11 +137,9 @@ def version(self, version: str):
128137
def init_app(self, app):
129138
self.app = app
130139

131-
app.teardown_appcontext(self.teardown)
132-
133140
# Register Flask extension
134141
app.extensions = getattr(app, "extensions", {})
135-
app.extensions[EXTENSION_NAME] = self
142+
app.extensions[EXTENSION_NAME] = weakref.ref(self)
136143

137144
# Flask error formatter
138145
if self.format_flask_exceptions:
@@ -154,12 +161,9 @@ def init_app(self, app):
154161
self.sockets = Sockets(app)
155162
self._create_base_sockets()
156163

157-
def teardown(self, exception):
158-
pass
159-
160164
def _create_base_routes(self):
161165
# Add root representation
162-
self.app.add_url_rule(self._complete_url("/", ""), "root", self.root)
166+
self.add_view(RootView, self._complete_url("/", ""), endpoint="root")
163167
# Add thing descriptions
164168
self.app.register_blueprint(
165169
docs_blueprint, url_prefix=f"{self.url_prefix}/docs"
@@ -175,19 +179,7 @@ def _create_base_routes(self):
175179
self.add_view(TaskView, "/tasks/<task_id>", endpoint=TASK_ENDPOINT)
176180

177181
def _create_base_sockets(self):
178-
self.sockets.add_view(self._complete_url("", ""), self._socket_handler)
179-
180-
def _socket_handler(self, ws):
181-
# Create a socket subscriber
182-
wssub = SocketSubscriber(ws)
183-
self.subscribers.add(wssub)
184-
logging.info(f"Added subscriber {wssub}")
185-
# Start the socket connection handler loop
186-
socket_handler_loop(ws)
187-
# Remove the subscriber once the loop returns
188-
self.subscribers.remove(wssub)
189-
logging.info(f"Removed subscriber {wssub}")
190-
logging.debug(list(self.subscribers))
182+
self.sockets.add_view(self._complete_url("", ""), socket_handler)
191183

192184
# Device stuff
193185

@@ -309,6 +301,8 @@ def _register_view(self, app, view, *urls, endpoint=None, **kwargs):
309301

310302
# There might be a better way to do this than _rules_by_endpoint,
311303
# but I can't find one so this will do for now. Skipping PYL-W0212
304+
# FIXME: There is a MASSIVE memory leak or something going on in APISpec!
305+
# This is grinding tests to a halt, and is really annoying... Should be fixed.
312306
flask_rules = app.url_map._rules_by_endpoint.get(endpoint) # skipcq: PYL-W0212
313307
for flask_rule in flask_rules:
314308
self.spec.path(**rule_to_apispec_path(flask_rule, view, self.spec))
@@ -345,8 +339,3 @@ def add_root_link(self, view, rel, kwargs=None, params=None):
345339
if params is None:
346340
params = {}
347341
self.thing_description.add_link(view, rel, kwargs=kwargs, params=params)
348-
349-
# Description
350-
def root(self):
351-
"""Root representation"""
352-
return self.thing_description.to_dict()

labthings/server/logging.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
from .find import current_labthing
2+
13
from logging import StreamHandler
24
import datetime
35

46

57
class LabThingLogger(StreamHandler):
6-
def __init__(self, labthing):
7-
StreamHandler.__init__(self)
8-
self.labthing = labthing
8+
def __init__(self, *args, **kwargs):
9+
StreamHandler.__init__(self, *args, **kwargs)
910

1011
def emit(self, record):
1112
log_event = self.rest_format_record(record)
1213

1314
# Broadcast to subscribers
14-
subscribers = getattr(self.labthing, "subscribers", [])
15+
subscribers = getattr(current_labthing(), "subscribers", [])
1516
for sub in subscribers:
1617
sub.event_notify(log_event)
1718

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ def thing(app):
190190
return thing
191191

192192

193+
@pytest.fixture()
194+
def thing_ctx(thing):
195+
with thing.app.app_context():
196+
yield thing.app
197+
198+
193199
@pytest.fixture()
194200
def debug_app(request):
195201

tests/test_server_find.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from labthings.server import find
2+
3+
from labthings.server.extensions import BaseExtension
4+
5+
6+
def test_current_labthing(thing, thing_ctx):
7+
with thing_ctx.test_request_context():
8+
assert find.current_labthing() is thing
9+
10+
11+
def test_current_labthing_explicit_app(thing, thing_ctx):
12+
with thing_ctx.test_request_context():
13+
assert find.current_labthing(thing.app) is thing
14+
15+
16+
def test_current_labthing_missing_app():
17+
assert find.current_labthing() is None
18+
19+
20+
def test_registered_extensions(thing_ctx):
21+
with thing_ctx.test_request_context():
22+
assert find.registered_extensions() == {}
23+
24+
25+
def test_registered_extensions_explicit_thing(thing):
26+
assert find.registered_extensions(thing) == {}
27+
28+
29+
def test_registered_components(thing_ctx):
30+
with thing_ctx.test_request_context():
31+
assert find.registered_components() == {}
32+
33+
34+
def test_registered_components_explicit_thing(thing):
35+
assert find.registered_components(thing) == {}
36+
37+
38+
def test_find_component(thing, thing_ctx):
39+
component = type("component", (object,), {})
40+
thing.add_component(component, "org.labthings.tests.component")
41+
42+
with thing_ctx.test_request_context():
43+
assert find.find_component("org.labthings.tests.component") == component
44+
45+
46+
def test_find_component_explicit_thing(thing):
47+
component = type("component", (object,), {})
48+
thing.add_component(component, "org.labthings.tests.component")
49+
50+
assert find.find_component("org.labthings.tests.component", thing) == component
51+
52+
53+
def test_find_component_missing_component(thing_ctx):
54+
with thing_ctx.test_request_context():
55+
assert find.find_component("org.labthings.tests.component") is None
56+
57+
58+
def test_find_extension(thing, thing_ctx):
59+
extension = BaseExtension("org.labthings.tests.extension")
60+
thing.register_extension(extension)
61+
62+
with thing_ctx.test_request_context():
63+
assert find.find_extension("org.labthings.tests.extension") == extension
64+
65+
66+
def test_find_extension_explicit_thing(thing):
67+
extension = BaseExtension("org.labthings.tests.extension")
68+
thing.register_extension(extension)
69+
70+
assert find.find_extension("org.labthings.tests.extension", thing) == extension
71+
72+
73+
def test_find_extension_missing_extesion(thing_ctx):
74+
with thing_ctx.test_request_context():
75+
assert find.find_extension("org.labthings.tests.extension") is None

tests/test_server_labthing.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def test_init_app(app):
2424
thing = labthing.LabThing()
2525
thing.init_app(app)
2626

27-
assert app.extensions.get(EXTENSION_NAME) == thing
27+
# Check weakref
28+
assert app.extensions.get(EXTENSION_NAME)() == thing
29+
2830
assert app.json_encoder == LabThingsJSONEncoder
2931
assert 400 in app.error_handler_spec.get(None)
3032

@@ -218,11 +220,6 @@ def test_td_add_link_options(thing, view_cls):
218220
} in thing.thing_description._links
219221

220222

221-
def test_root_rep(thing, app_ctx):
222-
with app_ctx.test_request_context():
223-
assert thing.root() == thing.thing_description.to_dict()
224-
225-
226223
def test_description(thing):
227224
assert thing.description == ""
228225
thing.description = "description"
@@ -242,10 +239,3 @@ def test_version(thing):
242239
thing.version = "x.x.x"
243240
assert thing.version == "x.x.x"
244241
assert thing.spec.version == "x.x.x"
245-
246-
247-
def test_socket_handler(thing, fake_websocket):
248-
ws = fake_websocket("", recieve_once=True)
249-
thing._socket_handler(ws)
250-
# Expect no response
251-
assert ws.responses == []

0 commit comments

Comments
 (0)