1-
2-
31# Authors: Veeresh Taranalli <veeresht@gmail.com>
42# License: BSD 3-Clause
53
1816class Trellis :
1917 """
2018 Class defining a Trellis corresponding to a k/n - rate convolutional code.
21-
2219 Parameters
2320 ----------
2421 memory : 1D ndarray of ints
2522 Number of memory elements per input of the convolutional encoder.
26-
2723 g_matrix : 2D ndarray of ints (octal representation)
2824 Generator matrix G(D) of the convolutional encoder. Each element of
2925 G(D) represents a polynomial.
30-
3126 feedback : int, optional
3227 Feedback polynomial of the convolutional encoder. Default value is 00.
33-
3428 code_type : {'default', 'rsc'}, optional
3529 Use 'rsc' to generate a recursive systematic convolutional code.
36-
3730 If 'rsc' is specified, then the first 'k x k' sub-matrix of
38-
3931 G(D) must represent a identity matrix along with a non-zero
4032 feedback polynomial.
41-
42-
4333 Attributes
4434 ----------
4535 k : int
4636 Size of the smallest block of input bits that can be encoded using
4737 the convolutional code.
48-
4938 n : int
5039 Size of the smallest block of output bits generated using
5140 the convolutional code.
52-
5341 total_memory : int
5442 Total number of delay elements needed to implement the convolutional
5543 encoder.
56-
5744 number_states : int
5845 Number of states in the convolutional code trellis.
59-
6046 number_inputs : int
6147 Number of branches from each state in the convolutional code trellis.
62-
6348 next_state_table : 2D ndarray of ints
6449 Table representing the state transition matrix of the
6550 convolutional code trellis. Rows represent current states and
6651 columns represent current inputs in decimal. Elements represent the
6752 corresponding next states in decimal.
68-
6953 output_table : 2D ndarray of ints
7054 Table representing the output matrix of the convolutional code trellis.
7155 Rows represent current states and columns represent current inputs in
7256 decimal. Elements represent corresponding outputs in decimal.
73-
7457 Examples
7558 --------
7659 >>> from numpy import array
@@ -98,7 +81,6 @@ class Trellis:
9881 [3 0]
9982 [1 2]
10083 [2 1]]
101-
10284 """
10385 def __init__ (self , memory , g_matrix , feedback = 0 , code_type = 'default' ):
10486
@@ -107,7 +89,8 @@ def __init__(self, memory, g_matrix, feedback = 0, code_type = 'default'):
10789 if code_type == 'rsc' :
10890 for i in range (self .k ):
10991 g_matrix [i ][i ] = feedback
110-
92+ self .code_type = code_type
93+
11194 self .total_memory = memory .sum ()
11295 self .number_states = pow (2 , self .total_memory )
11396 self .number_inputs = pow (2 , self .k )
@@ -176,7 +159,7 @@ def _generate_grid(self, trellis_length):
176159 """ Private method """
177160
178161 grid = np .mgrid [0.12 :0.22 * trellis_length :(trellis_length + 1 )* (0 + 1j ),
179- 0.1 :0.1 + self .number_states * 0.1 :self .number_states * (0 + 1j )].reshape (2 , - 1 )
162+ 0.1 :0.5 + self .number_states * 0.1 :self .number_states * (0 + 1j )].reshape (2 , - 1 )
180163
181164 return grid
182165
@@ -231,27 +214,22 @@ def _generate_labels(self, grid, state_order, state_radius, font):
231214 def visualize (self , trellis_length = 2 , state_order = None ,
232215 state_radius = 0.04 , edge_colors = None ):
233216 """ Plot the trellis diagram.
234-
235217 Parameters
236218 ----------
237219 trellis_length : int, optional
238220 Specifies the number of time steps in the trellis diagram.
239221 Default value is 2.
240-
241222 state_order : list of ints, optional
242223 Specifies the order in the which the states of the trellis
243224 are to be displayed starting from the top in the plot.
244225 Default order is [0,...,number_states-1]
245-
246226 state_radius : float, optional
247227 Radius of each state (circle) in the plot.
248228 Default value is 0.04
249-
250229 edge_colors = list of hex color codes, optional
251230 A list of length equal to the number_inputs,
252231 containing color codes that represent the edge corresponding
253232 to the input.
254-
255233 """
256234 if edge_colors is None :
257235 edge_colors = ["#9E1BE0" , "#06D65D" ]
@@ -260,7 +238,7 @@ def visualize(self, trellis_length = 2, state_order = None,
260238 state_order = list (range (self .number_states ))
261239
262240 font = "sans-serif"
263- fig = plt .figure ()
241+ fig = plt .figure (figsize = ( 12 , 6 ), dpi = 150 )
264242 ax = plt .axes ([0 ,0 ,1 ,1 ])
265243 trellis_patches = []
266244
@@ -281,26 +259,23 @@ def visualize(self, trellis_length = 2, state_order = None,
281259 ax .add_collection (collection )
282260 ax .set_xticks ([])
283261 ax .set_yticks ([])
284- #plt.legend([edge_patches[0], edge_patches[1]], ["1-input", "0-input"])
262+ plt .legend ([edge_patches [0 ], edge_patches [1 ]], ["1-input" , "0-input" ])
263+ #plt.savefig('trellis')
285264 plt .show ()
286265
287266
288- def conv_encode (message_bits , trellis , code_type = 'default ' , puncture_matrix = None ):
267+ def conv_encode (message_bits , trellis , termination = 'term ' , puncture_matrix = None ):
289268 """
290269 Encode bits using a convolutional code.
291-
292270 Parameters
293271 ----------
294272 message_bits : 1D ndarray containing {0, 1}
295273 Stream of bits to be convolutionally encoded.
296-
297- generator_matrix : 2-D ndarray of ints
298- Generator matrix G(D) of the convolutional code using which the input
299- bits are to be encoded.
300-
301- M : 1D ndarray of ints
302- Number of memory elements per input of the convolutional encoder.
303-
274+ trellis: pre-initialized Trellis structure.
275+ termination: {'cont', 'term'}, optional
276+ Create ('term') or not ('cont') termination bits.
277+ puncture_matrix: 2D ndarray containing {0, 1}, optional
278+ Matrix used for the puncturing algorithm
304279 Returns
305280 -------
306281 coded_bits : 1D ndarray containing {0, 1}
@@ -311,26 +286,30 @@ def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=No
311286 n = trellis .n
312287 total_memory = trellis .total_memory
313288 rate = float (k )/ n
289+
290+ code_type = trellis .code_type
314291
315292 if puncture_matrix is None :
316293 puncture_matrix = np .ones ((trellis .k , trellis .n ))
317294
318295 number_message_bits = np .size (message_bits )
319-
320- # Initialize an array to contain the message bits plus the truncation zeros
321- if code_type == 'default' :
322- inbits = np .zeros (number_message_bits + total_memory + total_memory % k ,
323- 'int' )
324- number_inbits = number_message_bits + total_memory + total_memory % k
325-
326- # Pad the input bits with M zeros (L-th terminated truncation)
327- inbits [0 :number_message_bits ] = message_bits
328- number_outbits = int (number_inbits / rate )
329-
330- else :
296+
297+ if termination == 'cont' :
331298 inbits = message_bits
332299 number_inbits = number_message_bits
333- number_outbits = int ((number_inbits + total_memory )/ rate )
300+ number_outbits = int (number_inbits / rate )
301+ else :
302+ # Initialize an array to contain the message bits plus the truncation zeros
303+ if code_type == 'rsc' :
304+ inbits = message_bits
305+ number_inbits = number_message_bits
306+ number_outbits = int ((number_inbits + total_memory )/ rate )
307+ else :
308+ number_inbits = number_message_bits + total_memory + total_memory % k
309+ inbits = np .zeros (number_inbits , 'int' )
310+ # Pad the input bits with M zeros (L-th terminated truncation)
311+ inbits [0 :number_message_bits ] = message_bits
312+ number_outbits = int (number_inbits / rate )
334313
335314 outbits = np .zeros (number_outbits , 'int' )
336315 p_outbits = np .zeros (int (number_outbits *
@@ -349,8 +328,7 @@ def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=No
349328 current_state = next_state_table [current_state ][current_input ]
350329 j += 1
351330
352- if code_type == 'rsc' :
353-
331+ if code_type == 'rsc' and termination == 'term' :
354332 term_bits = dec2bitarray (current_state , trellis .total_memory )
355333 term_bits = term_bits [::- 1 ]
356334 for i in range (trellis .total_memory ):
@@ -360,11 +338,12 @@ def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=No
360338 current_state = next_state_table [current_state ][current_input ]
361339 j += 1
362340
363- j = 0
364- for i in range (number_outbits ):
365- if puncture_matrix [0 ][i % np .size (puncture_matrix , 1 )] == 1 :
366- p_outbits [j ] = outbits [i ]
367- j = j + 1
341+ if puncture_matrix is not None :
342+ j = 0
343+ for i in range (number_outbits ):
344+ if puncture_matrix [0 ][i % np .size (puncture_matrix , 1 )] == 1 :
345+ p_outbits [j ] = outbits [i ]
346+ j = j + 1
368347
369348 return p_outbits
370349
@@ -474,32 +453,25 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
474453def viterbi_decode (coded_bits , trellis , tb_depth = None , decoding_type = 'hard' ):
475454 """
476455 Decodes a stream of convolutionally encoded bits using the Viterbi Algorithm
477-
478456 Parameters
479457 ----------
480458 coded_bits : 1D ndarray
481459 Stream of convolutionally encoded bits which are to be decoded.
482-
483460 generator_matrix : 2D ndarray of ints
484461 Generator matrix G(D) of the convolutional code using which the
485462 input bits are to be decoded.
486-
487463 M : 1D ndarray of ints
488464 Number of memory elements per input of the convolutional encoder.
489-
490465 tb_length : int
491466 Traceback depth (Typically set to 5*(M+1)).
492-
493467 decoding_type : str {'hard', 'unquantized'}
494468 The type of decoding to be used.
495469 'hard' option is used for hard inputs (bits) to the decoder, e.g., BSC channel.
496470 'unquantized' option is used for soft inputs (real numbers) to the decoder, e.g., BAWGN channel.
497-
498471 Returns
499472 -------
500473 decoded_bits : 1D ndarray
501474 Decoded bit stream.
502-
503475 References
504476 ----------
505477 .. [1] Todd K. Moon. Error Correction Coding: Mathematical Methods and
@@ -571,3 +543,49 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
571543 current_number_states = 1
572544
573545 return decoded_bits [0 :len (decoded_bits )- tb_depth - 1 ]
546+
547+ def puncturing (message , punct_vec ):
548+ '''
549+ Applying of the punctured procedure.
550+ Parameters
551+ ----------
552+ message: input message {0,1}
553+ punct_vec: puncturing vector {0,1}
554+ Returns
555+ -------
556+ punctured: output punctured vector {0,1}
557+ '''
558+ shift = 0
559+ N = len (punct_vec )
560+ punctured = []
561+ for idx , item in enumerate (message ):
562+ if punct_vec [idx - shift * N ] == 1 :
563+ punctured .append (item )
564+ if idx % N == 0 :
565+ shift = shift + 1
566+ return np .array (punctured )
567+
568+ def depuncturing (punctured , punct_vec , shouldbe ):
569+ '''
570+ Applying of the inserting zeros procedure.
571+ Parameters
572+ ----------
573+ punctured: input punctured message {0,1}
574+ punct_vec: puncturing vector {0,1}
575+ shouldbe: length of the initial message (before puncturing)
576+ Returns
577+ -------
578+ depunctured: output vector {0,1}
579+ '''
580+ shift = 0
581+ shift2 = 0
582+ N = len (punct_vec )
583+ depunctured = np .zeros ((shouldbe ,))
584+ for idx , item in enumerate (depunctured ):
585+ if punct_vec [idx - shift * N ] == 1 :
586+ depunctured [idx ] = float (punctured [idx - shift2 ])
587+ else :
588+ shift2 = shift2 + 1
589+ if idx % N == 0 :
590+ shift = shift + 1 ;
591+ return depunctured
0 commit comments