U2-Net: U Square Net을 이미지 배경 제거에 사용해봅니다. 

이미지 내에서 가장 시각적으로 두드러지거나 중요한 객체를 정확하게 분할(segmentation)해줍니다. 

최초작성 2025. 4. 28

0. conda 환경을 사용하여 진행합니다.

Visual Studio Code와 Miniconda를 사용한 Python 개발 환경 만들기( Windows, Ubuntu, WSL2)

https://webnautes.com/visual-studio-codewa-minicondareul-sayonghan-python-gaebal-hwangyeong-mandeulgi-windows-ubuntu-wsl2/ 

파이썬 가상환경을 생성합니다. 

$ conda create -n u2-net-test python=3.12

파이썬 가상환경을 활성화합니다.

$ conda activate u2-net-test 

1. 저장소를 다운로드합니다.

$ git clone https://github.com/NathanUA/U-2-Net.git

해당 디렉토리로 이동합니다.

$ cd U-2-Net

2.아래 링크에서 모델 u2net.pth을 다운로드하여 깃허브 저장소를 저장한 디렉토리에 복사해줍니다.

https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view?usp=sharing 

$ mkdir -p ./saved_models/u2net/

$ cp ~/다운로드/u2net.pth ./saved_models/u2net/

3. 필요한 패키지를 설치합니다.

$ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

$ pip3 install opencv-python

4. 현재 위치에 Visual Studio Code를 실행하고 다음 예제 코드를 실행해봅니다.

$ code .

탐색기의 폴더 이름에 가져가면 보이는 New 아이콘을 실행하여 탐색기에 파이썬 코드 파일을 추가하고 오른쪽 아래에 파이썬 확장 설치 물어보면 꼭 설치해야 합니다. 그런  다음 코드를 붙여넣고 오른쪽 위에 있는 실행 아이콘을 클릭하여 실행합니다.  소스코드에선 현재 폴더에 복사해둔  cat.png 이미지 파일을 입력으로 사용합니다.

import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import cv2
from PIL import Image

# U-2-Net 모델 불러오기
from model import U2NET

model_dir = 'saved_models/u2net/u2net.pth'  # 사전 훈련된 모델 경로

# 이미지 전처리 함수
def transform_image(image):
    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])
    return transform(image)

# 모델 출력 정규화 함수
def normalize_output(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d-mi)/(ma-mi)
    return dn

# 모델 로드 및 평가 모드 설정
def load_model():
    net = U2NET(3, 1)
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()
    return net

# 이미지 테스트 함수
def test_single_image(net, image_path):
    # 이미지 로드
    image = Image.open(image_path).convert('RGB')
    image_cv = cv2.imread(image_path)
    image_cv = cv2.resize(image_cv, (320, 320))
   
    inputs = transform_image(image)
   
    # 배치 차원 추가 및 모델에 입력
    inputs = inputs.unsqueeze(0)
    if torch.cuda.is_available():
        inputs = inputs.cuda()
   
    # 모델 추론
    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(inputs)
   
    # 결과 처리 (d1이 최종 출력)
    pred = d1[:, 0, :, :]
    pred = normalize_output(pred)
   
    # numpy 배열로 변환
    pred = pred.squeeze().cpu().numpy()
   
    # 마스크 생성
    mask = (pred * 255).astype(np.uint8)
   
    # 원본 이미지에 마스크 적용
    image_np = np.array(image.resize((320, 320)))
    mask_3channel = np.stack([mask, mask, mask], axis=2)
    result = (image_np * (mask_3channel / 255.0)).astype(np.uint8)
   
    # BGR로 변환 (OpenCV는 BGR 형식 사용)
    result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
   
    # 시각화 이미지 생성
    # 원본, 마스크, 결과를 가로로 합치기
    vis_width = 320 * 3
    vis_height = 320
    visualization = np.zeros((vis_height, vis_width, 3), dtype=np.uint8)
   
    # 원본 이미지
    visualization[:, 0:320, :] = image_cv
   
    # 마스크 (그레이스케일을 BGR로 변환)
    mask_bgr = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
    visualization[:, 320:640, :] = mask_bgr
   
    # 결과 이미지
    visualization[:, 640:960, :] = result_bgr
   
    # 텍스트 추가
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(visualization, 'Original Image', (10, 30), font, 0.7, (255, 255, 255), 2)
    cv2.putText(visualization, 'Predicted Mask', (330, 30), font, 0.7, (255, 255, 255), 2)
    cv2.putText(visualization, 'Segmentation Result', (650, 30), font, 0.7, (255, 255, 255), 2)
   
    # 화면에 표시
    cv2.imshow("U-2-Net Segmentation Results", visualization)
    cv2.waitKey(0# 키 입력 대기
   
    return pred, result



if __name__ == "__main__":

    # 모델 로드
    net = load_model()
    print("모델이 성공적으로 로드되었습니다.")
   
       
    mask, result = test_single_image(net, 'cat.png')

   
    cv2.destroyAllWindows()  # 모든 창 닫기

실행 결과입니다.

원본 이미지, 마스크 이미지 그리고 세그멘테이션 결과가 보여집니다. 세그멘테이션 결과는 원본 이미지에서 마스크 영역만 보여준 결과입니다.