RoMa를 사용하여 이미지 매칭해봤습니다. RoMa의 깃허브 저장소는 https://github.com/Parskatt/RoMa 입니다.

2025. 2. 20  최초작성

다음 포스트에 나온대로 conda 환경을 구성후 하는게 좋습니다.

Visual Studio Code와 Miniconda를 사용한 Python 개발 환경 만들기( Windows, Ubuntu, WSL2)

https://webnautes.com/visual-studio-codewa-minicondareul-sayonghan-python-gaebal-hwangyeong-mandeulgi-windows-ubuntu-wsl2/

이제 XFeat를 테스트하기 위한 환경을 구성합니다.

conda create -n roma python=3.10

conda activate roma

git clone https://github.com/Parskatt/RoMa.git

cd RoMa

pip install -e .

필요한 파이썬 패키지가 설치된 후, RoMa를 사용하기 위해 필요한 romatch 패키지를 사용할 준비가 됩니다.

이제 RoMa 안에 있는 romatch 폴더를 원하는 곳으로 이동하여 사용하면 됩니다. 

추가 패키지를 설치합니다. 

pip install pyqt5

포스트 아래쪽에 있는 코드를 실행해보니  처음 실행할 땐 필요한 모델을 다운로드 하기 때문에 좀 시간이 걸립니다. 

테스트는 다음 2장의 사진을 사용했습니다.

https://github.com/Parskatt/RoMa/blob/main/assets/toronto_A.jpg

https://github.com/Parskatt/RoMa/blob/main/assets/toronto_B.jpg 

CPU로 동작해서인지 너무 느려서 CUDA 지원하는 Pytorch를 설치했습니다.

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

매칭 결과 나오기까지 속도가 빨라졌습니다.

실행결과 입니다. Load Image 1 버튼과 Load Image 2 버튼을 눌러 2장의 이미지를 차례로 선택한 후, Match Images 버튼을 누르고 잠시 기다리면 아래 스크린샷처럼 매칭 결과가 보입니다. 

전체 코드입니다. 

import sys
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel,
                            QVBoxLayout, QHBoxLayout, QWidget, QFileDialog, QProgressBar)
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import Qt, QThread, pyqtSignal
import torch
import numpy as np
import cv2
from romatch import roma_outdoor
import warnings
import time
from functools import lru_cache

warnings.filterwarnings("ignore", category=UserWarning)

