Skip to content
Snippets Groups Projects
Commit 319923c1 authored by razvan841's avatar razvan841
Browse files

initial scripts and file system

parent 464e80a8
Branches
No related tags found
No related merge requests found
Pipeline #133076 canceled
import cv2
import numpy as np
from tensorflow.keras.models import load_model
import os
def extract_patches(image):
height, width, _ = image.shape
square_height = height // 3
square_width = width // 3
patches = []
for row in range(3):
for col in range(3):
y_start = row * square_height
y_end = (row + 1) * square_height
x_start = col * square_width
x_end = (col + 1) * square_width
patch = image[y_start:y_end, x_start:x_end]
patches.append(patch)
return patches
def preprocess_patches(patches):
patches_resized = [cv2.resize(patch, (32, 32)) for patch in patches]
patches_normalized = np.array(patches_resized) / 255.0
return patches_normalized
def predict_colors(model, image):
patches = extract_patches(image)
patches_preprocessed = preprocess_patches(patches)
predictions = model.predict(patches_preprocessed)
predicted_colors = [np.argmax(pred) for pred in predictions]
return predicted_colors
def main():
model_path = "rubiks_cube_model.h5"
if not os.path.exists(model_path):
print("Error: Model file not found.")
return
print("Loading model...")
model = load_model(model_path)
# Initialize OpenCV camera
camera = cv2.VideoCapture(0) # Use the default camera (index 0)
if not camera.isOpened():
print("Error: Could not open the camera.")
return
while True:
# Capture a frame
ret, frame = camera.read()
if not ret:
print("Error: Failed to capture an image.")
continue
# Display the live camera feed
cv2.imshow("Rubik's Cube - Press 'c' to capture", frame)
# Wait for the user to press 'c' to capture or 'q' to quit
key = cv2.waitKey(1) & 0xFF
if key == ord('c'):
print("Processing captured image...")
predicted_colors = predict_colors(model, frame)
print("Predicted Colors:", predicted_colors)
elif key == ord('q'):
print("Exiting...")
break
camera.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()
# label_map = {'white': 0, 'yellow': 1, 'red': 2, 'blue': 3, 'green': 4, 'orange': 5}
# pip install -r requirements.txt
opencv-python==4.10.0.84
numpy==1.25.2
tensorflow==2.16.1
\ No newline at end of file
import cv2
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import os
import json
# extract 9 patches from the image assuming the image
# contains only the cube and is properly aligned.
def extract_patches(image):
height, width, _ = image.shape
square_height = height // 3
square_width = width // 3
patches = []
for row in range(3):
for col in range(3):
y_start = row * square_height
y_end = (row + 1) * square_height
x_start = col * square_width
x_end = (col + 1) * square_width
patch = image[y_start:y_end, x_start:x_end]
patches.append(patch)
return patches
# preprocess patches (resize to 32x32 and normalize)
def preprocess_patches(patches):
patches_resized = [cv2.resize(patch, (32, 32)) for patch in patches]
patches_normalized = np.array(patches_resized) / 255.0
return patches_normalized
# CNN model
def build_model():
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D((2, 2)),
Flatten(),
Dense(64, activation='relu'),
Dense(6, activation='softmax') # 6 classes for Rubik's cube colors
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
#load training data
def load_training_data(data_dir):
"""
Loads and preprocesses the training data.
Args:
data_dir (str): Directory containing labeled folders of cube colors.
Returns:
tuple: Training images and labels.
"""
images = []
labels = []
label_map = {'white': 0, 'yellow': 1, 'red': 2, 'blue': 3, 'green': 4, 'orange': 5}
for label_name, label_idx in label_map.items():
folder_path = os.path.join(data_dir, label_name)
if not os.path.exists(folder_path):
continue
for file_name in os.listdir(folder_path):
file_path = os.path.join(folder_path, file_name)
image = cv2.imread(file_path)
if image is not None:
image_resized = cv2.resize(image, (32, 32))
images.append(image_resized)
labels.append(label_idx)
images = np.array(images) / 255.0 # Normalize
labels = np.array(labels)
return images, labels
def load_test_labels(label_file):
with open(label_file, 'r') as f:
labels = {}
for line in f:
image_name, label_str = line.strip().split(":")
labels[image_name.strip()] = json.loads(label_str.strip())
return labels
def main():
script_dir = os.path.dirname(os.path.abspath(__file__))
# Paths
training_data_dir = os.path.join(script_dir, "training_data")
test_images_dir = os.path.join(script_dir, "test_images")
test_labels_file = os.path.join(test_images_dir, "test_labels.txt")
print("Loading training data...")
X_train, y_train = load_training_data(training_data_dir)
print(f"Loaded {len(X_train)} training samples.")
print("Building model...")
model = build_model()
print("Training model...")
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2)
model.save("rubiks_cube_model.h5")
print("Model saved as rubiks_cube_model.h5.")
print("Loading test labels...")
test_labels = load_test_labels(test_labels_file)
print("Processing test images...")
correct = 0
total = 0
for file_name in os.listdir(test_images_dir):
if file_name.endswith(".jpg"):
file_path = os.path.join(test_images_dir, file_name)
test_image = cv2.imread(file_path)
test_patches = extract_patches(test_image)
test_patches_preprocessed = preprocess_patches(test_patches)
predictions = model.predict(test_patches_preprocessed)
predicted_colors = [np.argmax(pred) for pred in predictions]
if file_name in test_labels:
true_colors = test_labels[file_name]
total += 9
correct += sum([1 for pred, true in zip(predicted_colors, true_colors) if pred == true])
accuracy = (correct / total) * 100 if total > 0 else 0
print(f"Accuracy: {accuracy:.2f}%")
if __name__ == "__main__":
main()
test_image1.jpg: [0, 1, 2, 3, 4, 5, 0, 1, 2]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment