|
10 | 10 |
|
11 | 11 | import matplotlib.colors as mcolors |
12 | 12 | import matplotlib.patches as mpatches |
| 13 | +import matplotlib.path as mpath |
13 | 14 | import matplotlib.pyplot as plt |
14 | 15 | import numpy as np |
15 | 16 | from matplotlib.collections import PatchCollection |
@@ -116,7 +117,7 @@ def __init__(self, memory, g_matrix, feedback = None, code_type = 'default'): |
116 | 117 | self.number_inputs], 'int') |
117 | 118 |
|
118 | 119 | if isinstance(feedback, int): |
119 | | - warn('Treillis wiil will only accept feedback as a matricx in the future. ' |
| 120 | + warn('Trellis will only accept feedback as a matrix in the future. ' |
120 | 121 | 'Using the backwards compatibility version that may contain bugs for k > 1.', DeprecationWarning) |
121 | 122 |
|
122 | 123 | if code_type == 'rsc': |
@@ -305,6 +306,7 @@ def visualize(self, trellis_length = 2, state_order = None, |
305 | 306 | to the input. |
306 | 307 | save_path : str or None |
307 | 308 | If not None, save the figure to the file specified by its path. |
| 309 | + *Default* is no saving. |
308 | 310 | """ |
309 | 311 | if edge_colors is None: |
310 | 312 | edge_colors = [mcolors.hsv_to_rgb((i/self.number_inputs, 1, 1)) for i in range(self.number_inputs)] |
@@ -339,6 +341,111 @@ def visualize(self, trellis_length = 2, state_order = None, |
339 | 341 | if save_path is not None: |
340 | 342 | plt.savefig(save_path) |
341 | 343 |
|
| 344 | + def visualize_fsm(self, state_order=None, state_radius=0.04, edge_colors=None, save_path=None): |
| 345 | + """ Plot the FSM corresponding to the the trellis |
| 346 | +
|
| 347 | + This method is not intended to display large FSMs and its use is advisable only for simple trellises. |
| 348 | +
|
| 349 | + Parameters |
| 350 | + ---------- |
| 351 | + state_order : list of ints, optional |
| 352 | + Specifies the order in the which the states of the trellis are to be displayed starting from the top in the |
| 353 | + plot. |
| 354 | + *Default* order is [0,...,number_states-1] |
| 355 | + state_radius : float, optional |
| 356 | + Radius of each state (circle) in the plot. |
| 357 | + *Default* value is 0.04 |
| 358 | + edge_colors : list of hex color codes, optional |
| 359 | + A list of length equal to the number_inputs, containing color codes that represent the edge corresponding to |
| 360 | + the input. |
| 361 | + save_path : str or None |
| 362 | + If not None, save the figure to the file specified by its path. |
| 363 | + *Default* is no saving. |
| 364 | + """ |
| 365 | + # Default arguments |
| 366 | + if edge_colors is None: |
| 367 | + edge_colors = [mcolors.hsv_to_rgb((i/self.number_inputs, 1, 1)) for i in range(self.number_inputs)] |
| 368 | + |
| 369 | + if state_order is None: |
| 370 | + state_order = list(range(self.number_states)) |
| 371 | + |
| 372 | + # Init the figure |
| 373 | + ax = plt.axes((0, 0, 1, 1)) |
| 374 | + |
| 375 | + # Plot states |
| 376 | + radius = state_radius * self.number_states |
| 377 | + angles = 2 * np.pi / self.number_states * np.arange(self.number_states) |
| 378 | + positions = [(radius * math.cos(angle), radius * math.sin(angle)) for angle in angles] |
| 379 | + |
| 380 | + state_patches = [] |
| 381 | + arrows = [] |
| 382 | + for idx, state in enumerate(state_order): |
| 383 | + state_patches.append(mpatches.Circle(positions[idx], state_radius, color="#003399", ec="#cccccc")) |
| 384 | + plt.text(positions[idx][0], positions[idx][1], str(state), ha='center', va='center', size=20) |
| 385 | + |
| 386 | + # Plot transition |
| 387 | + for input in range(self.number_inputs): |
| 388 | + next_state = self.next_state_table[state, input] |
| 389 | + next_idx = (state_order == next_state).nonzero()[0][0] |
| 390 | + output = self.output_table[state, input] |
| 391 | + |
| 392 | + # Transition arrow |
| 393 | + if next_state == state: |
| 394 | + # Positions |
| 395 | + arrow_start_x = positions[idx][0] + state_radius * math.cos(angles[idx] + math.pi / 6) |
| 396 | + arrow_start_y = positions[idx][1] + state_radius * math.sin(angles[idx] + math.pi / 6) |
| 397 | + arrow_end_x = positions[idx][0] + state_radius * math.cos(angles[idx] - math.pi / 6) |
| 398 | + arrow_end_y = positions[idx][1] + state_radius * math.sin(angles[idx] - math.pi / 6) |
| 399 | + arrow_mid_x = positions[idx][0] + state_radius * 2 * math.cos(angles[idx]) |
| 400 | + arrow_mid_y = positions[idx][1] + state_radius * 2 * math.sin(angles[idx]) |
| 401 | + |
| 402 | + # Add text |
| 403 | + plt.text(arrow_mid_x, arrow_mid_y, '({})'.format(output), |
| 404 | + ha='center', va='center', backgroundcolor=edge_colors[input]) |
| 405 | + |
| 406 | + else: |
| 407 | + # Positions |
| 408 | + dx = positions[next_idx][0] - positions[idx][0] |
| 409 | + dy = positions[next_idx][1] - positions[idx][1] |
| 410 | + relative_angle = math.atan(dy / dx) + np.where(dx > 0, 0, math.pi) |
| 411 | + |
| 412 | + arrow_start_x = positions[idx][0] + state_radius * math.cos(relative_angle + math.pi * 0.05) |
| 413 | + arrow_start_y = positions[idx][1] + state_radius * math.sin(relative_angle + math.pi * 0.05) |
| 414 | + arrow_end_x = positions[next_idx][0] - state_radius * math.cos(relative_angle - math.pi * 0.05) |
| 415 | + arrow_end_y = positions[next_idx][1] - state_radius * math.sin(relative_angle - math.pi * 0.05) |
| 416 | + arrow_mid_x = (arrow_start_x + arrow_end_x) / 2 + \ |
| 417 | + radius * 0.25 * math.cos((angles[idx] + angles[next_idx]) / 2) * np.sign(dx) |
| 418 | + arrow_mid_y = (arrow_start_y + arrow_end_y) / 2 + \ |
| 419 | + radius * 0.25 * math.sin((angles[idx] + angles[next_idx]) / 2) * np.sign(dx) |
| 420 | + text_x = arrow_mid_x + 0.01 * math.cos((angles[idx] + angles[next_idx]) / 2) |
| 421 | + text_y = arrow_mid_y + 0.01 * math.sin((angles[idx] + angles[next_idx]) / 2) |
| 422 | + |
| 423 | + # Add text |
| 424 | + plt.text(text_x, text_y, '({})'.format(output), |
| 425 | + ha='center', va='center', backgroundcolor=edge_colors[input]) |
| 426 | + |
| 427 | + # Path creation |
| 428 | + codes = (mpath.Path.MOVETO, mpath.Path.CURVE3, mpath.Path.CURVE3) |
| 429 | + verts = ((arrow_start_x, arrow_start_y), |
| 430 | + (arrow_mid_x, arrow_mid_y), |
| 431 | + (arrow_end_x, arrow_end_y)) |
| 432 | + path = mpath.Path(verts, codes) |
| 433 | + |
| 434 | + # Plot arrow |
| 435 | + arrow = mpatches.FancyArrowPatch(path=path, mutation_scale=20, color=edge_colors[input]) |
| 436 | + ax.add_artist(arrow) |
| 437 | + arrows.append(arrow) |
| 438 | + |
| 439 | + # Format and plot |
| 440 | + ax.set_xlim(radius * -2, radius * 2) |
| 441 | + ax.set_ylim(radius * -2, radius * 2) |
| 442 | + ax.add_collection(PatchCollection(state_patches, True)) |
| 443 | + plt.legend(arrows, [str(i) + "-input" for i in range(self.number_inputs)], loc='lower right') |
| 444 | + plt.text(0, 1.5 * radius, 'Finite State Machine (output on transition)', ha='center', size=18) |
| 445 | + plt.show() |
| 446 | + if save_path is not None: |
| 447 | + plt.savefig(save_path) |
| 448 | + |
342 | 449 |
|
343 | 450 | def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=None): |
344 | 451 | """ |
@@ -574,7 +681,8 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'): |
574 | 681 | count = 0 |
575 | 682 | current_number_states = trellis.number_states |
576 | 683 |
|
577 | | - coded_bits = coded_bits.clip(-500, 500) |
| 684 | + if decoding_type == 'soft': |
| 685 | + coded_bits = coded_bits.clip(-500, 500) |
578 | 686 |
|
579 | 687 | for t in range(1, int((L+total_memory)/k)): |
580 | 688 | # Get the received codeword corresponding to t |
|
0 commit comments