Skip to content

Commit 37b71d4

Browse files
authored
Merge pull request #1029 from warshallrho/master
[fix bug] (1) add (non)trainable_weights in Layerlist (2) remove redu…
2 parents d842262 + c396f24 commit 37b71d4

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ To release a new version, please update the changelog as followed:
110110
- Set allow_pickle=True in np.load() (#PR 1021)
111111
- Remove `private_method` decorator (#PR 1025)
112112
- Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `ModelLayer` (#PR 1026)
113+
- Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `LayerList` (#PR 1029)
114+
- remove redundant parts in `model.all_layers` (#PR 1029)
113115

114116
### Removed
115117

@@ -119,7 +121,7 @@ To release a new version, please update the changelog as followed:
119121

120122
- @zsdonghao
121123
- @ChrisWu1997: #1010 #1015 #1025
122-
- @warshallrho: #1017 #1021 #1026
124+
- @warshallrho: #1017 #1021 #1026 #1029
123125
- @ArnoldLIULJ: #1023
124126
- @JingqingZ: #1023
125127

tensorlayer/layers/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ def __init__(self, layers, name=None):
602602

603603
is_built = True
604604
for layer in self.layers:
605+
self._trainable_weights.extend(layer.trainable_weights)
606+
self._nontrainable_weights.extend(layer.nontrainable_weights)
605607
if layer._built is False:
606608
is_built = False
607609
if layer._built and layer.all_weights is not None:

tensorlayer/models/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,12 @@ def all_layers(self):
358358
attr_list.remove("all_weights")
359359
attr_list.remove("trainable_weights")
360360
attr_list.remove("nontrainable_weights")
361+
attr_list.remove("_all_weights")
362+
attr_list.remove("_trainable_weights")
363+
attr_list.remove("_nontrainable_weights")
361364
attr_list.remove("all_layers")
365+
attr_list.remove("_all_layers")
366+
attr_list.remove("n_weights")
362367
for idx, attr in enumerate(attr_list):
363368
try:
364369
if isinstance(getattr(self, attr), Layer):

0 commit comments

Comments
 (0)