import sys import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, QVBoxLayout, QHBoxLayout, QWidget, QFileDialog, QProgressBar) from PyQt5.QtGui import QPixmap, QImage from PyQt5.QtCore import Qt, QThread, pyqtSignal import torch import numpy as np import cv2 from romatch import roma_outdoor import warnings import time from functools import lru_cache
warnings.filterwarnings("ignore", category=UserWarning)
class MatcherThread(QThread): finished = pyqtSignal(dict) progress = pyqtSignal(str) def __init__(self, roma_model, image1_path, image2_path, device): super().__init__() self.roma_model = roma_model self.image1_path = image1_path self.image2_path = image2_path self.device = device
def run(self): try: self.progress.emit("Processing images...") result = self.process_images() self.finished.emit(result) except Exception as e: self.finished.emit({"error": str(e)})
def process_images(self): try: if torch.cuda.is_available(): torch.cuda.empty_cache()
# 이미지 캐시 및 전처리 img1 = self.preprocess_image(self.image1_path) img2 = self.preprocess_image(self.image2_path)
# ROMA 매칭 수행 self.progress.emit("Performing ROMA matching...") print("Performing ROMA matching...") warp, certainty = self.roma_model.match( self.image1_path, self.image2_path, device=self.device ) # 기본 샘플링 수행 matches, match_certainty = self.roma_model.sample(warp, certainty) # CUDA 텐서를 CPU로 이동 후 NumPy로 변환 if torch.is_tensor(matches): matches = matches.cpu().numpy() if torch.is_tensor(match_certainty): match_certainty = match_certainty.cpu().numpy() print(f"Initial matches: {len(matches)}") print(f"Certainty range: {match_certainty.min():.3f} - {match_certainty.max():.3f}") # 매칭 후 필터링 기준을 더 완화 if match_certainty is not None and len(match_certainty) > 0: # 가장 신뢰도가 높은 상위 500개 매칭점 선택 top_k = min(500, len(matches)) top_indices = np.argsort(match_certainty)[-top_k:] matches = matches[top_indices] match_certainty = match_certainty[top_indices] print(f"Selected top {len(matches)} matches")
# RANSAC을 사용한 기하학적 검증 (선택사항) if len(matches) >= 4: src_pts = matches[:, :2] dst_pts = matches[:, 2:] print("\nDebug - Before RANSAC:") print(f"Source points shape: {src_pts.shape}") print(f"First few source points: \n{src_pts[:5]}") print(f"Destination points shape: {dst_pts.shape}") print(f"First few destination points: \n{dst_pts[:5]}") try: H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0) if mask is not None: inliers = mask.ravel().astype(bool) matches = matches[inliers] match_certainty = match_certainty[inliers] print(f"Matches after RANSAC: {len(matches)}") print("Debug - After RANSAC filtering:") print(f"First few matches: \n{matches[:5]}") except Exception as e: print(f"RANSAC failed: {str(e)}")
except Exception as e: import traceback error_msg = f"Error in process_images: {str(e)}\n{traceback.format_exc()}" print("\nError in process_images:") print(error_msg) self.progress.emit(f"Error: {str(e)}") return { "img1": None, "img2": None, "matches": np.array([]), "certainty": np.array([]) }
if len(matches) == 0: print("No matches found after filtering") self.progress.emit("No matches found after filtering") return { "img1": img1, "img2": img2, "matches": np.array([]), "certainty": np.array([]) }
print(f"Final matches: {len(matches)}") return { "img1": img1, "img2": img2, "matches": matches, "certainty": match_certainty }
@staticmethod @lru_cache(maxsize=32) def preprocess_image(image_path): img = cv2.imread(image_path) if img is None: raise ValueError(f"Could not load image: {image_path}") return img # 원본 이미지 그대로 반환
class ImageMatcher(QMainWindow): def __init__(self): super().__init__() self.roma_model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.image1_path = None self.image2_path = None self.matcher_thread = None self.initUI()
def initUI(self): self.setWindowTitle('Image Matcher') main_widget = QWidget() self.setCentralWidget(main_widget) layout = QVBoxLayout(main_widget) images_layout = QHBoxLayout() left_layout = QVBoxLayout() self.image1_label = QLabel() self.image1_label.setFixedSize(400, 400) self.image1_label.setAlignment(Qt.AlignCenter) self.image1_label.setStyleSheet("border: 2px solid black") self.load_image1_btn = QPushButton('Load Image 1') self.load_image1_btn.clicked.connect(self.load_image1) left_layout.addWidget(self.image1_label) left_layout.addWidget(self.load_image1_btn) right_layout = QVBoxLayout() self.image2_label = QLabel() self.image2_label.setFixedSize(400, 400) self.image2_label.setAlignment(Qt.AlignCenter) self.image2_label.setStyleSheet("border: 2px solid black") self.load_image2_btn = QPushButton('Load Image 2') self.load_image2_btn.clicked.connect(self.load_image2) right_layout.addWidget(self.image2_label) right_layout.addWidget(self.load_image2_btn) images_layout.addLayout(left_layout) images_layout.addLayout(right_layout) self.result_label = QLabel() self.result_label.setFixedSize(800, 400) self.result_label.setAlignment(Qt.AlignCenter) self.result_label.setStyleSheet("border: 2px solid black") self.progress_bar = QProgressBar() self.progress_bar.setTextVisible(True) self.progress_bar.hide() self.status_label = QLabel() self.status_label.setAlignment(Qt.AlignCenter) self.match_btn = QPushButton('Match Images') self.match_btn.clicked.connect(self.initialize_and_match) self.match_btn.setEnabled(False) layout.addLayout(images_layout) layout.addWidget(self.result_label) layout.addWidget(self.progress_bar) layout.addWidget(self.status_label) layout.addWidget(self.match_btn) self.setGeometry(100, 100, 850, 900) self.show()
def initialize_model(self): if self.roma_model is None: try: status_msg = f"Initializing model on {self.device}..." print(status_msg) self.status_label.setText(status_msg) QApplication.processEvents() # 14의 배수로 해상도 설정 self.roma_model = roma_outdoor( device=self.device, coarse_res=322, # 14 * 23 ≈ 320 upsample_res=(644, 644) # 14 * 46 ≈ 640 ).to(self.device) success_msg = "Model initialized successfully" print(success_msg) self.status_label.setText(success_msg) return True except Exception as e: import traceback error_msg = f"Error initializing model: {str(e)}\n{traceback.format_exc()}" print("\nError in initialize_model:") print(error_msg) self.status_label.setText(f"Error initializing model: {str(e)}") return False return True
def initialize_and_match(self): if not self.initialize_model(): return self.match_btn.setEnabled(False) self.progress_bar.setMaximum(0) self.progress_bar.show() self.matcher_thread = MatcherThread( self.roma_model, self.image1_path, self.image2_path, self.device ) self.matcher_thread.finished.connect(self.handle_matching_result) self.matcher_thread.progress.connect(self.update_progress) self.matcher_thread.start()
def handle_matching_result(self, result): if "error" in result: error_msg = f"Error: {result['error']}" print(error_msg) self.status_label.setText(error_msg) else: self.visualize_matches(result) self.progress_bar.hide() self.match_btn.setEnabled(True) self.matcher_thread = None
def update_progress(self, message): print(message) self.status_label.setText(message)
def visualize_matches(self, result): try: img1 = result["img1"] img2 = result["img2"] matches = result["matches"] match_certainty = result["certainty"]
# 원본 이미지로 매칭 결과 시각화 matched_img = self.draw_matches( img1, img2, matches, matches[:, :2], # 원본 좌표 사용 matches[:, 2:], # 원본 좌표 사용 match_certainty ) # 결과 이미지를 UI 크기에 맞게 축소 display_height = self.result_label.height() display_width = self.result_label.width() # 비율 유지하면서 크기 조정 img_height, img_width = matched_img.shape[:2] aspect_ratio = img_width / img_height if img_width / display_width > img_height / display_height: new_width = display_width new_height = int(display_width / aspect_ratio) else: new_height = display_height new_width = int(display_height * aspect_ratio) matched_img_resized = cv2.resize(matched_img, (new_width, new_height), interpolation=cv2.INTER_AREA) # 결과 표시 matched_img_rgb = cv2.cvtColor(matched_img_resized, cv2.COLOR_BGR2RGB) height, width = matched_img_rgb.shape[:2] bytes_per_line = 3 * width q_img = QImage(matched_img_rgb.tobytes(), width, height, bytes_per_line, QImage.Format_RGB888) pixmap = QPixmap.fromImage(q_img) self.result_label.setPixmap(pixmap)
match_count = len(matches) if match_count > 0: avg_certainty = float(np.nanmean(match_certainty)) if match_certainty is not None else 0 status_msg = f"Matching completed - Found {match_count} matches (Avg certainty: {avg_certainty:.2f})" else: status_msg = "No reliable matches found. Try adjusting the matching parameters or using different images." print(status_msg) self.status_label.setText(status_msg)
except Exception as e: import traceback error_msg = f"Error in visualization: {str(e)}\n{traceback.format_exc()}" print("\nError in visualization:") print(error_msg) self.status_label.setText(f"Error in visualization: {str(e)}")
def draw_matches(self, img1, img2, matches, kpts1, kpts2, match_certainty=None): # 결과 이미지 생성 new_width = img1.shape[1] + img2.shape[1] new_height = max(img1.shape[0], img2.shape[0]) out = np.zeros((new_height, new_width, 3), dtype=np.uint8) # 두 이미지 붙이기 out[:img1.shape[0], :img1.shape[1]] = img1 out[:img2.shape[0], img1.shape[1]:] = img2
# 신뢰도에 따른 색상 설정 if match_certainty is not None: colors = [] for cert in match_certainty: if cert < 0.7: color = (0, 0, 255) # 빨강 (BGR) elif cert < 0.85: color = (0, 255, 0) # 초록 else: color = (255, 0, 0) # 파랑 colors.append(color) else: colors = [(0, 255, 0)] * len(kpts1) # 모두 초록색으로
print(f"Drawing {len(matches)} match lines...") # 매칭 라인과 점 그리기 for i in range(len(matches)): # 정규화된 좌표를 실제 이미지 좌표로 변환 x1 = int((matches[i][0] + 1) * img1.shape[1] / 2) y1 = int((matches[i][1] + 1) * img1.shape[0] / 2) x2 = int((matches[i][2] + 1) * img2.shape[1] / 2) y2 = int((matches[i][3] + 1) * img2.shape[0] / 2) # 좌표 디버깅 print(f"\nMatch {i}:") print(f"Original normalized - P1:({matches[i][0]:.2f}, {matches[i][1]:.2f}) P2:({matches[i][2]:.2f}, {matches[i][3]:.2f})") print(f"Image coordinates - P1:({x1}, {y1}) P2:({x2}, {y2})") # 두 번째 이미지의 x 좌표 조정 x2 += img1.shape[1] # 매칭 라인 그리기 (두께를 2로 증가) cv2.line(out, (x1, y1), (x2, y2), colors[i], 2, cv2.LINE_AA) # 매칭점 그리기 cv2.circle(out, (x1, y1), 4, colors[i], -1, cv2.LINE_AA) cv2.circle(out, (x2, y2), 4, colors[i], -1, cv2.LINE_AA)
print("Match visualization completed") return out
def load_image1(self): self.image1_path = self.load_image(self.image1_label) if self.image1_path: print(f"Loaded image 1: {self.image1_path}") self.check_enable_match()
def load_image2(self): self.image2_path = self.load_image(self.image2_label) if self.image2_path: print(f"Loaded image 2: {self.image2_path}") self.check_enable_match()
def load_image(self, label): try: file_name, _ = QFileDialog.getOpenFileName( self, "Open Image File", "", "Images (*.png *.xpm *.jpg *.bmp)" ) if file_name and os.path.exists(file_name): pixmap = QPixmap(file_name) if not pixmap.isNull(): scaled_pixmap = pixmap.scaled(label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) label.setPixmap(scaled_pixmap) return file_name return None except Exception as e: error_msg = f"Error loading image: {str(e)}" print(error_msg) return None
def check_enable_match(self): should_enable = bool(self.image1_path) and bool(self.image2_path) self.match_btn.setEnabled(should_enable) if should_enable: print("Match button enabled - Ready to process images") else: print("Match button disabled - Please load both images")
if __name__ == '__main__': app = QApplication(sys.argv) ex = ImageMatcher() sys.exit(app.exec_()) |
Member discussion