딥러닝 모델을 사용하여 이미지를 그룹별로 묶어서 보여주는 PyQt5 예제 코드입니다.

최초작성 2025. 2. 21

다음 포스트에 나온대로 conda 환경을 구성후 하는게 좋습니다.

 

Visual Studio Code와 Miniconda를 사용한 Python 개발 환경 만들기( Windows, Ubuntu, WSL2)

https://webnautes.com/visual-studio-codewa-minicondareul-sayonghan-python-gaebal-hwangyeong-mandeulgi-windows-ubuntu-wsl2/ 

이제 테스트하기 위한 환경을 구성합니다.

conda create -n test python=3.10

conda activate test

cuda 가능하도록 파이토치를 설치합니다.

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

pyqt5 패키지를 설치합니다. 

pip install pyqt5

이제 다음 코드를 실행합니다.

import sys
import os
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
                          QHBoxLayout, QPushButton, QFileDialog, QScrollArea,
                          QLabel, QSizePolicy, QDesktopWidget, QFrame, QProgressBar)
from PyQt5.QtGui import QPixmap, QFont
from PyQt5.QtCore import Qt, QThread, pyqtSignal
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import numpy as np
from PIL import Image

class ImageProcessor:
    def __init__(self):
        # EfficientNet 모델 로드
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        self.model.eval()
        self.model = self.model.to(self.device)

        # 이미지 전처리
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def extract_features(self, image_path):
        try:
            # 이미지 로드 및 전처리
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image).unsqueeze(0).to(self.device)

            # 특징 추출
            with torch.no_grad():
                features = self.model.features(image)
                features = F.adaptive_avg_pool2d(features, (1, 1))
                features = features.squeeze().cpu().numpy()
               
                # 정규화
                features = features / np.linalg.norm(features)
               
            return features
        except Exception as e:
            print(f"Error extracting features from {image_path}: {str(e)}")
            return None

class ProcessThread(QThread):
    progressChanged = pyqtSignal(int)
    finished = pyqtSignal(object)
   
    def __init__(self, image_files):
        super().__init__()
        self.image_files = image_files
        self.processor = ImageProcessor()
       
    def run(self):
        try:
            # 특징 벡터 추출
            features = {}
            total_files = len(self.image_files)
           
            for i, img_path in enumerate(self.image_files):
                vector = self.processor.extract_features(img_path)
                if vector is not None:
                    features[img_path] = vector
                self.progressChanged.emit(int((i + 1) / total_files * 40))

            # 이미지 그룹화
            similarity_threshold = 0.92  # 유사도 임계값
            groups = []
            used_images = set()

            # 각 이미지에 대해
            for base_img in self.image_files:
                if base_img in used_images:
                    continue

                if base_img not in features:
                    continue

                base_vector = features[base_img]
                current_group = [(base_img, 1.0)]
                group_members = {base_img}
               
                # 다른 모든 이미지와의 유사도 계산
                for other_img in self.image_files:
                    if other_img in group_members or other_img in used_images:
                        continue
                       
                    if other_img not in features:
                        continue

                    other_vector = features[other_img]
                   
                    # 코사인 유사도 계산
                    similarity = float(np.dot(base_vector, other_vector))

                    # 유사도가 높은 이미지를 그룹에 추가
                    if similarity >= similarity_threshold:
                        # 기존 그룹 멤버들과의 평균 유사도 검사
                        group_similarities = [
                            float(np.dot(features[group_img], other_vector))
                            for group_img in group_members
                        ]
                        avg_group_similarity = sum(group_similarities) / len(group_similarities)
                       
                        if avg_group_similarity >= similarity_threshold:
                            current_group.append((other_img, similarity))
                            group_members.add(other_img)
                            used_images.add(other_img)

                if len(group_members) > 1:
                    current_group.sort(key=lambda x: x[1], reverse=True)
                    groups.append(current_group)
                    used_images.add(base_img)

                self.progressChanged.emit(40 + int(len(used_images) / total_files * 60))

            # 그룹 크기로 정렬
            groups.sort(key=len, reverse=True)
           
            self.progressChanged.emit(100)
            self.finished.emit(groups)
           
        except Exception as e:
            print(f"Error in processing thread: {str(e)}")
            import traceback
            print(traceback.format_exc())
            self.finished.emit([])


