Source code for renoir.color.extraction

"""
Color extraction functions for artwork analysis.

This module provides tools for extracting dominant colors from artworks
using k-means clustering and other computational methods.
"""

import json
import os

import numpy as np
from typing import List, Dict, Tuple, Optional, Union
from PIL import Image
from collections import Counter

try:
    from sklearn.cluster import KMeans

    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False


def _validate_export_filename(filename: str) -> None:
    """Raise ValueError if filename contains path-traversal components."""
    # Block explicit '..' traversal in the raw filename
    if ".." in filename.replace("\\", "/").split("/"):
        raise ValueError("filename must not contain path traversal components ('..').")
    # Block null bytes
    if "\x00" in filename:
        raise ValueError("filename must not contain null bytes.")


[docs] class ColorExtractor: """ Extract dominant colors and palettes from artwork images. This class provides methods for extracting color information from digital images, designed for educational use in teaching computational color analysis to art and design students. Attributes: use_sklearn: Whether scikit-learn is available for k-means clustering """ def __init__(self): """Initialize the ColorExtractor.""" self.use_sklearn = SKLEARN_AVAILABLE if not SKLEARN_AVAILABLE: print("Warning: scikit-learn not available. Some features may be limited.") print("Install with: pip install scikit-learn")
[docs] def extract_dominant_colors( self, image: Union[Image.Image, np.ndarray], n_colors: int = 5, method: str = "kmeans", sample_size: Optional[int] = 10000, random_state: int = 42, filter_extremes: bool = True, ) -> List[Tuple[int, int, int]]: """ Extract dominant colors from an image using k-means clustering. This method identifies the most prominent colors in an artwork by clustering pixel colors and finding cluster centers. Ideal for teaching students about color quantization and computational analysis. Args: image: PIL Image or numpy array of the artwork n_colors: Number of dominant colors to extract (default: 5) method: Extraction method - 'kmeans' or 'frequency' (default: 'kmeans') sample_size: Number of pixels to sample for faster processing None = use all pixels (default: 10000) random_state: Random seed for reproducible pixel sampling and k-means clustering (default: 42) filter_extremes: Whether to remove pure black (0,0,0) and pure white (255,255,255) pixels before clustering. Set to False when analysing artworks where true black or white are meaningful (default: True) Returns: List of RGB tuples representing dominant colors, ordered by prominence Raises: ValueError: If n_colors is invalid ValueError: If sample_size is invalid ValueError: If method is not recognized TypeError: If image is not PIL Image or numpy array ValueError: If image dimensions are invalid Example: >>> extractor = ColorExtractor() >>> from PIL import Image >>> img = Image.open('artwork.jpg') >>> colors = extractor.extract_dominant_colors(img, n_colors=5) >>> print(colors) [(120, 89, 143), (201, 178, 156), ...] """ # Input validation if not isinstance(n_colors, int): raise ValueError("n_colors must be an integer") if n_colors < 1: raise ValueError("n_colors must be at least 1") if n_colors > 256: raise ValueError("n_colors cannot exceed 256") if sample_size is not None: if not isinstance(sample_size, int): raise ValueError("sample_size must be an integer or None") if sample_size < 1: raise ValueError("sample_size must be positive") if method not in ["kmeans", "frequency"]: raise ValueError("method must be 'kmeans' or 'frequency'") # Validate and convert image try: if isinstance(image, Image.Image): img_array = np.array(image) elif isinstance(image, np.ndarray): img_array = image else: raise TypeError( "image must be a PIL Image or numpy array, " f"got {type(image).__name__}" ) except Exception as e: raise TypeError(f"Failed to convert image to array: {str(e)}") # Validate image dimensions if img_array.ndim not in [2, 3]: raise ValueError(f"Image must be 2D or 3D array, got {img_array.ndim}D") if img_array.ndim == 3: if img_array.shape[-1] not in [3, 4]: raise ValueError( f"Image must have 3 (RGB) or 4 (RGBA) channels, " f"got {img_array.shape[-1]}" ) # Handle RGBA images by removing alpha channel if img_array.shape[-1] == 4: img_array = img_array[:, :, :3] else: # Grayscale - convert to RGB img_array = np.stack([img_array] * 3, axis=-1) # Check image is not empty if img_array.size == 0: raise ValueError("Image is empty") # Reshape to 2D array of pixels try: pixels = img_array.reshape(-1, 3) except Exception as e: raise ValueError(f"Failed to reshape image: {str(e)}") # Remove any invalid pixels (e.g., all zeros, all 255s) if filter_extremes: valid_pixels = pixels[~np.all(pixels == 0, axis=1)] valid_pixels = valid_pixels[~np.all(valid_pixels == 255, axis=1)] else: valid_pixels = pixels if len(valid_pixels) == 0: print("Warning: No valid pixels found in image") return [(0, 0, 0)] * n_colors if len(valid_pixels) < n_colors: print( f"Warning: Only {len(valid_pixels)} unique pixels, fewer than requested {n_colors} colors" ) n_colors = len(valid_pixels) # Sample pixels for faster processing if sample_size and len(valid_pixels) > sample_size: rng = np.random.RandomState(random_state) indices = rng.choice(len(valid_pixels), sample_size, replace=False) sampled_pixels = valid_pixels[indices] else: sampled_pixels = valid_pixels try: if method == "kmeans" and self.use_sklearn: return self._extract_kmeans(sampled_pixels, n_colors, random_state) else: return self._extract_frequency(sampled_pixels, n_colors) except Exception as e: raise RuntimeError(f"Color extraction failed: {str(e)}") from e
def _extract_kmeans( self, pixels: np.ndarray, n_colors: int, random_state: int = 42 ) -> List[Tuple[int, int, int]]: """ Extract colors using k-means clustering. Args: pixels: Array of pixel RGB values n_colors: Number of clusters/colors to extract random_state: Random seed for reproducible clustering Returns: List of RGB tuples ordered by cluster size """ # Perform k-means clustering kmeans = KMeans(n_clusters=n_colors, random_state=random_state, n_init=10) kmeans.fit(pixels) # Get cluster centers (dominant colors) colors = kmeans.cluster_centers_.astype(int) # Count pixels in each cluster to order by prominence labels = kmeans.labels_ label_counts = Counter(labels) # Sort colors by cluster size (most prominent first) sorted_indices = sorted( range(n_colors), key=lambda i: label_counts[i], reverse=True ) sorted_colors = [tuple(colors[i]) for i in sorted_indices] return sorted_colors def _extract_frequency( self, pixels: np.ndarray, n_colors: int ) -> List[Tuple[int, int, int]]: """ Extract colors by frequency (fallback method without sklearn). Args: pixels: Array of pixel RGB values n_colors: Number of most frequent colors to return Returns: List of RGB tuples ordered by frequency """ # Convert to tuples for counting pixel_tuples = [tuple(pixel) for pixel in pixels] # Count frequencies color_counts = Counter(pixel_tuples) # Get n most common most_common = color_counts.most_common(n_colors) return [color for color, count in most_common]
[docs] def extract_palette_from_artwork( self, artwork_dict: Dict, n_colors: int = 5 ) -> Dict: """ Extract color palette from a WikiArt artwork dictionary. Convenience method that works directly with renoir artwork data. Args: artwork_dict: Artwork dictionary from WikiArt dataset n_colors: Number of colors to extract Returns: Dictionary with 'colors' (RGB tuples) and 'metadata' (artwork info) Example: >>> from renoir import ArtistAnalyzer >>> analyzer = ArtistAnalyzer() >>> works = analyzer.extract_artist_works('claude-monet') >>> extractor = ColorExtractor() >>> palette = extractor.extract_palette_from_artwork(works[0]) """ image = artwork_dict["image"] colors = self.extract_dominant_colors(image, n_colors=n_colors) return { "colors": colors, "artwork": artwork_dict.get("title", "Unknown"), "artist": artwork_dict.get("artist", "Unknown"), "n_colors": n_colors, }
[docs] def extract_average_color( self, image: Union[Image.Image, np.ndarray] ) -> Tuple[int, int, int]: """ Calculate the average color of an image. Simple method useful for teaching color concepts to beginners. Args: image: PIL Image or numpy array Returns: RGB tuple representing the average color Example: >>> extractor = ColorExtractor() >>> avg_color = extractor.extract_average_color(img) >>> print(f"Average color: RGB{avg_color}") """ if isinstance(image, Image.Image): img_array = np.array(image) else: img_array = image # Handle RGBA if img_array.shape[-1] == 4: img_array = img_array[:, :, :3] # Calculate mean for each channel avg_color = np.mean(img_array, axis=(0, 1)).astype(int) return tuple(avg_color)
[docs] def rgb_to_hex(self, rgb: Tuple[int, int, int]) -> str: """ Convert RGB tuple to hexadecimal color code. Args: rgb: Tuple of (R, G, B) values (0-255) Returns: Hexadecimal color string (e.g., '#FF5733') Example: >>> extractor = ColorExtractor() >>> hex_color = extractor.rgb_to_hex((255, 87, 51)) >>> print(hex_color) # '#FF5733' """ return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])
[docs] def hex_to_rgb(self, hex_color: str) -> Tuple[int, int, int]: """ Convert hexadecimal color code to RGB tuple. Args: hex_color: Hexadecimal color string (e.g., '#FF5733' or 'FF5733') Returns: Tuple of (R, G, B) values (0-255) Example: >>> extractor = ColorExtractor() >>> rgb = extractor.hex_to_rgb('#FF5733') >>> print(rgb) # (255, 87, 51) """ # Remove '#' if present hex_color = hex_color.lstrip("#") # Convert to RGB return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
[docs] def palette_to_dict( self, colors: List[Tuple[int, int, int]], include_hex: bool = True ) -> Dict: """ Convert color palette to a dictionary format. Useful for exporting palettes or educational demonstrations. Args: colors: List of RGB tuples include_hex: Whether to include hex codes (default: True) Returns: Dictionary with palette information Example: >>> colors = [(255, 87, 51), (100, 200, 150)] >>> palette_dict = extractor.palette_to_dict(colors) """ # Convert numpy types to native Python types for JSON compatibility colors_native = [(int(r), int(g), int(b)) for r, g, b in colors] palette = {"rgb_values": colors_native, "n_colors": len(colors_native)} if include_hex: palette["hex_values"] = [self.rgb_to_hex(color) for color in colors_native] return palette
[docs] def export_palette_css( self, colors: List[Tuple[int, int, int]], filename: str, prefix: str = "color" ) -> None: """ Export color palette as CSS variables. Useful for design students to use extracted palettes in web projects. Args: colors: List of RGB tuples filename: Output CSS filename prefix: Variable name prefix (default: 'color') Example: >>> colors = [(255, 87, 51), (100, 200, 150)] >>> extractor.export_palette_css(colors, 'palette.css') """ _validate_export_filename(filename) with open(filename, "w") as f: f.write(":root {\n") for i, color in enumerate(colors, 1): hex_color = self.rgb_to_hex(color) f.write(f" --{prefix}-{i}: {hex_color};\n") f.write("}\n") print(f"Palette exported to {filename}")
[docs] def export_palette_json( self, colors: List[Tuple[int, int, int]], filename: str ) -> None: """ Export color palette as JSON. Args: colors: List of RGB tuples filename: Output JSON filename Example: >>> colors = [(255, 87, 51), (100, 200, 150)] >>> extractor.export_palette_json(colors, 'palette.json') """ _validate_export_filename(filename) palette_dict = self.palette_to_dict(colors) with open(filename, "w") as f: json.dump(palette_dict, f, indent=2) print(f"Palette exported to {filename}")
[docs] def check_color_extraction_support() -> bool: """ Check if color extraction dependencies are available. Returns: True if scikit-learn is available, False otherwise """ if SKLEARN_AVAILABLE: print("✅ Color extraction fully supported (scikit-learn available)") print(" You can use k-means clustering for optimal color extraction") else: print("⚠️ Limited color extraction support") print(" Install scikit-learn for k-means clustering:") print(" pip install scikit-learn") print(" Fallback frequency-based extraction will be used") return SKLEARN_AVAILABLE