Skip to content

Commit a17de88

Browse files
committed
Method for the Trellis class plotting the FSM.
1 parent c938681 commit a17de88

File tree

1 file changed

+110
-2
lines changed

1 file changed

+110
-2
lines changed

commpy/channelcoding/convcode.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import matplotlib.colors as mcolors
1212
import matplotlib.patches as mpatches
13+
import matplotlib.path as mpath
1314
import matplotlib.pyplot as plt
1415
import numpy as np
1516
from matplotlib.collections import PatchCollection
@@ -116,7 +117,7 @@ def __init__(self, memory, g_matrix, feedback = None, code_type = 'default'):
116117
self.number_inputs], 'int')
117118

118119
if isinstance(feedback, int):
119-
warn('Treillis wiil will only accept feedback as a matricx in the future. '
120+
warn('Trellis will only accept feedback as a matrix in the future. '
120121
'Using the backwards compatibility version that may contain bugs for k > 1.', DeprecationWarning)
121122

122123
if code_type == 'rsc':
@@ -305,6 +306,7 @@ def visualize(self, trellis_length = 2, state_order = None,
305306
to the input.
306307
save_path : str or None
307308
If not None, save the figure to the file specified by its path.
309+
*Default* is no saving.
308310
"""
309311
if edge_colors is None:
310312
edge_colors = [mcolors.hsv_to_rgb((i/self.number_inputs, 1, 1)) for i in range(self.number_inputs)]
@@ -339,6 +341,111 @@ def visualize(self, trellis_length = 2, state_order = None,
339341
if save_path is not None:
340342
plt.savefig(save_path)
341343

344+
def visualize_fsm(self, state_order=None, state_radius=0.04, edge_colors=None, save_path=None):
345+
""" Plot the FSM corresponding to the the trellis
346+
347+
This method is not intended to display large FSMs and its use is advisable only for simple trellises.
348+
349+
Parameters
350+
----------
351+
state_order : list of ints, optional
352+
Specifies the order in the which the states of the trellis are to be displayed starting from the top in the
353+
plot.
354+
*Default* order is [0,...,number_states-1]
355+
state_radius : float, optional
356+
Radius of each state (circle) in the plot.
357+
*Default* value is 0.04
358+
edge_colors : list of hex color codes, optional
359+
A list of length equal to the number_inputs, containing color codes that represent the edge corresponding to
360+
the input.
361+
save_path : str or None
362+
If not None, save the figure to the file specified by its path.
363+
*Default* is no saving.
364+
"""
365+
# Default arguments
366+
if edge_colors is None:
367+
edge_colors = [mcolors.hsv_to_rgb((i/self.number_inputs, 1, 1)) for i in range(self.number_inputs)]
368+
369+
if state_order is None:
370+
state_order = list(range(self.number_states))
371+
372+
# Init the figure
373+
ax = plt.axes((0, 0, 1, 1))
374+
375+
# Plot states
376+
radius = state_radius * self.number_states
377+
angles = 2 * np.pi / self.number_states * np.arange(self.number_states)
378+
positions = [(radius * math.cos(angle), radius * math.sin(angle)) for angle in angles]
379+
380+
state_patches = []
381+
arrows = []
382+
for idx, state in enumerate(state_order):
383+
state_patches.append(mpatches.Circle(positions[idx], state_radius, color="#003399", ec="#cccccc"))
384+
plt.text(positions[idx][0], positions[idx][1], str(state), ha='center', va='center', size=20)
385+
386+
# Plot transition
387+
for input in range(self.number_inputs):
388+
next_state = self.next_state_table[state, input]
389+
next_idx = (state_order == next_state).nonzero()[0][0]
390+
output = self.output_table[state, input]
391+
392+
# Transition arrow
393+
if next_state == state:
394+
# Positions
395+
arrow_start_x = positions[idx][0] + state_radius * math.cos(angles[idx] + math.pi / 6)
396+
arrow_start_y = positions[idx][1] + state_radius * math.sin(angles[idx] + math.pi / 6)
397+
arrow_end_x = positions[idx][0] + state_radius * math.cos(angles[idx] - math.pi / 6)
398+
arrow_end_y = positions[idx][1] + state_radius * math.sin(angles[idx] - math.pi / 6)
399+
arrow_mid_x = positions[idx][0] + state_radius * 2 * math.cos(angles[idx])
400+
arrow_mid_y = positions[idx][1] + state_radius * 2 * math.sin(angles[idx])
401+
402+
# Add text
403+
plt.text(arrow_mid_x, arrow_mid_y, '({})'.format(output),
404+
ha='center', va='center', backgroundcolor=edge_colors[input])
405+
406+
else:
407+
# Positions
408+
dx = positions[next_idx][0] - positions[idx][0]
409+
dy = positions[next_idx][1] - positions[idx][1]
410+
relative_angle = math.atan(dy / dx) + np.where(dx > 0, 0, math.pi)
411+
412+
arrow_start_x = positions[idx][0] + state_radius * math.cos(relative_angle + math.pi * 0.05)
413+
arrow_start_y = positions[idx][1] + state_radius * math.sin(relative_angle + math.pi * 0.05)
414+
arrow_end_x = positions[next_idx][0] - state_radius * math.cos(relative_angle - math.pi * 0.05)
415+
arrow_end_y = positions[next_idx][1] - state_radius * math.sin(relative_angle - math.pi * 0.05)
416+
arrow_mid_x = (arrow_start_x + arrow_end_x) / 2 + \
417+
radius * 0.25 * math.cos((angles[idx] + angles[next_idx]) / 2) * np.sign(dx)
418+
arrow_mid_y = (arrow_start_y + arrow_end_y) / 2 + \
419+
radius * 0.25 * math.sin((angles[idx] + angles[next_idx]) / 2) * np.sign(dx)
420+
text_x = arrow_mid_x + 0.01 * math.cos((angles[idx] + angles[next_idx]) / 2)
421+
text_y = arrow_mid_y + 0.01 * math.sin((angles[idx] + angles[next_idx]) / 2)
422+
423+
# Add text
424+
plt.text(text_x, text_y, '({})'.format(output),
425+
ha='center', va='center', backgroundcolor=edge_colors[input])
426+
427+
# Path creation
428+
codes = (mpath.Path.MOVETO, mpath.Path.CURVE3, mpath.Path.CURVE3)
429+
verts = ((arrow_start_x, arrow_start_y),
430+
(arrow_mid_x, arrow_mid_y),
431+
(arrow_end_x, arrow_end_y))
432+
path = mpath.Path(verts, codes)
433+
434+
# Plot arrow
435+
arrow = mpatches.FancyArrowPatch(path=path, mutation_scale=20, color=edge_colors[input])
436+
ax.add_artist(arrow)
437+
arrows.append(arrow)
438+
439+
# Format and plot
440+
ax.set_xlim(radius * -2, radius * 2)
441+
ax.set_ylim(radius * -2, radius * 2)
442+
ax.add_collection(PatchCollection(state_patches, True))
443+
plt.legend(arrows, [str(i) + "-input" for i in range(self.number_inputs)], loc='lower right')
444+
plt.text(0, 1.5 * radius, 'Finite State Machine (output on transition)', ha='center', size=18)
445+
plt.show()
446+
if save_path is not None:
447+
plt.savefig(save_path)
448+
342449

343450
def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=None):
344451
"""
@@ -574,7 +681,8 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
574681
count = 0
575682
current_number_states = trellis.number_states
576683

577-
coded_bits = coded_bits.clip(-500, 500)
684+
if decoding_type == 'soft':
685+
coded_bits = coded_bits.clip(-500, 500)
578686

579687
for t in range(1, int((L+total_memory)/k)):
580688
# Get the received codeword corresponding to t

0 commit comments

Comments
 (0)