class SimilarImageViewer(QMainWindow):
    def __init__(self):
        super().__init__()
        self.initUI()
       
    def initUI(self):
        self.setWindowTitle("유사 이미지 그룹 뷰어")
       
        # 화면 크기 구하기
        screen = QDesktopWidget().screenGeometry()
        window_width = 1200
        window_height = 800
       
        # 창 크기와 위치 설정
        self.setGeometry(
            (screen.width() - window_width) // 2,
            (screen.height() - window_height) // 2,
            window_width,
            window_height
        )

        # 메인 위젯 설정
        main_widget = QWidget()
        self.setCentralWidget(main_widget)
        layout = QVBoxLayout(main_widget)
        layout.setContentsMargins(5, 5, 5, 5)
        layout.setSpacing(5)

        # 상단 영역
        top_layout = QVBoxLayout()
       
        # 버튼 영역
        button_layout = QHBoxLayout()
        self.select_button = QPushButton("디렉토리 선택")
        self.select_button.clicked.connect(self.select_directory)
        self.select_button.setFixedHeight(30)
        button_layout.addWidget(self.select_button)
        button_layout.addStretch()
        top_layout.addLayout(button_layout)
       
        # 프로그레스바
        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        top_layout.addWidget(self.progress_bar)
       
        layout.addLayout(top_layout)

        # 스크롤 영역
        scroll = QScrollArea()
        scroll.setWidgetResizable(True)
        scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        layout.addWidget(scroll)

        # 이미지 그룹들을 담을 컨테이너
        self.content_widget = QWidget()
        self.content_layout = QVBoxLayout(self.content_widget)
        self.content_layout.setAlignment(Qt.AlignTop)
        self.content_layout.setSpacing(20)
        scroll.setWidget(self.content_widget)

    def select_directory(self):
        directory = QFileDialog.getExistingDirectory(self, "디렉토리 선택")
        if directory:
            self.select_button.setEnabled(False)
            self.progress_bar.setVisible(True)
            self.progress_bar.setValue(0)
           
            # 이미지 목록 수집
            image_files = []
            for root, _, files in os.walk(directory):
                for file in files:
                    if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
                        image_files.append(os.path.join(root, file))
           
            if image_files:
                # 처리 스레드 시작
                self.process_thread = ProcessThread(image_files)
                self.process_thread.progressChanged.connect(self.update_progress)
                self.process_thread.finished.connect(self.show_image_groups)
                self.process_thread.start()

    def update_progress(self, value):
        self.progress_bar.setValue(value)

    def clear_layout(self):
        while self.content_layout.count():
            item = self.content_layout.takeAt(0)
            if item.widget():
                item.widget().deleteLater()

    def show_image_groups(self, groups):
        self.clear_layout()
       
        # 각 그룹을 한 줄에 표시
        for group in groups:
            if len(group) > 1# 2개 이상의 이미지가 있는 그룹만 표시
                # 그룹을 위한 행 위젯
                row_widget = QWidget()
                row_layout = QHBoxLayout(row_widget)
                row_layout.setAlignment(Qt.AlignLeft)
                row_layout.setSpacing(5)
               
                # 그룹의 각 이미지 표시
                for img_path, similarity in group:
                    container = QWidget()
                    container_layout = QVBoxLayout(container)
                    container_layout.setContentsMargins(0, 0, 0, 0)
                   
                    # 이미지
                    image_label = QLabel()
                    pixmap = QPixmap(img_path)
                    scaled_pixmap = pixmap.scaled(
                        150, 150,
                        Qt.KeepAspectRatio,
                        Qt.SmoothTransformation
                    )
                    image_label.setPixmap(scaled_pixmap)
                    container_layout.addWidget(image_label)
                   
                    # 파일명
                    filename = os.path.basename(img_path)
                    if len(filename) > 20:
                        filename = filename[:17] + "..."
                    filename_label = QLabel(filename)
                    filename_label.setAlignment(Qt.AlignCenter)
                    filename_label.setWordWrap(True)
                    small_font = QFont()
                    small_font.setPointSize(8)
                    filename_label.setFont(small_font)
                    container_layout.addWidget(filename_label)
                   
                    # 유사도
                    if similarity != 1.0# 기준 이미지가 아닌 경우에만 유사도 표시
                        similarity_label = QLabel(f"유사도: {similarity:.1%}")
                    else:
                        similarity_label = QLabel("기준 이미지")
                    similarity_label.setAlignment(Qt.AlignCenter)
                    similarity_label.setFont(small_font)
                    container_layout.addWidget(similarity_label)
                   
                    row_layout.addWidget(container)
               
                # 오른쪽 여백을 위한 스트레치 추가
                row_layout.addStretch()
               
                # 구분선 추가
                separator = QFrame()
                separator.setFrameShape(QFrame.HLine)
                separator.setFrameShadow(QFrame.Sunken)
               
                # 행과 구분선 추가
                self.content_layout.addWidget(row_widget)
                self.content_layout.addWidget(separator)
       
        # 아래쪽 여백을 위한 스트레치 추가
        self.content_layout.addStretch()
       
        # UI 상태 복원
        self.select_button.setEnabled(True)
        self.progress_bar.setVisible(False)

if __name__ == '__main__':
    app = QApplication(sys.argv)
    viewer = SimilarImageViewer()
    viewer.show()
    sys.exit(app.exec_())

실행 결과입니다. 디렉토리 선택을 눌러 디렉토리를 지정해주면 해당 디렉토리에 있는 이미지를 유사 이미지별로 묶어서 아래 스크린샷처럼 그룹별로 나누어 보여줍니다.