import sys import os from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QScrollArea, QLabel, QSizePolicy, QDesktopWidget, QFrame, QProgressBar) from PyQt5.QtGui import QPixmap, QFont from PyQt5.QtCore import Qt, QThread, pyqtSignal import torch import torch.nn.functional as F import torchvision.transforms as transforms from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights import numpy as np from PIL import Image
class ImageProcessor: def __init__(self): # EfficientNet 모델 로드 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT) self.model.eval() self.model = self.model.to(self.device)
# 이미지 전처리 self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
def extract_features(self, image_path): try: # 이미지 로드 및 전처리 image = Image.open(image_path).convert('RGB') image = self.transform(image).unsqueeze(0).to(self.device)
# 특징 추출 with torch.no_grad(): features = self.model.features(image) features = F.adaptive_avg_pool2d(features, (1, 1)) features = features.squeeze().cpu().numpy() # 정규화 features = features / np.linalg.norm(features) return features except Exception as e: print(f"Error extracting features from {image_path}: {str(e)}") return None
class ProcessThread(QThread): progressChanged = pyqtSignal(int) finished = pyqtSignal(object) def __init__(self, image_files): super().__init__() self.image_files = image_files self.processor = ImageProcessor() def run(self): try: # 특징 벡터 추출 features = {} total_files = len(self.image_files) for i, img_path in enumerate(self.image_files): vector = self.processor.extract_features(img_path) if vector is not None: features[img_path] = vector self.progressChanged.emit(int((i + 1) / total_files * 40))
# 이미지 그룹화 similarity_threshold = 0.92 # 유사도 임계값 groups = [] used_images = set()
# 각 이미지에 대해 for base_img in self.image_files: if base_img in used_images: continue
if base_img not in features: continue
base_vector = features[base_img] current_group = [(base_img, 1.0)] group_members = {base_img} # 다른 모든 이미지와의 유사도 계산 for other_img in self.image_files: if other_img in group_members or other_img in used_images: continue if other_img not in features: continue
other_vector = features[other_img] # 코사인 유사도 계산 similarity = float(np.dot(base_vector, other_vector))
# 유사도가 높은 이미지를 그룹에 추가 if similarity >= similarity_threshold: # 기존 그룹 멤버들과의 평균 유사도 검사 group_similarities = [ float(np.dot(features[group_img], other_vector)) for group_img in group_members ] avg_group_similarity = sum(group_similarities) / len(group_similarities) if avg_group_similarity >= similarity_threshold: current_group.append((other_img, similarity)) group_members.add(other_img) used_images.add(other_img)
if len(group_members) > 1: current_group.sort(key=lambda x: x[1], reverse=True) groups.append(current_group) used_images.add(base_img)
self.progressChanged.emit(40 + int(len(used_images) / total_files * 60))
# 그룹 크기로 정렬 groups.sort(key=len, reverse=True) self.progressChanged.emit(100) self.finished.emit(groups) except Exception as e: print(f"Error in processing thread: {str(e)}") import traceback print(traceback.format_exc()) self.finished.emit([])
class SimilarImageViewer(QMainWindow): def __init__(self): super().__init__() self.initUI() def initUI(self): self.setWindowTitle("유사 이미지 그룹 뷰어") # 화면 크기 구하기 screen = QDesktopWidget().screenGeometry() window_width = 1200 window_height = 800 # 창 크기와 위치 설정 self.setGeometry( (screen.width() - window_width) // 2, (screen.height() - window_height) // 2, window_width, window_height )
# 메인 위젯 설정 main_widget = QWidget() self.setCentralWidget(main_widget) layout = QVBoxLayout(main_widget) layout.setContentsMargins(5, 5, 5, 5) layout.setSpacing(5)
# 상단 영역 top_layout = QVBoxLayout() # 버튼 영역 button_layout = QHBoxLayout() self.select_button = QPushButton("디렉토리 선택") self.select_button.clicked.connect(self.select_directory) self.select_button.setFixedHeight(30) button_layout.addWidget(self.select_button) button_layout.addStretch() top_layout.addLayout(button_layout) # 프로그레스바 self.progress_bar = QProgressBar() self.progress_bar.setVisible(False) top_layout.addWidget(self.progress_bar) layout.addLayout(top_layout)
# 스크롤 영역 scroll = QScrollArea() scroll.setWidgetResizable(True) scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) layout.addWidget(scroll)
# 이미지 그룹들을 담을 컨테이너 self.content_widget = QWidget() self.content_layout = QVBoxLayout(self.content_widget) self.content_layout.setAlignment(Qt.AlignTop) self.content_layout.setSpacing(20) scroll.setWidget(self.content_widget)
def select_directory(self): directory = QFileDialog.getExistingDirectory(self, "디렉토리 선택") if directory: self.select_button.setEnabled(False) self.progress_bar.setVisible(True) self.progress_bar.setValue(0) # 이미지 목록 수집 image_files = [] for root, _, files in os.walk(directory): for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): image_files.append(os.path.join(root, file)) if image_files: # 처리 스레드 시작 self.process_thread = ProcessThread(image_files) self.process_thread.progressChanged.connect(self.update_progress) self.process_thread.finished.connect(self.show_image_groups) self.process_thread.start()
def update_progress(self, value): self.progress_bar.setValue(value)
def clear_layout(self): while self.content_layout.count(): item = self.content_layout.takeAt(0) if item.widget(): item.widget().deleteLater()
def show_image_groups(self, groups): self.clear_layout() # 각 그룹을 한 줄에 표시 for group in groups: if len(group) > 1: # 2개 이상의 이미지가 있는 그룹만 표시 # 그룹을 위한 행 위젯 row_widget = QWidget() row_layout = QHBoxLayout(row_widget) row_layout.setAlignment(Qt.AlignLeft) row_layout.setSpacing(5) # 그룹의 각 이미지 표시 for img_path, similarity in group: container = QWidget() container_layout = QVBoxLayout(container) container_layout.setContentsMargins(0, 0, 0, 0) # 이미지 image_label = QLabel() pixmap = QPixmap(img_path) scaled_pixmap = pixmap.scaled( 150, 150, Qt.KeepAspectRatio, Qt.SmoothTransformation ) image_label.setPixmap(scaled_pixmap) container_layout.addWidget(image_label) # 파일명 filename = os.path.basename(img_path) if len(filename) > 20: filename = filename[:17] + "..." filename_label = QLabel(filename) filename_label.setAlignment(Qt.AlignCenter) filename_label.setWordWrap(True) small_font = QFont() small_font.setPointSize(8) filename_label.setFont(small_font) container_layout.addWidget(filename_label) # 유사도 if similarity != 1.0: # 기준 이미지가 아닌 경우에만 유사도 표시 similarity_label = QLabel(f"유사도: {similarity:.1%}") else: similarity_label = QLabel("기준 이미지") similarity_label.setAlignment(Qt.AlignCenter) similarity_label.setFont(small_font) container_layout.addWidget(similarity_label) row_layout.addWidget(container) # 오른쪽 여백을 위한 스트레치 추가 row_layout.addStretch() # 구분선 추가 separator = QFrame() separator.setFrameShape(QFrame.HLine) separator.setFrameShadow(QFrame.Sunken) # 행과 구분선 추가 self.content_layout.addWidget(row_widget) self.content_layout.addWidget(separator) # 아래쪽 여백을 위한 스트레치 추가 self.content_layout.addStretch() # UI 상태 복원 self.select_button.setEnabled(True) self.progress_bar.setVisible(False)
if __name__ == '__main__': app = QApplication(sys.argv) viewer = SimilarImageViewer() viewer.show() sys.exit(app.exec_()) |
Member discussion