class MatcherThread(QThread):
    finished = pyqtSignal(dict)
    progress = pyqtSignal(str)
   
    def __init__(self, roma_model, image1_path, image2_path, device):
        super().__init__()
        self.roma_model = roma_model
        self.image1_path = image1_path
        self.image2_path = image2_path
        self.device = device

    def run(self):
        try:
            self.progress.emit("Processing images...")
            result = self.process_images()
            self.finished.emit(result)
        except Exception as e:
            self.finished.emit({"error": str(e)})

    def process_images(self):
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # 이미지 캐시 및 전처리
            img1 = self.preprocess_image(self.image1_path)
            img2 = self.preprocess_image(self.image2_path)

            # ROMA 매칭 수행
            self.progress.emit("Performing ROMA matching...")
            print("Performing ROMA matching...")
           
            warp, certainty = self.roma_model.match(
                self.image1_path,
                self.image2_path,
                device=self.device
            )
           
            # 기본 샘플링 수행
            matches, match_certainty = self.roma_model.sample(warp, certainty)
           
            # CUDA 텐서를 CPU로 이동 후 NumPy로 변환
            if torch.is_tensor(matches):
                matches = matches.cpu().numpy()
            if torch.is_tensor(match_certainty):
                match_certainty = match_certainty.cpu().numpy()
           
            print(f"Initial matches: {len(matches)}")
            print(f"Certainty range: {match_certainty.min():.3f} - {match_certainty.max():.3f}")
           
            # 매칭 후 필터링 기준을 더 완화
            if match_certainty is not None and len(match_certainty) > 0:
                # 가장 신뢰도가 높은 상위 500개 매칭점 선택
                top_k = min(500, len(matches))
                top_indices = np.argsort(match_certainty)[-top_k:]
                matches = matches[top_indices]
                match_certainty = match_certainty[top_indices]
                print(f"Selected top {len(matches)} matches")

                # RANSAC을 사용한 기하학적 검증 (선택사항)
                if len(matches) >= 4:
                    src_pts = matches[:, :2]
                    dst_pts = matches[:, 2:]
                    print("\nDebug - Before RANSAC:")
                    print(f"Source points shape: {src_pts.shape}")
                    print(f"First few source points: \n{src_pts[:5]}")
                    print(f"Destination points shape: {dst_pts.shape}")
                    print(f"First few destination points: \n{dst_pts[:5]}")
                   
                    try:
                        H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0)
                        if mask is not None:
                            inliers = mask.ravel().astype(bool)
                            matches = matches[inliers]
                            match_certainty = match_certainty[inliers]
                            print(f"Matches after RANSAC: {len(matches)}")
                            print("Debug - After RANSAC filtering:")
                            print(f"First few matches: \n{matches[:5]}")
                    except Exception as e:
                        print(f"RANSAC failed: {str(e)}")

        except Exception as e:
            import traceback
            error_msg = f"Error in process_images: {str(e)}\n{traceback.format_exc()}"
            print("\nError in process_images:")
            print(error_msg)
            self.progress.emit(f"Error: {str(e)}")
            return {
                "img1": None,
                "img2": None,
                "matches": np.array([]),
                "certainty": np.array([])
            }

        if len(matches) == 0:
            print("No matches found after filtering")
            self.progress.emit("No matches found after filtering")
            return {
                "img1": img1,
                "img2": img2,
                "matches": np.array([]),
                "certainty": np.array([])
            }

        print(f"Final matches: {len(matches)}")
        return {
            "img1": img1,
            "img2": img2,
            "matches": matches,
            "certainty": match_certainty
        }

    @staticmethod
    @lru_cache(maxsize=32)
    def preprocess_image(image_path):
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Could not load image: {image_path}")
        return img  # 원본 이미지 그대로 반환

