Skip to content

Commit 994b926

Browse files
committed
TL model - num of weights
1 parent ac9d43b commit 994b926

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tensorlayer/models/core.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,24 @@ def all_weights(self):
431431

432432
return self._all_weights.copy()
433433

434+
@property
435+
def n_weights(self):
436+
"""Return the number of weights (parameters) in this network."""
437+
n_weights = 0
438+
for i, w in enumerate(self.all_weights):
439+
n = 1
440+
# for s in p.eval().shape:
441+
for s in w.get_shape():
442+
try:
443+
s = int(s)
444+
except:
445+
s = 1
446+
if s:
447+
n = n * s
448+
n_weights = n_weights + n
449+
# print("num of weights (parameters) %d" % n_weights)
450+
return n_weights
451+
434452
@property
435453
def config(self):
436454
if self._config is not None and len(self._config) > 0:

0 commit comments

Comments
 (0)