Nova-UTD/navigator

Fix Traffic Light Detector Bug (Not Launching)

Opened this issue · 3 comments

Fix Traffic Light Detector Bug (Not Launching)

for bug fixing to test this code( to be done)-
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
from navigator_msgs.msg import TrafficLightDetection, TrafficLight

import torch
import numpy as np
import torchvision.transforms as transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

class TrafficLightDetectorNode(Node):
def init(self):
super().init('traffic_light_detector_node')

    self.bridge = CvBridge()
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the pre-trained Faster R-CNN model
    self.model = fasterrcnn_resnet50_fpn(pretrained=True)
    
    # Modify the model for custom classification task (traffic lights)
    num_classes = 4  # Including background class
    in_features = self.model.roi_heads.box_predictor.in_features
    self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # Load the trained weights for fine-tuning
    model_weights_path = 'data/perception/models/fasterrcnn_resnet50_fpn.pth'
    self.model.load_state_dict(torch.load(model_weights_path, map_location=self.device))
    
    # Set the model to evaluation mode
    self.model.to(self.device)
    self.model.eval()

    self.subscription = self.create_subscription(
        Image, '/cameras/camera0', self.image_callback, 10)

    self.publisher = self.create_publisher(
        TrafficLightDetection, '/traffic_light/detections', 10)

def image_callback(self, msg):
    image_data = self.bridge.imgmsg_to_cv2(msg, desired_encoding='rgb8')

    # Run traffic light detection
    traffic_light_detections = self.detect_traffic_lights(image_data)

    # Publish traffic light detections
    self.publisher.publish(traffic_light_detections)

def detect_traffic_lights(self, image_data):
    # Preprocess image
    image_tensor = self.preprocess_image(image_data)
    image_tensor = image_tensor.to(self.device)

    # Run inference
    with torch.no_grad():
        predictions = self.model(image_tensor)

    # Post-process predictions
    boxes, scores, labels = self.postprocess_predictions(predictions)

    # Prepare traffic light detections message
    traffic_light_detections_msg = TrafficLightDetection()
    traffic_light_detections_msg.header.stamp = self.get_clock().now().to_msg()
    traffic_light_detections_msg.header.frame_id = "camera_frame"

    # Populate traffic light detections message with detected traffic lights
    for box, score, label in zip(boxes, scores, labels):
        traffic_light = TrafficLight()
        traffic_light.x = float(box[0])
        traffic_light.y = float(box[1])
        traffic_light.width = float(box[2] - box[0])
        traffic_light.height = float(box[3] - box[1])
        traffic_light.label = int(label)
        traffic_light.score = float(score)
        traffic_light_detections_msg.traffic_lights.append(traffic_light)

    return traffic_light_detections_msg

def preprocess_image(self, image_data):
    # Convert the CV image to a PyTorch tensor
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
    ])
    image_tensor = transform(image_data)
    # Add batch dimension (assuming you're working with a single image)
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

def postprocess_predictions(self, predictions):
    # Assuming predictions is a list of dictionaries
    boxes = predictions[0]['boxes'].cpu().numpy()
    scores = predictions[0]['scores'].cpu().numpy()
    labels = predictions[0]['labels'].cpu().numpy()
    
    # Filter out low-confidence detections
    high_scores_indices = scores > 0.5
    boxes = boxes[high_scores_indices]
    scores = scores[high_scores_indices]
    labels = labels[high_scores_indices]

    return boxes, scores, labels

def main(args=None):
rclpy.init(args=args)
traffic_light_detector_node = TrafficLightDetectorNode()
rclpy.spin(traffic_light_detector_node)
rclpy.shutdown()

if name == 'main':
main()

Status-
Status-

Status- In progress(almost about to be done)
next- trying if it works with ros2 bag
Projected completion- week of 16th September
update-