Merge pull request #163 from jmathai/use-generator
Convert FileSystem.get_all_files to be a generator function
This commit is contained in:
		
						commit
						5b4c74280a
					
				
							
								
								
									
										23
									
								
								elodie.py
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								elodie.py
									
									
									
									
									
								
							@ -115,34 +115,21 @@ def _generate_db(source):
 | 
			
		||||
    result = Result()
 | 
			
		||||
    source = os.path.abspath(os.path.expanduser(source))
 | 
			
		||||
 | 
			
		||||
    extensions = set()
 | 
			
		||||
    all_files = set()
 | 
			
		||||
    valid_files = set()
 | 
			
		||||
 | 
			
		||||
    if not os.path.isdir(source):
 | 
			
		||||
        log.error('Source is not a valid directory %s' % source)
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
        
 | 
			
		||||
    subclasses = get_all_subclasses(Base)
 | 
			
		||||
    for cls in subclasses:
 | 
			
		||||
        extensions.update(cls.extensions)
 | 
			
		||||
 | 
			
		||||
    all_files.update(FILESYSTEM.get_all_files(source, None))
 | 
			
		||||
 | 
			
		||||
    db = Db()
 | 
			
		||||
    db.backup_hash_db()
 | 
			
		||||
    db.reset_hash_db()
 | 
			
		||||
 | 
			
		||||
    for current_file in all_files:
 | 
			
		||||
        if os.path.splitext(current_file)[1][1:].lower() not in extensions:
 | 
			
		||||
            log.info('Skipping invalid file %s' % current_file)
 | 
			
		||||
            result.append((current_file, False))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
    for current_file in FILESYSTEM.get_all_files(source):
 | 
			
		||||
        result.append((current_file, True))
 | 
			
		||||
        db.add_hash(db.checksum(current_file), current_file)
 | 
			
		||||
        log.progress()
 | 
			
		||||
    
 | 
			
		||||
    db.update_hash_db()
 | 
			
		||||
    log.progress('', True)
 | 
			
		||||
    result.write()
 | 
			
		||||
 | 
			
		||||
@click.command('verify')
 | 
			
		||||
@ -152,14 +139,18 @@ def _verify():
 | 
			
		||||
    for checksum, file_path in db.all():
 | 
			
		||||
        if not os.path.isfile(file_path):
 | 
			
		||||
            result.append((file_path, False))
 | 
			
		||||
            log.progress('x')
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        actual_checksum = db.checksum(file_path)
 | 
			
		||||
        if checksum == actual_checksum:
 | 
			
		||||
            result.append((file_path, True))
 | 
			
		||||
            log.progress()
 | 
			
		||||
        else:
 | 
			
		||||
            result.append((file_path, False))
 | 
			
		||||
            log.progress('x')
 | 
			
		||||
 | 
			
		||||
    log.progress('', True)
 | 
			
		||||
    result.write()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ import time
 | 
			
		||||
from elodie import geolocation
 | 
			
		||||
from elodie import log
 | 
			
		||||
from elodie.localstorage import Db
 | 
			
		||||
from elodie.media.base import Base, get_all_subclasses
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FileSystem(object):
 | 
			
		||||
@ -61,17 +62,20 @@ class FileSystem(object):
 | 
			
		||||
 | 
			
		||||
        :param str path string: Path to start recursive file listing
 | 
			
		||||
        :param tuple(str) extensions: File extensions to include (whitelist)
 | 
			
		||||
        :returns: generator
 | 
			
		||||
        """
 | 
			
		||||
        files = []
 | 
			
		||||
        # If extensions is None then we get all supported extensions
 | 
			
		||||
        if not extensions:
 | 
			
		||||
            extensions = set()
 | 
			
		||||
            subclasses = get_all_subclasses(Base)
 | 
			
		||||
            for cls in subclasses:
 | 
			
		||||
                extensions.update(cls.extensions)
 | 
			
		||||
 | 
			
		||||
        for dirname, dirnames, filenames in os.walk(path):
 | 
			
		||||
            # print path to all filenames.
 | 
			
		||||
            for filename in filenames:
 | 
			
		||||
                if(
 | 
			
		||||
                    extensions is None or
 | 
			
		||||
                    filename.lower().endswith(extensions)
 | 
			
		||||
                ):
 | 
			
		||||
                    files.append(os.path.join(dirname, filename))
 | 
			
		||||
        return files
 | 
			
		||||
                # If file extension is in `extensions` then append to the list
 | 
			
		||||
                if os.path.splitext(filename)[1][1:].lower() in extensions:
 | 
			
		||||
                    yield os.path.join(dirname, filename)
 | 
			
		||||
 | 
			
		||||
    def get_current_directory(self):
 | 
			
		||||
        """Get the current working directory.
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,8 @@ General file system methods.
 | 
			
		||||
