1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| model = YOLO(r'D:\zlc\drone\drone-train\runs\segment\train\weights\best.pt')
image_folder = r"D:\zlc\drone\drone-train\data\land\test\images"
output_folder = r"D:\zlc\drone\drone-train\data\land\test\output" os.makedirs(output_folder, exist_ok=True) names = ['background','land']
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
def create_color_map(num_classes): color_map = {} for class_id in range(num_classes): color = tuple(np.random.randint(0, 255, size=3).tolist()) color_map[class_id] = color return color_map
num_classes = len(names) color_map = create_color_map(num_classes)
def process_images_in_folder(folder_path): for filename in os.listdir(folder_path): ext = os.path.splitext(filename)[1].lower() if ext not in valid_image_extensions: continue
image_path = os.path.join(folder_path, filename)
image = cv2.imread(image_path) if image is None: print(f"无法读取图像: {image_path}") continue
results = model.predict(source=image_path, conf=0.25, iou=0.45, imgsz=640,show_labels=False,show_boxes=False,show_conf=False)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_overlay = np.zeros_like(image_rgb)
for result in results: boxes = result.boxes.cpu().numpy() if hasattr(result.boxes, 'cpu') else result.boxes masks = result.masks.data.cpu().numpy() if hasattr(result, 'masks') and hasattr(result.masks, 'data') else None names = result.names
for i, (box, cls) in enumerate(zip(boxes.xyxy, boxes.cls)): if masks is not None and i < len(masks): mask = masks[i]
if not isinstance(mask, np.ndarray): print("掩码不是 NumPy 数组") continue
mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
class_id = int(cls) color = color_map.get(class_id, (0, 0, 0))
mask_overlay[mask_resized > 0.5] = color
alpha = 0.5 output_image = cv2.addWeighted(image_rgb, 1, mask_overlay, alpha, 0)
output_path = os.path.join(output_folder, filename) cv2.imwrite(output_path, output_image) print(f"已处理并保存: {output_path}")
|