Refactoring media class

This commit is contained in:
Cédric Leporcq 2021-08-14 21:31:37 +02:00
parent a0920e3b5b
commit 1d488c0154
8 changed files with 80 additions and 161 deletions

View File

@ -8,11 +8,12 @@ from datetime import datetime
import click import click
from ordigi import constants
from ordigi import config from ordigi import config
from ordigi.filesystem import FileSystem from ordigi import constants
from ordigi import log
from ordigi.database import Db from ordigi.database import Db
from ordigi.media.media import Media, get_all_subclasses from ordigi.filesystem import FileSystem
from ordigi.media import Media, get_all_subclasses
from ordigi.summary import Summary from ordigi.summary import Summary
FILESYSTEM = FileSystem() FILESYSTEM = FileSystem()

View File

@ -83,7 +83,7 @@ class Db(object):
# structure might be needed. Some speed up ideas: # structure might be needed. Some speed up ideas:
# - Sort it and inter-half method can be used # - Sort it and inter-half method can be used
# - Use integer part of long or lat as key to get a lower search list # - Use integer part of long or lat as key to get a lower search list
# - Cache a small number of lookups, photos are likely to be taken in # - Cache a small number of lookups, images are likely to be taken in
# clusters around a spot during import. # clusters around a spot during import.
def add_location(self, latitude, longitude, place, write=False): def add_location(self, latitude, longitude, place, write=False):
"""Add a location to the database. """Add a location to the database.

View File

@ -17,8 +17,9 @@ from datetime import datetime, timedelta
from ordigi import constants from ordigi import constants
from ordigi import geolocation from ordigi import geolocation
from ordigi.media.media import get_media_class, get_all_subclasses from ordigi import media
from ordigi.media.photo import Photo from ordigi.media import Media, get_all_subclasses
from ordigi.images import Images
from ordigi.summary import Summary from ordigi.summary import Summary

View File

