@@ -15,7 +15,6 @@ def unet_train():
1515 print (n )
1616 X_train , y_train = [], []
1717 for i in range (n ):
18- print ("正在读取第%d张图片" % i )
1918 img = cv2 .imread (path + 'train_image/%d.png' % i )
2019 label = cv2 .imread (path + 'train_label/%d.png' % i )
2120 X_train .append (img )
@@ -91,24 +90,21 @@ def Conv2dT_BN(x, filters, kernel_size, strides=(2, 2), padding='same'):
9190 model .summary ()
9291
9392 print ("开始训练u-net" )
94- model .fit (X_train , y_train , epochs = 100 , batch_size = 15 )#epochs和batch_size看个人情况调整,batch_size不要过大,否则内存容易溢出
95- #我11G显存也只能设置15-20左右,我训练最终loss降低至250左右,acc约95%左右
93+ model .fit (X_train , y_train , epochs = 100 , batch_size = 15 )
9694 model .save ('unet.h5' )
9795 print ('unet.h5保存成功!!!' )
9896
9997
10098def unet_predict (unet , img_src_path ):
101- img_src = cv2 .imdecode (np .fromfile (img_src_path , dtype = np .uint8 ), - 1 ) # 从中文路径读取时用
102- # img_src=cv2.imread(img_src_path)
99+ img_src = cv2 .imdecode (np .fromfile (img_src_path , dtype = np .uint8 ), - 1 )
103100 if img_src .shape != (512 , 512 , 3 ):
104- img_src = cv2 .resize (img_src , dsize = (512 , 512 ), interpolation = cv2 .INTER_AREA )[:, :, :3 ] # dsize=(宽度,高度),[:,:,:3]是防止图片为4通道图片,后续无法reshape
105- img_src = img_src .reshape (1 , 512 , 512 , 3 ) # 预测图片shape为(1,512,512,3)
106-
107- img_mask = unet .predict (img_src ) # 归一化除以255后进行预测
108- img_src = img_src .reshape (512 , 512 , 3 ) # 将原图reshape为3维
109- img_mask = img_mask .reshape (512 , 512 , 3 ) # 将预测后图片reshape为3维
110- img_mask = img_mask / np .max (img_mask ) * 255 # 归一化后乘以255
111- img_mask [:, :, 2 ] = img_mask [:, :, 1 ] = img_mask [:, :, 0 ] # 三个通道保持相同
112- img_mask = img_mask .astype (np .uint8 ) # 将img_mask类型转为int型
101+ img_src = cv2 .resize (img_src , dsize = (512 , 512 ), interpolation = cv2 .INTER_AREA )[:, :, :3 ]
102+ img_src = img_src .reshape (1 , 512 , 512 , 3 )
103+ img_mask = unet .predict (img_src )
104+ img_src = img_src .reshape (512 , 512 , 3 )
105+ img_mask = img_mask .reshape (512 , 512 , 3 )
106+ img_mask = img_mask / np .max (img_mask ) * 255
107+ img_mask [:, :, 2 ] = img_mask [:, :, 1 ] = img_mask [:, :, 0 ]
108+ img_mask = img_mask .astype (np .uint8 )
113109
114110 return img_src , img_mask
0 commit comments