33
44""" Algorithms for Convolutional Codes """
55
6+ from __future__ import division
7+
8+ import math
69from warnings import warn
710
11+ import matplotlib .colors as mcolors
812import matplotlib .patches as mpatches
913import matplotlib .pyplot as plt
1014import numpy as np
11- from commpy .utilities import dec2bitarray , bitarray2dec , hamming_dist , euclid_dist
1215from matplotlib .collections import PatchCollection
1316
17+ from commpy .utilities import dec2bitarray , bitarray2dec , hamming_dist , euclid_dist
18+
1419__all__ = ['Trellis' , 'conv_encode' , 'viterbi_decode' ]
1520
1621class Trellis :
@@ -257,11 +262,11 @@ def _generate_edges(self, trellis_length, grid, state_order, state_radius, edge_
257262 dx = grid_subset [0 , state_count_2 + self .number_states ] - grid_subset [0 ,state_count_1 ] - 2 * state_radius
258263 dy = grid_subset [1 , state_count_2 + self .number_states ] - grid_subset [1 ,state_count_1 ]
259264 if np .count_nonzero (self .next_state_table [state_order [state_count_1 ],:] == state_order [state_count_2 ]):
260- found_index = np .where (self .next_state_table [state_order [state_count_1 ],: ] ==
265+ found_index = np .where (self .next_state_table [state_order [state_count_1 ]] ==
261266 state_order [state_count_2 ])
262267 edge_patch = mpatches .FancyArrow (grid_subset [0 ,state_count_1 ]+ state_radius ,
263268 grid_subset [1 ,state_count_1 ], dx , dy , width = 0.005 ,
264- length_includes_head = True , color = edge_colors [found_index [0 ][0 ]])
269+ length_includes_head = True , color = edge_colors [found_index [0 ][0 ]- 1 ])
265270 edge_patches .append (edge_patch )
266271 input_count = input_count + 1
267272
@@ -280,7 +285,7 @@ def _generate_labels(self, grid, state_order, state_radius, font):
280285
281286
282287 def visualize (self , trellis_length = 2 , state_order = None ,
283- state_radius = 0.04 , edge_colors = None ):
288+ state_radius = 0.04 , edge_colors = None , save_path = None ):
284289 """ Plot the trellis diagram.
285290 Parameters
286291 ----------
@@ -294,13 +299,15 @@ def visualize(self, trellis_length = 2, state_order = None,
294299 state_radius : float, optional
295300 Radius of each state (circle) in the plot.
296301 Default value is 0.04
297- edge_colors = list of hex color codes, optional
302+ edge_colors : list of hex color codes, optional
298303 A list of length equal to the number_inputs,
299304 containing color codes that represent the edge corresponding
300305 to the input.
306+ save_path : str or None
307+ If not None, save the figure to the file specified by its path.
301308 """
302309 if edge_colors is None :
303- edge_colors = ["#9E1BE0" , "#06D65D" ]
310+ edge_colors = [mcolors . hsv_to_rgb (( i / self . number_inputs , 1 , 1 )) for i in range ( self . number_inputs ) ]
304311
305312 if state_order is None :
306313 state_order = list (range (self .number_states ))
@@ -327,9 +334,10 @@ def visualize(self, trellis_length = 2, state_order = None,
327334 ax .add_collection (collection )
328335 ax .set_xticks ([])
329336 ax .set_yticks ([])
330- plt .legend ([edge_patches [0 ], edge_patches [1 ]], ["1-input" , "0-input" ])
331- #plt.savefig('trellis')
337+ plt .legend (edge_patches , [str (i ) + "-input" for i in range (self .number_inputs )])
332338 plt .show ()
339+ if save_path is not None :
340+ plt .savefig (save_path )
333341
334342
335343def conv_encode (message_bits , trellis , termination = 'term' , puncture_matrix = None ):
@@ -371,7 +379,7 @@ def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=Non
371379 if code_type == 'rsc' :
372380 inbits = message_bits
373381 number_inbits = number_message_bits
374- number_outbits = int ((number_inbits + total_memory )/ rate )
382+ number_outbits = int ((number_inbits + k * total_memory )/ rate )
375383 else :
376384 number_inbits = number_message_bits + total_memory + total_memory % k
377385 inbits = np .zeros (number_inbits , 'int' )
@@ -409,19 +417,17 @@ def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=Non
409417 current_state = next_state_table [current_state ][current_input ]
410418 j += 1
411419
412- if puncture_matrix is not None :
413- j = 0
414- for i in range (number_outbits ):
415- if puncture_matrix [0 ][i % np .size (puncture_matrix , 1 )] == 1 :
416- p_outbits [j ] = outbits [i ]
417- j = j + 1
420+ j = 0
421+ for i in range (number_outbits ):
422+ if puncture_matrix [0 ][i % np .size (puncture_matrix , 1 )] == 1 :
423+ p_outbits [j ] = outbits [i ]
424+ j = j + 1
418425
419426 return p_outbits
420427
421428
422429def _where_c (inarray , rows , cols , search_value , index_array ):
423430
424- #cdef int i, j,
425431 number_found = 0
426432 for i in range (rows ):
427433 for j in range (cols ):
@@ -438,10 +444,6 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
438444 decoded_bits , tb_count , t , count ,
439445 tb_depth , current_number_states ):
440446
441- #cdef int state_num, i, j, number_previous_states, previous_state, \
442- # previous_input, i_codeword, number_found, min_idx, \
443- # current_state, dec_symbol
444-
445447 k = trellis .k
446448 n = trellis .n
447449 number_states = trellis .number_states
@@ -452,9 +454,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
452454 next_state_table = trellis .next_state_table
453455 output_table = trellis .output_table
454456 pmetrics = np .empty (number_inputs )
455- i_codeword_array = np .empty (n , 'int' )
456457 index_array = np .empty ([number_states , 2 ], 'int' )
457- decoded_bitarray = np .empty (k , 'int' )
458458
459459 # Loop over all the current states (Time instant: t)
460460 for state_num in range (current_number_states ):
@@ -471,20 +471,16 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
471471
472472 # Using the output table, find the ideal codeword
473473 i_codeword = output_table [previous_state , previous_input ]
474- #dec2bitarray_c(i_codeword, n, i_codeword_array)
475474 i_codeword_array = dec2bitarray (i_codeword , n )
476475
477476 # Compute Branch Metrics
478477 if decoding_type == 'hard' :
479- #branch_metric = hamming_dist_c(r_codeword.astype(int), i_codeword_array.astype(int), n)
480478 branch_metric = hamming_dist (r_codeword .astype (int ), i_codeword_array .astype (int ))
481479 elif decoding_type == 'soft' :
482480 pass
483481 elif decoding_type == 'unquantized' :
484482 i_codeword_array = 2 * i_codeword_array - 1
485483 branch_metric = euclid_dist (r_codeword , i_codeword_array )
486- else :
487- pass
488484
489485 # ADD operation: Add the branch metric to the
490486 # accumulated path metric and store it in the temporary array
@@ -512,8 +508,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
512508 dec_symbol = decoded_symbols [current_state , j ]
513509 previous_state = paths [current_state , j ]
514510 decoded_bitarray = dec2bitarray (dec_symbol , k )
515- decoded_bits [(t - tb_depth - 1 )+ (j + 1 )* k + count :(t - tb_depth - 1 )+ (j + 2 )* k + count ] = \
516- decoded_bitarray
511+ decoded_bits [t - tb_depth + 1 + (j - 1 ) * k + count :t - tb_depth + 1 + j * k + count ] = decoded_bitarray
517512 current_state = previous_state
518513
519514 paths [:,0 :tb_depth - 1 ] = paths [:,1 :]
@@ -523,26 +518,29 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
523518
524519def viterbi_decode (coded_bits , trellis , tb_depth = None , decoding_type = 'hard' ):
525520 """
526- Decodes a stream of convolutionally encoded bits using the Viterbi Algorithm
521+ Decodes a stream of convolutionally encoded bits using the Viterbi Algorithm.
527522 Parameters
528523 ----------
529524 coded_bits : 1D ndarray
530525 Stream of convolutionally encoded bits which are to be decoded.
531- generator_matrix : 2D ndarray of ints
532- Generator matrix G(D) of the convolutional code using which the
533- input bits are to be decoded.
534- M : 1D ndarray of ints
535- Number of memory elements per input of the convolutional encoder.
536- tb_length : int
537- Traceback depth (Typically set to 5*(M+1)).
538- decoding_type : str {'hard', 'unquantized'}
526+ treillis : treillis object
527+ Treillis representing the convolutional code.
528+ tb_depth : int
529+ Traceback depth.
530+ *Default* is 5 times the number of memories in the code.
531+ decoding_type : str {'hard', 'soft', 'unquantized'}
539532 The type of decoding to be used.
540533 'hard' option is used for hard inputs (bits) to the decoder, e.g., BSC channel.
534+ 'soft' option is used for soft inputs (LLRs) to the decoder.
541535 'unquantized' option is used for soft inputs (real numbers) to the decoder, e.g., BAWGN channel.
542536 Returns
543537 -------
544538 decoded_bits : 1D ndarray
545539 Decoded bit stream.
540+ Raises
541+ ------
542+ ValueError
543+ If decoding_type is something else than 'hard', 'soft' or 'unquantized'.
546544 References
547545 ----------
548546 .. [1] Todd K. Moon. Error Correction Coding: Mathematical Methods and
@@ -552,49 +550,41 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
552550 # k = Rows in G(D), n = columns in G(D)
553551 k = trellis .k
554552 n = trellis .n
555- rate = float ( k ) / n
553+ rate = k / n
556554 total_memory = trellis .total_memory
557- number_states = trellis .number_states
558- number_inputs = trellis .number_inputs
559555
560556 if tb_depth is None :
561557 tb_depth = 5 * total_memory
562558
563- next_state_table = trellis .next_state_table
564- output_table = trellis .output_table
565-
566559 # Number of message bits after decoding
567560 L = int (len (coded_bits )* rate )
568561
569- path_metrics = np .empty ([number_states , 2 ])
570- path_metrics [:, :] = 1000000
562+ path_metrics = np .full ((trellis .number_states , 2 ), np .inf )
571563 path_metrics [0 ][0 ] = 0
572- paths = np .empty ([number_states , tb_depth ], 'int' )
573- paths [:, :] = 1000000
564+ paths = np .full ((trellis .number_states , tb_depth ), np .iinfo (int ).max , 'int' )
574565 paths [0 ][0 ] = 0
575566
576- decoded_symbols = np .zeros ([number_states , tb_depth ], 'int' )
577- decoded_bits = np .zeros ( L + tb_depth + k , 'int' )
567+ decoded_symbols = np .zeros ([trellis . number_states , tb_depth ], 'int' )
568+ decoded_bits = np .empty ( math . ceil ( L / k ) * k + tb_depth , 'int' )
578569 r_codeword = np .zeros (n , 'int' )
579570
580571 tb_count = 1
581572 count = 0
582- current_number_states = number_states
573+ current_number_states = trellis . number_states
583574
584- for t in range (1 , int ((L + total_memory + total_memory % k )/ k ) + 1 ):
575+ for t in range (1 , int ((L + total_memory )/ k )):
585576 # Get the received codeword corresponding to t
586- if t <= L :
577+ if t <= L // k :
587578 r_codeword = coded_bits [(t - 1 )* n :t * n ]
588579 else :
589580 if decoding_type == 'hard' :
590581 r_codeword [:] = 0
591582 elif decoding_type == 'soft' :
592583 pass
593584 elif decoding_type == 'unquantized' :
594- r_codeword [:] = 0
595- r_codeword = 2 * r_codeword - 1
585+ r_codeword [:] = - 1
596586 else :
597- pass
587+ raise ValueError ( 'The available decoding types are "hard", "soft" and "unquantized' )
598588
599589 _acs_traceback (r_codeword , trellis , decoding_type , path_metrics , paths ,
600590 decoded_symbols , decoded_bits , tb_count , t , count , tb_depth ,
@@ -609,11 +599,7 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
609599 # Path metrics (at t-1) = Path metrics (at t)
610600 path_metrics [:, 0 ] = path_metrics [:, 1 ]
611601
612- # Force all the paths back to '0' state at the end of decoding
613- if t == (L + total_memory + total_memory % k )/ k :
614- current_number_states = 1
615-
616- return decoded_bits [0 :len (decoded_bits )- tb_depth - 1 ]
602+ return decoded_bits [:L ]
617603
618604def puncturing (message , punct_vec ):
619605 """
0 commit comments