Skip to content

Commit 41ad792

Browse files
committed
Visualize, conv_encode and viterbi_decode works for 2/3-rate codes.
1 parent 8740ce8 commit 41ad792

File tree

2 files changed

+51
-64
lines changed

2 files changed

+51
-64
lines changed

commpy/channelcoding/convcode.py

Lines changed: 46 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33

44
""" Algorithms for Convolutional Codes """
55

6+
from __future__ import division
7+
8+
import math
69
from warnings import warn
710

11+
import matplotlib.colors as mcolors
812
import matplotlib.patches as mpatches
913
import matplotlib.pyplot as plt
1014
import numpy as np
11-
from commpy.utilities import dec2bitarray, bitarray2dec, hamming_dist, euclid_dist
1215
from matplotlib.collections import PatchCollection
1316

17+
from commpy.utilities import dec2bitarray, bitarray2dec, hamming_dist, euclid_dist
18+
1419
__all__ = ['Trellis', 'conv_encode', 'viterbi_decode']
1520

1621
class Trellis:
@@ -257,11 +262,11 @@ def _generate_edges(self, trellis_length, grid, state_order, state_radius, edge_
257262
dx = grid_subset[0, state_count_2+self.number_states] - grid_subset[0,state_count_1] - 2*state_radius
258263
dy = grid_subset[1, state_count_2+self.number_states] - grid_subset[1,state_count_1]
259264
if np.count_nonzero(self.next_state_table[state_order[state_count_1],:] == state_order[state_count_2]):
260-
found_index = np.where(self.next_state_table[state_order[state_count_1],:] ==
265+
found_index = np.where(self.next_state_table[state_order[state_count_1]] ==
261266
state_order[state_count_2])
262267
edge_patch = mpatches.FancyArrow(grid_subset[0,state_count_1]+state_radius,
263268
grid_subset[1,state_count_1], dx, dy, width=0.005,
264-
length_includes_head = True, color = edge_colors[found_index[0][0]])
269+
length_includes_head = True, color = edge_colors[found_index[0][0]-1])
265270
edge_patches.append(edge_patch)
266271
input_count = input_count + 1
267272

@@ -280,7 +285,7 @@ def _generate_labels(self, grid, state_order, state_radius, font):
280285

281286

282287
def visualize(self, trellis_length = 2, state_order = None,
283-
state_radius = 0.04, edge_colors = None):
288+
state_radius = 0.04, edge_colors = None, save_path = None):
284289
""" Plot the trellis diagram.
285290
Parameters
286291
----------
@@ -294,13 +299,15 @@ def visualize(self, trellis_length = 2, state_order = None,
294299
state_radius : float, optional
295300
Radius of each state (circle) in the plot.
296301
Default value is 0.04
297-
edge_colors = list of hex color codes, optional
302+
edge_colors : list of hex color codes, optional
298303
A list of length equal to the number_inputs,
299304
containing color codes that represent the edge corresponding
300305
to the input.
306+
save_path : str or None
307+
If not None, save the figure to the file specified by its path.
301308
"""
302309
if edge_colors is None:
303-
edge_colors = ["#9E1BE0", "#06D65D"]
310+
edge_colors = [mcolors.hsv_to_rgb((i/self.number_inputs, 1, 1)) for i in range(self.number_inputs)]
304311

305312
if state_order is None:
306313
state_order = list(range(self.number_states))
@@ -327,9 +334,10 @@ def visualize(self, trellis_length = 2, state_order = None,
327334
ax.add_collection(collection)
328335
ax.set_xticks([])
329336
ax.set_yticks([])
330-
plt.legend([edge_patches[0], edge_patches[1]], ["1-input", "0-input"])
331-
#plt.savefig('trellis')
337+
plt.legend(edge_patches, [str(i) + "-input" for i in range(self.number_inputs)])
332338
plt.show()
339+
if save_path is not None:
340+
plt.savefig(save_path)
333341

334342

335343
def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=None):
@@ -371,7 +379,7 @@ def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=Non
371379
if code_type == 'rsc':
372380
inbits = message_bits
373381
number_inbits = number_message_bits
374-
number_outbits = int((number_inbits + total_memory)/rate)
382+
number_outbits = int((number_inbits + k * total_memory)/rate)
375383
else:
376384
number_inbits = number_message_bits + total_memory + total_memory % k
377385
inbits = np.zeros(number_inbits, 'int')
@@ -409,19 +417,17 @@ def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=Non
409417
current_state = next_state_table[current_state][current_input]
410418
j += 1
411419

412-
if puncture_matrix is not None:
413-
j = 0
414-
for i in range(number_outbits):
415-
if puncture_matrix[0][i % np.size(puncture_matrix, 1)] == 1:
416-
p_outbits[j] = outbits[i]
417-
j = j + 1
420+
j = 0
421+
for i in range(number_outbits):
422+
if puncture_matrix[0][i % np.size(puncture_matrix, 1)] == 1:
423+
p_outbits[j] = outbits[i]
424+
j = j + 1
418425

