Skip to content

Commit 2957aed

Browse files
committed
jupyter kernel ports are now determined in pyscript and made available to hass_pyscript_kernel.py via a state variable
1 parent 1a2d764 commit 2957aed

File tree

2 files changed

+107
-50
lines changed

2 files changed

+107
-50
lines changed

custom_components/pyscript/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Component to allow running Python scripts."""
22

33
import glob
4+
import json
45
import logging
56
import os
67

@@ -114,6 +115,12 @@ async def jupyter_kernel_start(call):
114115
handler_func.install_ast_funcs(ast_ctx)
115116
kernel = Kernel(call.data, ast_ctx, global_ctx_name, global_ctx_mgr)
116117
await kernel.session_start()
118+
hass.states.async_set(call.data["state_var"], json.dumps(kernel.get_ports()))
119+
120+
def state_var_remove():
121+
hass.states.async_remove(call.data["state_var"])
122+
123+
kernel.set_session_cleanup_callback(state_var_remove)
117124

118125
hass.services.async_register(
119126
DOMAIN, SERVICE_JUPYTER_KERNEL_START, jupyter_kernel_start
@@ -123,12 +130,10 @@ async def state_changed(event):
123130
var_name = event.data["entity_id"]
124131
# attr = event.data["new_state"].attributes
125132
if "new_state" not in event.data or event.data["new_state"] is None:
126-
_LOGGER.debug(
127-
"state_changed: missing new_state in event.data=%s; ignoring",
128-
event.data,
129-
)
130-
return
131-
new_val = event.data["new_state"].state
133+
# state variable has been deleted
134+
new_val = None
135+
else:
136+
new_val = event.data["new_state"].state
132137
old_val = event.data["old_state"].state if event.data["old_state"] else None
133138
new_vars = {var_name: new_val, f"{var_name}.old": old_val}
134139
func_args = {

custom_components/pyscript/jupyter_kernel.py

Lines changed: 96 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import logging
1616
import logging.handlers
17+
import random
1718
import re
1819
from struct import pack, unpack
1920
import traceback
@@ -36,14 +37,6 @@ def str_to_bytes(string):
3637
"""Encode a string in bytes."""
3738
return string.encode('ascii')
3839

39-
def bind(socket, connection, port):
40-
"""Bind a socket."""
41-
if port <= 0:
42-
return socket.bind_to_random_port(connection)
43-
# _LOGGER.debug(f"binding to %s:%s" % (connection, port))
44-
socket.bind("%s:%s" % (connection, port))
45-
return port
46-
4740
class KernelBufferingHandler(logging.handlers.BufferingHandler):
4841
"""Memory-based handler for logging; send via stdout queue."""
4942
def __init__(self, housekeep_q):
@@ -61,7 +54,6 @@ def shouldFlush(self, record):
6154

6255

6356
################################################################
64-
#
6557
class ZmqSocket:
6658
"""Defines a minimal implementation of a small subset of ZMQ,
6759
allowing pyscript to work with Jupyter without the real zmq
@@ -193,7 +185,6 @@ def __init__(self, config, ast_ctx, global_ctx_name, global_ctx_mgr):
193185
self.global_ctx_mgr = global_ctx_mgr
194186
self.ast_ctx = ast_ctx
195187

