YOLO (You Only Look Once) is an object detection algorithm that can be used to detect, classify, and track objects in near real-time. The first research paper about YOLO was published in May 2016. The 4th version of YOLO (YOLOv4) was introduced in April 2020.
This tutorial provides an example how to use pre-trained YOLOv4 to detect objects in an image. We will use YOLOv4 Python package which implemented in TensorFlow 2.
Using pip
package manager, install tensorflow
and tf2-yolov4
from the command line.
pip install tensorflow
pip install tf2-yolov4
Download YOLOv4 weights (yolov4.weights
) from AlexeyAB/darknet
repository. The model was trained on COCO (Common Objects In Context) dataset, which contains 80 object categories.
The tf2-yolov4
package includes the convert-darknet-weights
command which allows converting Darknet weights to TensorFlow weights.
convert-darknet-weights yolov4.weights -o yolov4.h5
We read and preprocess the image. The tf.io.decode_image
function detects an image format (JPEG, PNG, BMP, or GIF) and converts the input bytes into a Tensor
. An image is resized by using the tf.image.resize
function.
We have a single image of shape [height, width, channels]
. However, model require input to be [batch, height, width, channels]
, where the batch
dimension is missing. We can add an batch
dimension by passing axis=0
to tf.expand_dims
function.
An image contains pixels which values are between 0 and 255. We normalize the image data by dividing each pixel value by 255 to get a range of 0 to 1. It will improve activation functions performance.
We define YOLOv4 model. The yolo_max_boxes
parameter defines the maximum number of objects detected and classified on an image.
Intersection over Union (IoU) threshold defines whether close objects will be merged. Score threshold defines that objects will be filtered out if the score is lower than this threshold.
We load weights which were trained on COCO dataset. We use the pre-trained model to detect objects in an image.
- The
boxes
list contains bounding boxes for detected objects. Values are between 0 and 1. - The
scores
list contains the confidence score for each predicted object. Values are between 0 and 1. - The
classes
list indicates the detected objects belongs to one of the 80 classes. Values are between 0 and 80. - The
detections
variable defines how many objects were detected in an image.
import tensorflow as tf
from tf2_yolov4.anchors import YOLOV4_ANCHORS
from tf2_yolov4.model import YOLOv4
import matplotlib.pyplot as plt
WIDTH, HEIGHT = (1024, 768)
image = tf.io.read_file('test1.jpg')
image = tf.io.decode_image(image)
image = tf.image.resize(image, (HEIGHT, WIDTH))
images = tf.expand_dims(image, axis=0) / 255
model = YOLOv4(
input_shape=(HEIGHT, WIDTH, 3),
anchors=YOLOV4_ANCHORS,
num_classes=80,
training=False,
yolo_max_boxes=50,
yolo_iou_threshold=0.5,
yolo_score_threshold=0.5,
)
model.load_weights('yolov4.h5')
boxes, scores, classes, detections = model.predict(images)
boxes = boxes[0] * [WIDTH, HEIGHT, WIDTH, HEIGHT]
scores = scores[0]
classes = classes[0].astype(int)
detections = detections[0]
CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
plt.imshow(images[0])
ax = plt.gca()
for (xmin, ymin, xmax, ymax), score, class_idx in zip(boxes, scores, classes):
if score > 0:
rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color='green')
ax.add_patch(rect)
text = CLASSES[class_idx] + ': {0:.2f}'.format(score)
ax.text(xmin, ymin, text, fontsize=9, bbox=dict(facecolor='yellow', alpha=0.6))
plt.title('Objects detected: {}'.format(detections))
plt.axis('off')
plt.show()
Finally, we draw bounding boxes and class labels for each detected object.
Leave a Comment
Cancel reply