419426
return p_outbits
420427

421428

422429
def _where_c(inarray, rows, cols, search_value, index_array):
423430

424-
#cdef int i, j,
425431
number_found = 0
426432
for i in range(rows):
427433
for j in range(cols):
@@ -438,10 +444,6 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
438444
decoded_bits, tb_count, t, count,
439445
tb_depth, current_number_states):
440446

441-
#cdef int state_num, i, j, number_previous_states, previous_state, \
442-
# previous_input, i_codeword, number_found, min_idx, \
443-
# current_state, dec_symbol
444-
445447
k = trellis.k
446448
n = trellis.n
447449
number_states = trellis.number_states
@@ -452,9 +454,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
452454
next_state_table = trellis.next_state_table
453455
output_table = trellis.output_table
454456
pmetrics = np.empty(number_inputs)
455-
i_codeword_array = np.empty(n, 'int')
456457
index_array = np.empty([number_states, 2], 'int')
457-
decoded_bitarray = np.empty(k, 'int')
458458

459459
# Loop over all the current states (Time instant: t)
460460
for state_num in range(current_number_states):
@@ -471,20 +471,16 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
471471

472472
# Using the output table, find the ideal codeword
473473
i_codeword = output_table[previous_state, previous_input]
474-
#dec2bitarray_c(i_codeword, n, i_codeword_array)
475474
i_codeword_array = dec2bitarray(i_codeword, n)
476475

477476
# Compute Branch Metrics
478477
if decoding_type == 'hard':
479-
#branch_metric = hamming_dist_c(r_codeword.astype(int), i_codeword_array.astype(int), n)
480478
branch_metric = hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int))
481479
elif decoding_type == 'soft':
482480
pass
483481
elif decoding_type == 'unquantized':
484482
i_codeword_array = 2*i_codeword_array - 1
485483
branch_metric = euclid_dist(r_codeword, i_codeword_array)
486-
else:
487-
pass
488484

489485
# ADD operation: Add the branch metric to the
490486
# accumulated path metric and store it in the temporary array
@@ -512,8 +508,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
512508
dec_symbol = decoded_symbols[current_state, j]
513509
previous_state = paths[current_state, j]
514510
decoded_bitarray = dec2bitarray(dec_symbol, k)
515-
decoded_bits[(t-tb_depth-1)+(j+1)*k+count:(t-tb_depth-1)+(j+2)*k+count] = \
516-
decoded_bitarray
511+
decoded_bits[t - tb_depth + 1 + (j - 1) * k + count:t - tb_depth + 1 + j * k + count] = decoded_bitarray
517512
current_state = previous_state
518513

519514
paths[:,0:tb_depth-1] = paths[:,1:]
@@ -523,26 +518,29 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
523518