@ -1,5 +1,5 @@
""" """
The photo module contains the :class:`Photo` class, which is used to track The image module contains the :class:`Images` class, which is used to track
image objects (JPG, DNG, etc.). image objects (JPG, DNG, etc.).
.. moduleauthor:: Jaisen Mathai <jaisen@jmathai.com> .. moduleauthor:: Jaisen Mathai <jaisen@jmathai.com>
@ -10,50 +10,36 @@ import imghdr
import logging import logging
import numpy as np import numpy as np
import os import os
from PIL import Image, UnidentifiedImageError from PIL import Image as img
from PIL import UnidentifiedImageError
import time import time
from .media import Media # HEIC extension support (experimental, not tested)
PYHEIF = False
try:
class Photo(Media):
"""A photo object.
:param str source: The fully qualified path to the photo file
"""
__name__ = 'Photo'
#: Valid extensions for photo files.
extensions = ('arw', 'cr2', 'dng', 'gif', 'heic', 'jpeg', 'jpg', 'nef', 'png', 'rw2')
def __init__(self, source=None, hash_size=8, ignore_tags=set(),
logger=logging.getLogger()):
super().__init__(source, ignore_tags)
self.hash_size = hash_size
self.logger = logger
logger.setLevel(logging.INFO)
# HEIC extension support (experimental, not tested)
self.pyheif = False
try:
from pyheif_pillow_opener import register_heif_opener from pyheif_pillow_opener import register_heif_opener
self.pyheif = True PYHEIF = True
# Allow to open HEIF/HEIC images from pillow # Allow to open HEIF/HEIC image from pillow
register_heif_opener() register_heif_opener()
except ImportError as e: except ImportError as e:
self.logger.info(e) logging.info(e)
def is_image(self, img_path):
class Image():
def __init__(self, img_path, hash_size=8):
self.img_path = img_path
self.hash_size = hash_size
def is_image(self):
"""Check whether the file is an image. """Check whether the file is an image.
:returns: bool :returns: bool
""" """
# gh-4 This checks if the source file is an image. # gh-4 This checks if the file is an image.
# It doesn't validate against the list of supported types. # It doesn't validate against the list of supported types.
# We check with imghdr and pillow. # We check with imghdr and pillow.
if imghdr.what(img_path) is None: if imghdr.what(self.img_path) is None:
# Pillow is used as a fallback # Pillow is used as a fallback
# imghdr won't detect all variants of images (https://bugs.python.org/issue28591) # imghdr won't detect all variants of images (https://bugs.python.org/issue28591)
# see https://github.com/jmathai/elodie/issues/281 # see https://github.com/jmathai/elodie/issues/281
@ -65,7 +51,7 @@ class Photo(Media):
# things like mode, size, and other properties required to decode the file, # things like mode, size, and other properties required to decode the file,
# but the rest of the file is not processed until later. # but the rest of the file is not processed until later.
try: try:
im = Image.open(img_path) im = img.open(self.img_path)
except (IOError, UnidentifiedImageError): except (IOError, UnidentifiedImageError):
return False return False
@ -74,26 +60,48 @@ class Photo(Media):
return True return True
def get_images(self, file_paths): def get_hash(self):
with img.open(self.img_path) as img_path:
return imagehash.average_hash(img_path, self.hash_size).hash
class Images():
"""A image object.
:param str img_path: The fully qualified path to the image file
"""
#: Valid extensions for image files.
extensions = ('arw', 'cr2', 'dng', 'gif', 'heic', 'jpeg', 'jpg', 'nef', 'png', 'rw2')
def __init__(self, file_paths=None, hash_size=8, logger=logging.getLogger()):
self.file_paths = file_paths
self.hash_size = hash_size
self.duplicates = []
self.logger = logger
def get_images(self):
''':returns: img_path generator
''' '''
:returns: img_path generator for img_path in self.file_paths:
''' image = Image(img_path)
for img_path in file_paths: if image.is_image():
if self.is_image(img_path):
yield img_path yield img_path
def get_images_hashes(self, file_paths): def get_images_hashes(self):
"""Get image hashes""" """Get image hashes"""
hashes = {} hashes = {}
duplicates = []
# Searching for duplicates. # Searching for duplicates.
for img_path in self.get_images(file_paths): for img_path in self.get_images():
with Image.open(img_path) as img: with img.open(img_path) as img:
yield imagehash.average_hash(img, self.hash_size) yield imagehash.average_hash(img, self.hash_size)
def find_duplicates(self, file_paths): def find_duplicates(self, img_path):
"""Find duplicates""" """Find duplicates"""
for temp_hash in get_images_hashes(file_paths): duplicates = []
for temp_hash in get_images_hashes(self.file_paths):
if temp_hash in hashes: if temp_hash in hashes:
self.logger.info("Duplicate {} \nfound for image {}\n".format(img_path, hashes[temp_hash])) self.logger.info("Duplicate {} \nfound for image {}\n".format(img_path, hashes[temp_hash]))
duplicates.append(img_path) duplicates.append(img_path)
@ -118,10 +126,6 @@ class Photo(Media):
else: else:
self.logger.info("No duplicates found") self.logger.info("No duplicates found")
def get_hash(self, img_path):
with Image.open(img_path) as img:
return imagehash.average_hash(img, self.hash_size).hash
def diff(self, hash1, hash2): def diff(self, hash1, hash2):
return np.count_nonzero(hash1 != hash2) return np.count_nonzero(hash1 != hash2)
@ -131,24 +135,25 @@ class Photo(Media):
return similarity_img return similarity_img
def find_similar(self, image, file_paths, similarity=80): def find_similar(self, image, similarity=80):
''' '''
Find similar images Find similar images
:returns: img_path generator :returns: img_path generator
''' '''
hash1 = '' hash1 = ''
if self.is_image(image): image = Image(image)
hash1 = self.get_hash(image) if image.is_image():
hash1 = image.get_hash()
self.logger.info(f'Finding similar images to {image}') self.logger.info(f'Finding similar images to {image}')
threshold = 1 - similarity/100 threshold = 1 - similarity/100
diff_limit = int(threshold*(self.hash_size**2)) diff_limit = int(threshold*(self.hash_size**2))
for img_path in self.get_images(file_paths): for img_path in self.get_images():
if img_path == image: if img_path == image:
continue continue
hash2 = self.get_hash(img_path) hash2 = image.get_hash()
img_diff = self.diff(hash1, hash2) img_diff = self.diff(hash1, hash2)
if img_diff <= diff_limit: if img_diff <= diff_limit:
similarity_img = self.similarity(img_diff) similarity_img = self.similarity(img_diff)

View File