196-
self.connection = self.config["transport"] + "://" + self.config["ip"]
197188
self.secure_key = str_to_bytes(self.config["key"])
198189
self.signature_schemes = {"hmac-sha256": hashlib.sha256}
199190
self.auth = hmac.HMAC(
@@ -209,17 +200,30 @@ def __init__(self, config, ast_ctx, global_ctx_name, global_ctx_mgr):
209200
self.stdin_server = None
210201
self.shell_server = None
211202

203+
self.heartbeat_port = None
204+
self.iopub_port = None
205+
self.control_port = None
206+
self.stdin_port = None
207+
self.shell_port = None
208+
self.avail_port = random.randrange(40000, 50000)
209+
212210
# there can be multiple iopub subscribers, with corresponding tasks
213211
self.iopub_socket = set()
214212

215213
self.tasks = {}
216214
self.task_cnt = 0
217215
self.task_cnt_max = 0
218216

217+
self.session_cleanup_callback = None
218+
219219
self.housekeep_q = asyncio.Queue(0)
220220

221221
self.parent_header = None
222222

223+
#
224+
# we create a logging handler so that output from the log functions
225+
# gets delivered back to Jupyter as stdout
226+
#
223227
self.console = KernelBufferingHandler(self.housekeep_q)
224228
self.console.setLevel(logging.DEBUG)
225229
# set a format which is just the message
@@ -514,13 +518,10 @@ async def shell_handler(self, shell_socket, msg):
514518
else:
515519
_LOGGER.error("unknown msg_type: %s", msg['header']["msg_type"])
516520

517-
518-
##########################################
519-
# Control:
520521
async def control_listen(self, reader, writer):
521522
"""Task that listens to control messages."""
522523
try:
523-
# _LOGGER.debug("control_listen connected")
524+
_LOGGER.debug("control_listen connected")
524525
await self.housekeep_q.put(["register", "control", current_task()])
525526
control_socket = ZmqSocket(reader, writer, "ROUTER")
526527
await control_socket.handshake()
@@ -533,41 +534,37 @@ async def control_listen(self, reader, writer):
533534
except asyncio.CancelledError: # pylint: disable=try-except-raise
534535
raise
535536
except EOFError:
536-
# _LOGGER.debug("control_listen got eof")
537+
_LOGGER.debug("control_listen got eof")
537538
await self.housekeep_q.put(["unregister", "control", current_task()])
538539
control_socket.close()
539540
except Exception as err: # pylint: disable=broad-except
540541
_LOGGER.error("control_listen exception %s", err)
541542
await self.housekeep_q.put(["shutdown"])
542543

543-
##########################################
544-
# Stdin:
545544
async def stdin_listen(self, reader, writer):
546545
"""Task that listens to stdin messages."""
547546
try:
548-
# _LOGGER.debug("stdin_listen connected")
547+
_LOGGER.debug("stdin_listen connected")
549548
await self.housekeep_q.put(["register", "stdin", current_task()])
550549
stdin_socket = ZmqSocket(reader, writer, "ROUTER")
551550
await stdin_socket.handshake()
552551
while 1:
553552
_ = await stdin_socket.recv_multipart()
554-
# _LOGGER.debug("stdin_listen received %s", raw_msg)
553+
# _LOGGER.debug("stdin_listen received %s", _)
555554
except asyncio.CancelledError: # pylint: disable=try-except-raise
556555
raise
557556
except EOFError:
558-
# _LOGGER.debug("stdin_listen got eof")
557+
_LOGGER.debug("stdin_listen got eof")
559558
await self.housekeep_q.put(["unregister", "stdin", current_task()])
560559
stdin_socket.close()
561560
except Exception: # pylint: disable=broad-except
562561
_LOGGER.error("stdin_listen exception %s", traceback.format_exc(-1))
563562
await self.housekeep_q.put(["shutdown"])
564563

565-
##########################################
566-
# Shell:
567564
async def shell_listen(self, reader, writer):
568565
"""Task that listens to shell messages."""
569566
try:
570-
# _LOGGER.debug("shell_listen connected")
567+
_LOGGER.debug("shell_listen connected")
571568
await self.housekeep_q.put(["register", "shell", current_task()])
572569
shell_socket = ZmqSocket(reader, writer, "ROUTER")
573570
await shell_socket.handshake()
@@ -578,62 +575,56 @@ async def shell_listen(self, reader, writer):
578575
shell_socket.close()
579576
raise
580577
except EOFError:
581-
# _LOGGER.debug("shell_listen got eof")
578+
_LOGGER.debug("shell_listen got eof")
582579
await self.housekeep_q.put(["unregister", "shell", current_task()])
583580
shell_socket.close()
584581
except Exception: # pylint: disable=broad-except
585582
_LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
586583
await self.housekeep_q.put(["shutdown"])
587584

588-
##########################################
589-
# Heartbeat:
590585
async def heartbeat_listen(self, reader, writer):
591586
"""Task that listens and responds to heart beat messages."""
592587
try:
593-
# _LOGGER.debug("heartbeat_listen connected")
588+
_LOGGER.debug("heartbeat_listen connected")
594589
await self.housekeep_q.put(["register", "heartbeat", current_task()])
595590
heartbeat_socket = ZmqSocket(reader, writer, "REP")
596591
await heartbeat_socket.handshake()
597592
while 1:
598593
msg = await heartbeat_socket.recv()
599-
# _LOGGER.debug(f"heartbeat_listen: got {msg}")
594+
# _LOGGER.debug("heartbeat_listen: got %s", msg)
600595
await heartbeat_socket.send(msg)
601596
except asyncio.CancelledError: # pylint: disable=try-except-raise
602597
raise
603598
except EOFError:
604-
# _LOGGER.debug("heartbeat_listen got eof")
599+
_LOGGER.debug("heartbeat_listen got eof")
605600
await self.housekeep_q.put(["unregister", "heartbeat", current_task()])
606601
heartbeat_socket.close()
607602
except Exception: # pylint: disable=broad-except
608603
_LOGGER.error("heartbeat_listen exception: %s", traceback.format_exc(-1))
609604
await self.housekeep_q.put(["shutdown"])
610605

611-
##########################################
612-
# IOPub/Sub:
613606
async def iopub_listen(self, reader, writer):
614607
"""Task that listens to iopub messages."""
615608
try:
616-
# _LOGGER.debug("iopub_listen connected")
609+
_LOGGER.debug("iopub_listen connected")
617610
await self.housekeep_q.put(["register", "iopub", current_task()])
618611
iopub_socket = ZmqSocket(reader, writer, "PUB")
619612
await iopub_socket.handshake()
620613
self.iopub_socket.add(iopub_socket)
621614
while 1:
622615
_ = await iopub_socket.recv_multipart()
623-
# _LOGGER.debug("iopub received %s", wire_msg)
616+
# _LOGGER.debug("iopub received %s", _)
624617
except asyncio.CancelledError: # pylint: disable=try-except-raise
625618
raise
626619
except EOFError:
627620
await self.housekeep_q.put(["unregister", "iopub", current_task()])
628621
iopub_socket.close()
629622
self.iopub_socket.discard(iopub_socket)
630-
# _LOGGER.debug("iopub_listen got eof")
623+
_LOGGER.debug("iopub_listen got eof")
631624
except Exception: # pylint: disable=broad-except
632625
_LOGGER.error("iopub_listen exception %s", traceback.format_exc(-1))
633626
await self.housekeep_q.put(["shutdown"])
634627

635-
##########################################
636-
# Housekeeping
637628
async def housekeep_run(self):
638629
"""Housekeeping, including closing servers after startup, and doing orderly shutdown."""
639630
while True:
@@ -652,6 +643,12 @@ async def housekeep_run(self):
652643
self.tasks[msg[1]].add(msg[2])
653644
self.task_cnt += 1
654645
self.task_cnt_max = max(self.task_cnt_max, self.task_cnt)
646+
#
647+
# now a couple of things are connected, call the session_cleanup_callback
648+
#
649+
if self.task_cnt > 1 and self.session_cleanup_callback:
650+
self.session_cleanup_callback()
651+
self.session_cleanup_callback = None
655652
elif msg[0] == "unregister":
656653
if msg[1] in self.tasks:
657654
self.tasks[msg[1]].discard(msg[2])
@@ -670,35 +667,89 @@ async def housekeep_run(self):
670667
except Exception: # pylint: disable=broad-except
671668
_LOGGER.error("housekeep task exception: %s", traceback.format_exc(-1))
672669

670+
async def startup_timeout(self):
671+
"""Shut down the session if nothing connects after 30 seconds."""
672+
await self.housekeep_q.put(["register", "startup_timeout", current_task()])
673+
await asyncio.sleep(30)
674+
if self.task_cnt_max == 1:
675+
#
676+
# nothing started other than us, so shut down the session
677+
#
678+
_LOGGER.error("No connections to session %s; shutting down", self.global_ctx_name)
679+
if self.session_cleanup_callback:
680+
self.session_cleanup_callback()
681+
self.session_cleanup_callback = None
682+
await self.housekeep_q.put(["shutdown"])
683+
await self.housekeep_q.put(["unregister", "startup_timeout", current_task()])
684+
685+
async def start_one_server(self, callback):
686+
"""Start a server by finding an available port."""
687+
for _ in range(2048):
688+
try:
689+
server = await asyncio.start_server(callback, self.config["ip"], self.avail_port)
690+
return server, self.avail_port
691+
except OSError:
692+
self.avail_port += 1
693+
_LOGGER.error("unable to find an available port on host %s, last port %d", self.config["ip"], self.avail_port)
694+
return None, None
695+
696+
def get_ports(self):
697+
"""Return a dict of the port numbers this kernel session is listening to."""
698+
return {
699+
"iopub_port": self.iopub_port,
700+
"hb_port": self.heartbeat_port,
701+
"control_port": self.control_port,
702+
"stdin_port": self.stdin_port,
703+
"shell_port": self.shell_port,
704+
}
705+
706+
def set_session_cleanup_callback(self, callback):
707+
"""Set a cleanup callback which is called right after the session has started."""
708+
self.session_cleanup_callback = callback
709+
673710
async def session_start(self):
674711
"""Start the kernel session."""
675712
self.ast_ctx.add_logger_handler(self.console)
676713
_LOGGER.info("Starting session %s", self.global_ctx_name)
677714

678715
self.tasks["housekeep"] = {asyncio.create_task(self.housekeep_run())}
716+
self.tasks["startup_timeout"] = {asyncio.create_task(self.startup_timeout())}
679717

680-
self.iopub_server = await asyncio.start_server(self.iopub_listen, self.config["ip"], self.config["iopub_port"])
681-
self.heartbeat_server = await asyncio.start_server(self.heartbeat_listen, self.config["ip"], self.config["hb_port"])
682-
self.control_server = await asyncio.start_server(self.control_listen, self.config["ip"], self.config["control_port"])
683-
self.stdin_server = await asyncio.start_server(self.stdin_listen, self.config["ip"], self.config["stdin_port"])
684-
self.shell_server = await asyncio.start_server(self.shell_listen, self.config["ip"], self.config["shell_port"])
718+
self.iopub_server, self.iopub_port = await self.start_one_server(self.iopub_listen)
719+
self.heartbeat_server, self.heartbeat_port = await self.start_one_server(self.heartbeat_listen)
720+
self.control_server, self.control_port = await self.start_one_server(self.control_listen)
721+
self.stdin_server, self.stdin_port = await self.start_one_server(self.stdin_listen)
722+
self.shell_server, self.shell_port = await self.start_one_server(self.shell_listen)
685723

686724
#
687725
# For debugging, can use the real ZMQ library instead on certain sockets; comment out
688726
# the corresponding asyncio.start_server() call above if you enable the ZMQ-based
689-
# functions here. The two most important ones are shown here.
727+
# functions here. You can then turn of verbosity level 4 (-vvvv) in hass_pyscript_kernel.py
728+
# to see all the byte data in case you need to debug the simple ZMQ implementation here.
729+
# The two most important zmq functions are shown below.
690730
#
691731
# import zmq
692732
# import zmq.asyncio
693733
#
734+
# def zmq_bind(socket, connection, port):
735+
# """Bind a socket."""
736+
# if port <= 0:
737+
# return socket.bind_to_random_port(connection)
738+
# # _LOGGER.debug(f"binding to %s:%s" % (connection, port))
739+
# socket.bind("%s:%s" % (connection, port))
740+
# return port
741+
#
694742
# zmq_ctx = zmq.asyncio.Context()
743+
#
695744
# ##########################################
696745
# # Shell using real ZMQ for debugging:
697746
# async def shell_listen_zmq():
698747
# """Task that listens to shell messages using ZMQ."""
699748
# try:
749+
# _LOGGER.debug("shell_listen_zmq connected")
750+
# connection = self.config["transport"] + "://" + self.config["ip"]
700751
# shell_socket = zmq_ctx.socket(zmq.ROUTER) # pylint: disable=no-member
701-
# self.config["shell_port"] = bind(shell_socket, self.connection, self.config["shell_port"])
752+
# self.shell_port = zmq_bind(shell_socket, connection, -1)
702753
# _LOGGER.debug("shell_listen_zmq connected")
703754
# while 1:
704755
# msg = await shell_socket.recv_multipart()
@@ -716,8 +767,9 @@ async def session_start(self):
716767
# """Task that listens to iopub messages using ZMQ."""
717768
# try:
718769
# _LOGGER.debug("iopub_listen_zmq connected")
770+
# connection = self.config["transport"] + "://" + self.config["ip"]
719771
# iopub_socket = zmq_ctx.socket(zmq.PUB) # pylint: disable=no-member
720-
# self.config["iopub_port"] = bind(self.iopub_socket, self.connection, self.config["iopub_port"])
772+
# self.iopub_port = zmq_bind(self.iopub_socket, connection, -1)
721773
# self.iopub_socket.add(iopub_socket)
722774
# while 1:
723775
# wire_msg = await iopub_socket.recv_multipart()

0 commit comments

Comments
 (0)