"""
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from json import dumps
 | 
			
		||||
 | 
			
		||||
from elodie import constants
 | 
			
		||||
@ -18,6 +20,13 @@ def info_json(payload):
 | 
			
		||||
    _print(dumps(payload))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def progress(message='.', new_line=False):
 | 
			
		||||
    if not new_line:
 | 
			
		||||
        print(message, end="")
 | 
			
		||||
    else:
 | 
			
		||||
        print(message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def warn(message):
 | 
			
		||||
    _print(message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirna
 | 
			
		||||
 | 
			
		||||
from . import helper
 | 
			
		||||
from elodie.filesystem import FileSystem
 | 
			
		||||
from elodie.media.text import Text
 | 
			
		||||
from elodie.media.media import Media
 | 
			
		||||
from elodie.media.photo import Photo
 | 
			
		||||
from elodie.media.video import Video
 | 
			
		||||
@ -102,35 +103,75 @@ def test_delete_directory_if_empty_when_not_empty():
 | 
			
		||||
def test_get_all_files_success():
 | 
			
		||||
    filesystem = FileSystem()
 | 
			
		||||
    folder = helper.populate_folder(5)
 | 
			
		||||
    files = filesystem.get_all_files(folder)
 | 
			
		||||
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder))
 | 
			
		||||
    shutil.rmtree(folder)
 | 
			
		||||
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 5, length
 | 
			
		||||
 | 
			
		||||
    assert length == 5, files
 | 
			
		||||
 | 
			
		||||
def test_get_all_files_by_extension():
 | 
			
		||||
    filesystem = FileSystem()
 | 
			
		||||
    folder = helper.populate_folder(5)
 | 
			
		||||
 | 
			
		||||
    files = filesystem.get_all_files(folder)
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder))
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 5, length
 | 
			
		||||
 | 
			
		||||
    files = filesystem.get_all_files(folder, 'jpg')
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder, 'jpg'))
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 3, length
 | 
			
		||||
 | 
			
		||||
    files = filesystem.get_all_files(folder, 'txt')
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder, 'txt'))
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 2, length
 | 
			
		||||
 | 
			
		||||
    files = filesystem.get_all_files(folder, 'gif')
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder, 'gif'))
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 0, length
 | 
			
		||||
 | 
			
		||||
    shutil.rmtree(folder)
 | 
			
		||||
 | 
			
		||||
def test_get_all_files_with_only_invalid_file():
 | 
			
		||||
    filesystem = FileSystem()
 | 
			
		||||
    folder = helper.populate_folder(0, include_invalid=True)
 | 
			
		||||
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder))
 | 
			
		||||
    shutil.rmtree(folder)
 | 
			
		||||
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 0, length
 | 
			
		||||
 | 
			
		||||
def test_get_all_files_with_invalid_file():
 | 
			
		||||
    filesystem = FileSystem()
 | 
			
		||||
    folder = helper.populate_folder(5, include_invalid=True)
 | 
			
		||||
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update(filesystem.get_all_files(folder))
 | 
			
		||||
    shutil.rmtree(folder)
 | 
			
		||||
 | 
			
		||||
    length = len(files)
 | 
			
		||||
    assert length == 5, length
 | 
			
		||||
 | 
			
		||||
def test_get_all_files_for_loop():
 | 
			
		||||
    filesystem = FileSystem()
 | 
			
		||||
    folder = helper.populate_folder(5)
 | 
			
		||||
 | 
			
		||||
    files = set()
 | 
			
		||||
    files.update()
 | 
			
		||||
    counter = 0
 | 
			
		||||
    for file in filesystem.get_all_files(folder):
 | 
			
		||||
        counter += 1
 | 
			
		||||
    shutil.rmtree(folder)
 | 
			
		||||
 | 
			
		||||
    assert counter == 5, counter
 | 
			
		||||
 | 
			
		||||
def test_get_current_directory():
 | 
			
		||||
    filesystem = FileSystem()
 | 
			
		||||
    assert os.getcwd() == filesystem.get_current_directory()
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ def get_file_path(name):
 | 
			
		||||
def get_test_location():
 | 
			
		||||
    return (61.013710, 99.196656, 'Siberia')
 | 
			
		||||
 | 
			
		||||
def populate_folder(number_of_files):
 | 
			
		||||
def populate_folder(number_of_files, include_invalid=False):
 | 
			
		||||
    folder = '%s/%s' % (tempfile.gettempdir(), random_string(10))
 | 
			
		||||
    os.makedirs(folder)
 | 
			
		||||
 | 
			
		||||
@ -67,6 +67,11 @@ def populate_folder(number_of_files):
 | 
			
		||||
        with open(fname, 'a'):
 | 
			
		||||
            os.utime(fname, None)
 | 
			
		||||
 | 
			
		||||
    if include_invalid:
 | 
			
		||||
        fname = '%s/%s' % (folder, 'invalid.invalid')
 | 
			
		||||
        with open(fname, 'a'):
 | 
			
		||||
            os.utime(fname, None)
 | 
			
		||||
 | 
			
		||||
    return folder
 | 
			
		||||
 | 
			
		||||
def random_string(length):
 | 
			
		||||
 | 
			
		||||
@ -18,17 +18,20 @@ from elodie import constants
 | 
			
		||||
from elodie import log
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def call_log_and_assert(func, message, expected):
 | 
			
		||||
def call_log_and_assert(func, args, expected):
 | 
			
		||||
    saved_stdout = sys.stdout
 | 
			
		||||
    try:
 | 
			
		||||
        out = StringIO()
 | 
			
		||||
        sys.stdout = out
 | 
			
		||||
        func(message)
 | 
			
		||||
        output = out.getvalue().strip()
 | 
			
		||||
        assert output == expected, (func, output)
 | 
			
		||||
        func(*args)
 | 
			
		||||
        output = out.getvalue()
 | 
			
		||||
        assert output == expected, (expected, func, output)
 | 
			
		||||
    finally:
 | 
			
		||||
        sys.stdout = saved_stdout
 | 
			
		||||
 | 
			
		||||
def with_new_line(string):
 | 
			
		||||
    return "{}\n".format(string)
 | 
			
		||||
 | 
			
		||||
@patch('elodie.log')
 | 
			
		||||
@patch('elodie.constants.debug', True)
 | 
			
		||||
def test_calls_print_debug_true(fake_log):
 | 
			
		||||
@ -37,14 +40,14 @@ def test_calls_print_debug_true(fake_log):
 | 
			
		||||
    fake_log.warn.return_value = expected
 | 
			
		||||
    fake_log.error.return_value = expected
 | 
			
		||||
    for func in [log.info, log.warn, log.error]:
 | 
			
		||||
        call_log_and_assert(func, expected, expected)
 | 
			
		||||
        call_log_and_assert(func, [expected], with_new_line(expected))
 | 
			
		||||
 | 
			
		||||
    expected_json = {'foo':'bar'}
 | 
			
		||||
    fake_log.info.return_value = expected_json
 | 
			
		||||
    fake_log.warn.return_value = expected_json
 | 
			
		||||
    fake_log.error.return_value = expected_json
 | 
			
		||||
    for func in [log.info_json, log.warn_json, log.error_json]:
 | 
			
		||||
        call_log_and_assert(func, expected_json, dumps(expected_json))
 | 
			
		||||
        call_log_and_assert(func, [expected_json], with_new_line(dumps(expected_json)))
 | 
			
		||||
 | 
			
		||||
@patch('elodie.log')
 | 
			
		||||
@patch('elodie.constants.debug', False)
 | 
			
		||||
@ -54,11 +57,27 @@ def test_calls_print_debug_false(fake_log):
 | 
			
		||||
    fake_log.warn.return_value = expected
 | 
			
		||||
    fake_log.error.return_value = expected
 | 
			
		||||
    for func in [log.info, log.warn, log.error]:
 | 
			
		||||
        call_log_and_assert(func, expected, '')
 | 
			
		||||
        call_log_and_assert(func, [expected], '')
 | 
			
		||||
 | 
			
		||||
    expected_json = {'foo':'bar'}
 | 
			
		||||
    fake_log.info.return_value = expected_json
 | 
			
		||||
    fake_log.warn.return_value = expected_json
 | 
			
		||||
    fake_log.error.return_value = expected_json
 | 
			
		||||
    for func in [log.info_json, log.warn_json, log.error_json]:
 | 
			
		||||
        call_log_and_assert(func, expected_json, '')
 | 
			
		||||
        call_log_and_assert(func, [expected_json], '')
 | 
			
		||||
 | 
			
		||||
@patch('elodie.log')
 | 
			
		||||
def test_calls_print_progress_no_new_line(fake_log):
 | 
			
		||||
    expected = 'some other string'
 | 
			
		||||
    fake_log.info.return_value = expected
 | 
			
		||||
    fake_log.warn.return_value = expected
 | 
			
		||||
    fake_log.error.return_value = expected
 | 
			
		||||
    call_log_and_assert(log.progress, [expected], expected)
 | 
			
		||||
 | 
			
		||||
@patch('elodie.log')
 | 
			
		||||
def test_calls_print_progress_with_new_line(fake_log):
 | 
			
		||||
    expected = "some other string\n"
 | 
			
		||||
    fake_log.info.return_value = expected
 | 
			
		||||
    fake_log.warn.return_value = expected
 | 
			
		||||
    fake_log.error.return_value = expected
 | 
			
		||||
    call_log_and_assert(log.progress, [expected, True], with_new_line(expected))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user