@ -1,13 +1,11 @@
""" """
Base :class:`Media` class for media objects Media :class:`Media` class to get file metadata
The Media class provides some base functionality used by all the media types.
Sub-classes (:class:`~ordigi.media.Audio`, :class:`~ordigi.media.Photo`, and :class:`~ordigi.media.Video`).
""" """
import logging
import mimetypes import mimetypes
import os import os
import six import six
import logging
# load modules # load modules
from dateutil.parser import parse from dateutil.parser import parse
@ -18,11 +16,9 @@ class Media():
"""The media class for all media objects. """The media class for all media objects.
:param str source: The fully qualified path to the video file. :param str file_path: The fully qualified path to the media file.
""" """
__name__ = 'Media'
d_coordinates = { d_coordinates = {
'latitude': 'latitude_ref', 'latitude': 'latitude_ref',
'longitude': 'longitude_ref' 'longitude': 'longitude_ref'
@ -34,8 +30,8 @@ class Media():
extensions = PHOTO + AUDIO + VIDEO extensions = PHOTO + AUDIO + VIDEO
def __init__(self, sources=None, ignore_tags=set(), logger=logging.getLogger()): def __init__(self, file_path, ignore_tags=set(), logger=logging.getLogger()):
self.source = sources self.file_path = file_path
self.ignore_tags = ignore_tags self.ignore_tags = ignore_tags
self.tags_keys = self.get_tags() self.tags_keys = self.get_tags()
self.exif_metadata = None self.exif_metadata = None
@ -104,7 +100,7 @@ class Media():
:returns: str or None :returns: str or None
""" """
mimetype = mimetypes.guess_type(self.source) mimetype = mimetypes.guess_type(self.file_path)
if(mimetype is None): if(mimetype is None):
return None return None
@ -198,7 +194,7 @@ class Media():
:returns: dict :returns: dict
""" """
# Get metadata from exiftool. # Get metadata from exiftool.
self.exif_metadata = ExifToolCaching(self.source, logger=self.logger).asdict() self.exif_metadata = ExifToolCaching(self.file_path, logger=self.logger).asdict()
# TODO to be removed # TODO to be removed
self.metadata = {} self.metadata = {}
@ -224,9 +220,9 @@ class Media():
self.metadata[key] = formated_data self.metadata[key] = formated_data
self.metadata['base_name'] = os.path.basename(os.path.splitext(self.source)[0]) self.metadata['base_name'] = os.path.basename(os.path.splitext(self.file_path)[0])
self.metadata['ext'] = os.path.splitext(self.source)[1][1:] self.metadata['ext'] = os.path.splitext(self.file_path)[1][1:]
self.metadata['directory_path'] = os.path.dirname(self.source) self.metadata['directory_path'] = os.path.dirname(self.file_path)
return self.metadata return self.metadata
@ -245,8 +241,7 @@ class Media():
def get_class_by_file(cls, _file, classes, ignore_tags=set(), logger=logging.getLogger()): def get_class_by_file(cls, _file, classes, ignore_tags=set(), logger=logging.getLogger()):
"""Static method to get a media object by file. """Static method to get a media object by file.
""" """
basestring = (bytes, str) if not os.path.isfile(_file):
if not isinstance(_file, basestring) or not os.path.isfile(_file):
return None return None
extension = os.path.splitext(_file)[1][1:].lower() extension = os.path.splitext(_file)[1][1:].lower()
@ -254,13 +249,9 @@ class Media():
if len(extension) > 0: if len(extension) > 0:
for i in classes: for i in classes:
if(extension in i.extensions): if(extension in i.extensions):
return i(_file, ignore_tags=ignore_tags) return i(_file, ignore_tags=ignore_tags, logger=logger)
exclude_list = ['.DS_Store', '.directory'] return Media(_file, logger, ignore_tags=ignore_tags, logger=logger)
if os.path.basename(_file) == '.DS_Store':
return None
else:
return Media(_file, ignore_tags=ignore_tags, logger=logger)
def set_date_taken(self, date_key, time): def set_date_taken(self, date_key, time):
"""Set the date/time a photo was taken. """Set the date/time a photo was taken.
@ -309,7 +300,7 @@ class Media():
:returns: bool :returns: bool
""" """
folder = os.path.basename(os.path.dirname(self.source)) folder = os.path.basename(os.path.dirname(self.file_path))
return set_value(self, 'album', folder) return set_value(self, 'album', folder)

View File

@ -1,36 +0,0 @@
"""
The audio module contains classes specifically for dealing with audio files.
The :class:`Audio` class inherits from the :class:`~ordigi.media.Media`
class.
.. moduleauthor:: Jaisen Mathai <jaisen@jmathai.com>
"""
import os
from .media import Media
class Audio(Media):
"""An audio object.
:param str source: The fully qualified path to the audio file.
"""
__name__ = 'Audio'
#: Valid extensions for audio files.
extensions = ('m4a',)
def __init__(self, source=None, ignore_tags=set()):
super().__init__(source, ignore_tags=set())
def is_valid(self):
"""Check the file extension against valid file extensions.
The list of valid file extensions come from self.extensions.
:returns: bool
"""
source = self.source
return os.path.splitext(source)[1][1:].lower() in self.extensions

View File

@ -1,43 +0,0 @@
"""
The video module contains the :class:`Video` class, which represents video
objects (AVI, MOV, etc.).
.. moduleauthor:: Jaisen Mathai <jaisen@jmathai.com>
"""
# load modules
from datetime import datetime
import os
import re
import time
from .media import Media
class Video(Media):
"""A video object.
:param str source: The fully qualified path to the video file.
"""
__name__ = 'Video'
#: Valid extensions for video files.
extensions = ('avi', 'm4v', 'mov', 'mp4', 'mpg', 'mpeg', '3gp', 'mts')
def __init__(self, source=None, ignore_tags=set()):
super().__init__(source, ignore_tags=set())
# self.set_gps_ref = False
def is_valid(self):
"""Check the file extension against valid file extensions.
The list of valid file extensions come from self.extensions.
:returns: bool
"""
source = self.source
return os.path.splitext(source)[1][1:].lower() in self.extensions