524519
def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
525520
"""
526-
Decodes a stream of convolutionally encoded bits using the Viterbi Algorithm
521+
Decodes a stream of convolutionally encoded bits using the Viterbi Algorithm.
527522
Parameters
528523
----------
529524
coded_bits : 1D ndarray
530525
Stream of convolutionally encoded bits which are to be decoded.
531-
generator_matrix : 2D ndarray of ints
532-
Generator matrix G(D) of the convolutional code using which the
533-
input bits are to be decoded.
534-
M : 1D ndarray of ints
535-
Number of memory elements per input of the convolutional encoder.
536-
tb_length : int
537-
Traceback depth (Typically set to 5*(M+1)).
538-
decoding_type : str {'hard', 'unquantized'}
526+
treillis : treillis object
527+
Treillis representing the convolutional code.
528+
tb_depth : int
529+
Traceback depth.
530+
*Default* is 5 times the number of memories in the code.
531+
decoding_type : str {'hard', 'soft', 'unquantized'}
539532
The type of decoding to be used.
540533
'hard' option is used for hard inputs (bits) to the decoder, e.g., BSC channel.
534+
'soft' option is used for soft inputs (LLRs) to the decoder.
541535
'unquantized' option is used for soft inputs (real numbers) to the decoder, e.g., BAWGN channel.
542536
Returns
543537
-------
544538
decoded_bits : 1D ndarray
545539
Decoded bit stream.
540+
Raises
541+
------
542+
ValueError
543+
If decoding_type is something else than 'hard', 'soft' or 'unquantized'.
546544
References
547545
----------
548546
.. [1] Todd K. Moon. Error Correction Coding: Mathematical Methods and
@@ -552,49 +550,41 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
552550
# k = Rows in G(D), n = columns in G(D)
553551
k = trellis.k
554552
n = trellis.n
555-
rate = float(k)/n
553+
rate = k/n
556554
total_memory = trellis.total_memory
557-
number_states = trellis.number_states
558-
number_inputs = trellis.number_inputs
559555

560556
if tb_depth is None:
561557
tb_depth = 5*total_memory
562558

563-
next_state_table = trellis.next_state_table
564-
output_table = trellis.output_table
565-
566559
# Number of message bits after decoding
567560
L = int(len(coded_bits)*rate)
568561

569-
path_metrics = np.empty([number_states, 2])
570-
path_metrics[:, :] = 1000000
562+
path_metrics = np.full((trellis.number_states, 2), np.inf)
571563
path_metrics[0][0] = 0
572-
paths = np.empty([number_states, tb_depth], 'int')
573-
paths[:, :] = 1000000
564+
paths = np.full((trellis.number_states, tb_depth), np.iinfo(int).max, 'int')
574565
paths[0][0] = 0
575566

576-
decoded_symbols = np.zeros([number_states, tb_depth], 'int')
577-
decoded_bits = np.zeros(L+tb_depth+k, 'int')
567+
decoded_symbols = np.zeros([trellis.number_states, tb_depth], 'int')
568+
decoded_bits = np.empty(math.ceil(L / k) * k + tb_depth, 'int')
578569
r_codeword = np.zeros(n, 'int')
579570

580571
tb_count = 1
581572
count = 0
582-
current_number_states = number_states
573+
current_number_states = trellis.number_states
583574

584-
for t in range(1, int((L+total_memory+total_memory%k)/k) + 1):
575+
for t in range(1, int((L+total_memory)/k)):
585576
# Get the received codeword corresponding to t
586-
if t <= L:
577+
if t <= L // k:
587578
r_codeword = coded_bits[(t-1)*n:t*n]
588579
else:
589580
if decoding_type == 'hard':
590581
r_codeword[:] = 0
591582
elif decoding_type == 'soft':
592583
pass
593584
elif decoding_type == 'unquantized':
594-
r_codeword[:] = 0
595-
r_codeword = 2*r_codeword - 1
585+
r_codeword[:] = -1
596586
else:
597-
pass
587+
raise ValueError('The available decoding types are "hard", "soft" and "unquantized')
598588

599589
_acs_traceback(r_codeword, trellis, decoding_type, path_metrics, paths,
600590
decoded_symbols, decoded_bits, tb_count, t, count, tb_depth,
@@ -609,11 +599,7 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
609599
# Path metrics (at t-1) = Path metrics (at t)
610600
path_metrics[:, 0] = path_metrics[:, 1]
611601

612-
# Force all the paths back to '0' state at the end of decoding
613-
if t == (L+total_memory+total_memory%k)/k:
614-
current_number_states = 1
615-
616-
return decoded_bits[0:len(decoded_bits)-tb_depth-1]
602+
return decoded_bits[:L]
617603

618604
def puncturing(message, punct_vec):
619605
"""

commpy/channelcoding/tests/test_convcode.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from numpy import array
55
from numpy.random import randint
6-
from numpy.testing import assert_array_equal
6+
from numpy.testing import assert_array_equal, dec
77

88
from commpy.channelcoding.convcode import Trellis, conv_encode, viterbi_decode
99

@@ -100,6 +100,7 @@ def test_conv_encode(self):
100100
def test_viterbi_decode(self):
101101
pass
102102

103+
@dec.slow
103104
def test_conv_encode_viterbi_decode(self):
104105
niters = 10
105106
blocklength = 1000
@@ -108,10 +109,10 @@ def test_conv_encode_viterbi_decode(self):
108109
msg = randint(0, 2, blocklength)
109110

110111
# Previous tests
111-
for i in range(2):
112+
for i in range(4):
112113
coded_bits = conv_encode(msg, self.trellis[i])
113114
decoded_bits = viterbi_decode(coded_bits.astype(float), self.trellis[i], 15)
114-
assert_array_equal(decoded_bits[:-2], msg)
115+
assert_array_equal(decoded_bits[:len(msg)], msg)
115116

116117
coded_bits = conv_encode(msg, self.trellis[i], termination='cont')
117118
decoded_bits = viterbi_decode(coded_bits.astype(float), self.trellis[i], 15)
@@ -120,4 +121,4 @@ def test_conv_encode_viterbi_decode(self):
120121
coded_bits = conv_encode(msg, self.trellis[i])
121122
coded_syms = 2.0 * coded_bits - 1
122123
decoded_bits = viterbi_decode(coded_syms, self.trellis[i], 15, 'unquantized')
123-
assert_array_equal(decoded_bits[:-2], msg)
124+
assert_array_equal(decoded_bits[:len(msg)], msg)

0 commit comments

Comments
 (0)