본문 바로가기
Studies & Courses/Machine Learning & Vision

[Object Detection] Tensorflow Hub 활용하기 (inception resnet V2)

by Air’s Big Data 2021. 3. 31.

Tensorflow Hub 에 있는 object detection model을 어떻게 사용하는지 알아보기 위해 간단한 구현을 하려고 합니다. 해당 코드는 CoLab에서도 확인 가능합니다. 

  - Tensorflow Hub에서 object detection model 찾아보기

  - 나의 workspace에 models load하기

  - Inference를 위해 image를 preprocess하기

  - models에 inference하고 output을 inspect하기


Imports

import tensorflow as tf
import tensorflow_hub as hub
from PIL import Image
from PIL import ImageOps
import tempfile
from six.moves.urllib.request import urlopen
from six import BytesIO

Tensorflow Hub에서 model 다운로드하기

Tensorflow Hub는 재사용이 가능한 trained machine learning models의 repository입니다.

 - 카테고리별 확인은 여기에서 확인이 가능합니다.

 - 이 실험을 위해서는 image object detection subcategory 를 확인합니다.

 - 모델의 URL을 사용해 다운로드할 수 있습니다.

 - 이 실험을 위해서는 inception resnet version 2 을 선택합니다.

 

# inception resnet version 2
module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"

# 실험 결과에 따라 ssd mobilenet version 2 을 사용할 수 있습니다
#module_handle = "https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1"

Model을 load 하기

module_handle 이라고 정의한 model을 load합니다.

model = hub.load(module_handle)


Default signature 선택하기

Tensorflow hub에 있는 model은 다양한 tasks를 위해 사용됩니다. 그래서 각 모델의 문서는 model을 실행할 때 무슨 signature가 필요한지 보여줘야 합니다. 우리는 이 실험에서 default signature만 사용할 것입니다.

# model에 대해 available한 signature를 확인하기
model.signatures.keys()

Detection model을 위해 'default' signature를 선택합니다. 'default' signature은 image tensors의 batch를 받아들일 것입니다. 탐지된 객체를 설명할 dictionary를 output할 것입니다.

detector = model.signatures['default']

 


Image를 download하고 pre-processes하기

아래 function은 주어진 "url"을 다운로드하고 전처리하고 disk에 저장하는 function입니다. online에 있는 image를 가져오고 , resize한 뒤 로컬 자오에 저장합니다.

def download_and_resize_image(url, new_width=256, new_height=256):
    
    # ".jpg"로 끝나는 임시파일 만들기
    _, filename = tempfile.mkstemp(suffix=".jpg")
    
    # URL 열기
    response = urlopen(url)
    
    # URL로부터 이미지 가져와서 읽기
    image_data = response.read()
    
    # 이미지 데이터를 memory buffer에 넣기
    image_data = BytesIO(image_data)
    
    # Image 열기
    pil_image = Image.open(image_data)
    
    # Image를 resize하기 (ratio가 다를 경우 crop)
    pil_image = ImageOps.fit(pil_image, (new_width, new_height), Image.ANTIALIAS)
    
    # RGB colorspace로 convert하기
    pil_image_rgb = pil_image.convert("RGB")
    
    # Image를 임시파일에 저장하기 
    pil_image_rgb.save(filename, format="JPEG", quality=90)
    
    print("Image downloaded to %s." % filename)
    
    return filename

 download_and_resize_image 를 사용해 샘플이미지를 온라인에서 가져와서 로컬에 저장합니다. 이미지의 크기는 결과에 따라 수정할 수 있습니다.

 

# URL은 수정 가능
image_url = "https://upload.wikimedia.org/wikipedia/commons/f/fb/20130807_dublin014.JPG"

# image 다운로드 후 기존 높이, 넓이 사용하기
downloaded_image_path = download_and_resize_image(image_url, 3872, 2592)

Image를 load하고 detector를 run하기

이제 run_detector을 정의합니다. 이 function은 object detection model인 detector와 sample image로의 path를 받습니다. 그리고 object를 detect하고 예상 class를 나타내는데 이 model을 사용합니다.

load_img

JPEG 이미지를 받아서 tensor로 변환합니다.

def load_img(path):
    
    # 파일 읽기
    img = tf.io.read_file(path)
    
    # tensor로 변환하기
    img = tf.image.decode_jpeg(img, channels=3)
    
    return img

run_detector

object detection model을 이용해 local file에 inference 실행합니다.

   detector (model) -- TF Hub에서 가져온 detection model

   path (string) -- 이미지가 저장된 local path

def run_detector(detector, path):
    
    # local file path로부터 이미지 가져오기 
    img = load_img(path)

    # tensor 앞에 batch dimension 추가하기
    converted_img  = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
    
    # model을 이용해 inference 실행하기
    result = detector(converted_img)

    # dictionary에 결과 저장하기
    result = {key:value.numpy() for key,value in result.items()}

    # results 출력하기
    print("Found %d objects." % len(result["detection_scores"]))

    print(result["detection_scores"])
    print(result["detection_class_entities"])
    print(result["detection_boxes"])

Image에 inference 실행하기

run_detector function을 가져와서 실행합니다. 아래 3가지 list에 따라 object의 개수가 출력됩니다.

 - 각 object가 의 detection scores (모델을 얼마나 신뢰할만한가)

 - 각 object의 class

 - 각 object의 class의 bounding boxes

# object detection model을 실행하고 찾은 object 정보를 출력
run_detector(detector, downloaded_image_path)

 

 

댓글