class ImageMatcher(QMainWindow):
    def __init__(self):
        super().__init__()
        self.roma_model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.image1_path = None
        self.image2_path = None
        self.matcher_thread = None
        self.initUI()

    def initUI(self):
        self.setWindowTitle('Image Matcher')
        main_widget = QWidget()
        self.setCentralWidget(main_widget)
        layout = QVBoxLayout(main_widget)
       
        images_layout = QHBoxLayout()
        left_layout = QVBoxLayout()
        self.image1_label = QLabel()
        self.image1_label.setFixedSize(400, 400)
        self.image1_label.setAlignment(Qt.AlignCenter)
        self.image1_label.setStyleSheet("border: 2px solid black")
        self.load_image1_btn = QPushButton('Load Image 1')
        self.load_image1_btn.clicked.connect(self.load_image1)
        left_layout.addWidget(self.image1_label)
        left_layout.addWidget(self.load_image1_btn)
       
        right_layout = QVBoxLayout()
        self.image2_label = QLabel()
        self.image2_label.setFixedSize(400, 400)
        self.image2_label.setAlignment(Qt.AlignCenter)
        self.image2_label.setStyleSheet("border: 2px solid black")
        self.load_image2_btn = QPushButton('Load Image 2')
        self.load_image2_btn.clicked.connect(self.load_image2)
        right_layout.addWidget(self.image2_label)
        right_layout.addWidget(self.load_image2_btn)
       
        images_layout.addLayout(left_layout)
        images_layout.addLayout(right_layout)
       
        self.result_label = QLabel()
        self.result_label.setFixedSize(800, 400)
        self.result_label.setAlignment(Qt.AlignCenter)
        self.result_label.setStyleSheet("border: 2px solid black")
       
        self.progress_bar = QProgressBar()
        self.progress_bar.setTextVisible(True)
        self.progress_bar.hide()
       
        self.status_label = QLabel()
        self.status_label.setAlignment(Qt.AlignCenter)
       
        self.match_btn = QPushButton('Match Images')
        self.match_btn.clicked.connect(self.initialize_and_match)
        self.match_btn.setEnabled(False)
       
        layout.addLayout(images_layout)
        layout.addWidget(self.result_label)
        layout.addWidget(self.progress_bar)
        layout.addWidget(self.status_label)
        layout.addWidget(self.match_btn)
       
        self.setGeometry(100, 100, 850, 900)
        self.show()

    def initialize_model(self):
        if self.roma_model is None:
            try:
                status_msg = f"Initializing model on {self.device}..."
                print(status_msg)
                self.status_label.setText(status_msg)
                QApplication.processEvents()
               
                # 14의 배수로 해상도 설정
                self.roma_model = roma_outdoor(
                    device=self.device,
                    coarse_res=322# 14 * 23 ≈ 320
                    upsample_res=(644, 644# 14 * 46 ≈ 640
                ).to(self.device)
               
                success_msg = "Model initialized successfully"
                print(success_msg)
                self.status_label.setText(success_msg)
                return True
            except Exception as e:
                import traceback
                error_msg = f"Error initializing model: {str(e)}\n{traceback.format_exc()}"
                print("\nError in initialize_model:")
                print(error_msg)
                self.status_label.setText(f"Error initializing model: {str(e)}")
                return False
        return True

    def initialize_and_match(self):
        if not self.initialize_model():
            return
       
        self.match_btn.setEnabled(False)
        self.progress_bar.setMaximum(0)
        self.progress_bar.show()
       
        self.matcher_thread = MatcherThread(
            self.roma_model,
            self.image1_path,
            self.image2_path,
            self.device
        )
        self.matcher_thread.finished.connect(self.handle_matching_result)
        self.matcher_thread.progress.connect(self.update_progress)
        self.matcher_thread.start()

    def handle_matching_result(self, result):
        if "error" in result:
            error_msg = f"Error: {result['error']}"
            print(error_msg)
            self.status_label.setText(error_msg)
        else:
            self.visualize_matches(result)
       
        self.progress_bar.hide()
        self.match_btn.setEnabled(True)
        self.matcher_thread = None

    def update_progress(self, message):
        print(message)
        self.status_label.setText(message)

    def visualize_matches(self, result):
        try:
            img1 = result["img1"]
            img2 = result["img2"]
            matches = result["matches"]
            match_certainty = result["certainty"]

            # 원본 이미지로 매칭 결과 시각화
            matched_img = self.draw_matches(
                img1,
                img2,
                matches,
                matches[:, :2],  # 원본 좌표 사용
                matches[:, 2:],  # 원본 좌표 사용
                match_certainty
            )
           
            # 결과 이미지를 UI 크기에 맞게 축소
            display_height = self.result_label.height()
            display_width = self.result_label.width()
           
            # 비율 유지하면서 크기 조정
            img_height, img_width = matched_img.shape[:2]
            aspect_ratio = img_width / img_height
           
            if img_width / display_width > img_height / display_height:
                new_width = display_width
                new_height = int(display_width / aspect_ratio)
            else:
                new_height = display_height
                new_width = int(display_height * aspect_ratio)
               
            matched_img_resized = cv2.resize(matched_img, (new_width, new_height),
                                        interpolation=cv2.INTER_AREA)
           
            # 결과 표시
            matched_img_rgb = cv2.cvtColor(matched_img_resized, cv2.COLOR_BGR2RGB)
            height, width = matched_img_rgb.shape[:2]
            bytes_per_line = 3 * width
            q_img = QImage(matched_img_rgb.tobytes(), width, height,
                        bytes_per_line, QImage.Format_RGB888)
            pixmap = QPixmap.fromImage(q_img)
            self.result_label.setPixmap(pixmap)

            match_count = len(matches)
            if match_count > 0:
                avg_certainty = float(np.nanmean(match_certainty)) if match_certainty is not None else 0
                status_msg = f"Matching completed - Found {match_count} matches (Avg certainty: {avg_certainty:.2f})"
            else:
                status_msg = "No reliable matches found. Try adjusting the matching parameters or using different images."
           
            print(status_msg)
            self.status_label.setText(status_msg)

        except Exception as e:
            import traceback
            error_msg = f"Error in visualization: {str(e)}\n{traceback.format_exc()}"
            print("\nError in visualization:")
            print(error_msg)
            self.status_label.setText(f"Error in visualization: {str(e)}")

    def draw_matches(self, img1, img2, matches, kpts1, kpts2, match_certainty=None):
        # 결과 이미지 생성
        new_width = img1.shape[1] + img2.shape[1]
        new_height = max(img1.shape[0], img2.shape[0])
        out = np.zeros((new_height, new_width, 3), dtype=np.uint8)
       
        # 두 이미지 붙이기
        out[:img1.shape[0], :img1.shape[1]] = img1
        out[:img2.shape[0], img1.shape[1]:] = img2

        # 신뢰도에 따른 색상 설정
        if match_certainty is not None:
            colors = []
            for cert in match_certainty:
                if cert < 0.7:
                    color = (0, 0, 255# 빨강 (BGR)
                elif cert < 0.85:
                    color = (0, 255, 0# 초록
                else:
                    color = (255, 0, 0# 파랑
                colors.append(color)
        else:
            colors = [(0, 255, 0)] * len(kpts1)  # 모두 초록색으로

        print(f"Drawing {len(matches)} match lines...")
       
        # 매칭 라인과 점 그리기
        for i in range(len(matches)):
            # 정규화된 좌표를 실제 이미지 좌표로 변환
            x1 = int((matches[i][0] + 1) * img1.shape[1] / 2)
            y1 = int((matches[i][1] + 1) * img1.shape[0] / 2)
            x2 = int((matches[i][2] + 1) * img2.shape[1] / 2)
            y2 = int((matches[i][3] + 1) * img2.shape[0] / 2)
           
            # 좌표 디버깅
            print(f"\nMatch {i}:")
            print(f"Original normalized - P1:({matches[i][0]:.2f}, {matches[i][1]:.2f}) P2:({matches[i][2]:.2f}, {matches[i][3]:.2f})")
            print(f"Image coordinates - P1:({x1}, {y1}) P2:({x2}, {y2})")
           
            # 두 번째 이미지의 x 좌표 조정
            x2 += img1.shape[1]
           
            # 매칭 라인 그리기 (두께를 2로 증가)
            cv2.line(out, (x1, y1), (x2, y2), colors[i], 2, cv2.LINE_AA)
           
            # 매칭점 그리기
            cv2.circle(out, (x1, y1), 4, colors[i], -1, cv2.LINE_AA)
            cv2.circle(out, (x2, y2), 4, colors[i], -1, cv2.LINE_AA)

        print("Match visualization completed")
        return out

    def load_image1(self):
        self.image1_path = self.load_image(self.image1_label)
        if self.image1_path:
            print(f"Loaded image 1: {self.image1_path}")
        self.check_enable_match()

    def load_image2(self):
        self.image2_path = self.load_image(self.image2_label)
        if self.image2_path:
            print(f"Loaded image 2: {self.image2_path}")
        self.check_enable_match()

    def load_image(self, label):
        try:
            file_name, _ = QFileDialog.getOpenFileName(
                self,
                "Open Image File",
                "",
                "Images (*.png *.xpm *.jpg *.bmp)"
            )
            if file_name and os.path.exists(file_name):
                pixmap = QPixmap(file_name)
                if not pixmap.isNull():
                    scaled_pixmap = pixmap.scaled(label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
                    label.setPixmap(scaled_pixmap)
                    return file_name
            return None
        except Exception as e:
            error_msg = f"Error loading image: {str(e)}"
            print(error_msg)
            return None

    def check_enable_match(self):
        should_enable = bool(self.image1_path) and bool(self.image2_path)
        self.match_btn.setEnabled(should_enable)
        if should_enable:
            print("Match button enabled - Ready to process images")
        else:
            print("Match button disabled - Please load both images")

if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = ImageMatcher()
    sys.exit(app.exec_())