Merge pull request #163 from jmathai/use-generator

Convert FileSystem.get_all_files to be a generator function
This commit is contained in:
Jaisen Mathai 2016-12-19 22:04:00 -08:00 committed by GitHub
commit 5b4c74280a
6 changed files with 109 additions and 40 deletions

View File

@ -115,34 +115,21 @@ def _generate_db(source):
result = Result()
source = os.path.abspath(os.path.expanduser(source))
extensions = set()
all_files = set()
valid_files = set()
if not os.path.isdir(source):
log.error('Source is not a valid directory %s' % source)
sys.exit(1)
subclasses = get_all_subclasses(Base)
for cls in subclasses:
extensions.update(cls.extensions)
all_files.update(FILESYSTEM.get_all_files(source, None))
db = Db()
db.backup_hash_db()
db.reset_hash_db()
for current_file in all_files:
if os.path.splitext(current_file)[1][1:].lower() not in extensions:
log.info('Skipping invalid file %s' % current_file)
result.append((current_file, False))
continue
for current_file in FILESYSTEM.get_all_files(source):
result.append((current_file, True))
db.add_hash(db.checksum(current_file), current_file)
log.progress()
db.update_hash_db()
log.progress('', True)
result.write()
@click.command('verify')
@ -152,14 +139,18 @@ def _verify():
for checksum, file_path in db.all():
if not os.path.isfile(file_path):
result.append((file_path, False))
log.progress('x')
continue
actual_checksum = db.checksum(file_path)
if checksum == actual_checksum:
result.append((file_path, True))
log.progress()
else:
result.append((file_path, False))
log.progress('x')
log.progress('', True)
result.write()

View File

@ -14,6 +14,7 @@ import time
from elodie import geolocation
from elodie import log
from elodie.localstorage import Db
from elodie.media.base import Base, get_all_subclasses
class FileSystem(object):
@ -61,17 +62,20 @@ class FileSystem(object):
:param str path string: Path to start recursive file listing
:param tuple(str) extensions: File extensions to include (whitelist)
:returns: generator
"""
files = []
# If extensions is None then we get all supported extensions
if not extensions:
extensions = set()
subclasses = get_all_subclasses(Base)
for cls in subclasses:
extensions.update(cls.extensions)
for dirname, dirnames, filenames in os.walk(path):
# print path to all filenames.
for filename in filenames:
if(
extensions is None or
filename.lower().endswith(extensions)
):
files.append(os.path.join(dirname, filename))
return files
# If file extension is in `extensions` then append to the list
if os.path.splitext(filename)[1][1:].lower() in extensions:
yield os.path.join(dirname, filename)
def get_current_directory(self):
"""Get the current working directory.

View File

@ -5,6 +5,8 @@ General file system methods.
"""
from __future__ import print_function
import sys
from json import dumps
from elodie import constants
@ -18,6 +20,13 @@ def info_json(payload):
_print(dumps(payload))
def progress(message='.', new_line=False):
if not new_line:
print(message, end="")
else:
print(message)
def warn(message):
_print(message)

View File

@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirna
from . import helper
from elodie.filesystem import FileSystem
from elodie.media.text import Text
from elodie.media.media import Media
from elodie.media.photo import Photo
from elodie.media.video import Video
@ -102,35 +103,75 @@ def test_delete_directory_if_empty_when_not_empty():
def test_get_all_files_success():
filesystem = FileSystem()
folder = helper.populate_folder(5)
files = filesystem.get_all_files(folder)
files = set()
files.update(filesystem.get_all_files(folder))
shutil.rmtree(folder)
length = len(files)
assert length == 5, length
assert length == 5, files
def test_get_all_files_by_extension():
filesystem = FileSystem()
folder = helper.populate_folder(5)
files = filesystem.get_all_files(folder)
files = set()
files.update(filesystem.get_all_files(folder))
length = len(files)
assert length == 5, length
files = filesystem.get_all_files(folder, 'jpg')
files = set()
files.update(filesystem.get_all_files(folder, 'jpg'))
length = len(files)
assert length == 3, length
files = filesystem.get_all_files(folder, 'txt')
files = set()
files.update(filesystem.get_all_files(folder, 'txt'))
length = len(files)
assert length == 2, length
files = filesystem.get_all_files(folder, 'gif')
files = set()
files.update(filesystem.get_all_files(folder, 'gif'))
length = len(files)
assert length == 0, length
shutil.rmtree(folder)
def test_get_all_files_with_only_invalid_file():
filesystem = FileSystem()
folder = helper.populate_folder(0, include_invalid=True)
files = set()
files.update(filesystem.get_all_files(folder))
shutil.rmtree(folder)
length = len(files)
assert length == 0, length
def test_get_all_files_with_invalid_file():
filesystem = FileSystem()
folder = helper.populate_folder(5, include_invalid=True)
files = set()
files.update(filesystem.get_all_files(folder))
shutil.rmtree(folder)
length = len(files)
assert length == 5, length
def test_get_all_files_for_loop():
filesystem = FileSystem()
folder = helper.populate_folder(5)
files = set()
files.update()
counter = 0
for file in filesystem.get_all_files(folder):
counter += 1
shutil.rmtree(folder)
assert counter == 5, counter
def test_get_current_directory():
filesystem = FileSystem()
assert os.getcwd() == filesystem.get_current_directory()

View File

@ -57,7 +57,7 @@ def get_file_path(name):
def get_test_location():
return (61.013710, 99.196656, 'Siberia')
def populate_folder(number_of_files):
def populate_folder(number_of_files, include_invalid=False):
folder = '%s/%s' % (tempfile.gettempdir(), random_string(10))
os.makedirs(folder)
@ -67,6 +67,11 @@ def populate_folder(number_of_files):
with open(fname, 'a'):
os.utime(fname, None)
if include_invalid:
fname = '%s/%s' % (folder, 'invalid.invalid')
with open(fname, 'a'):
os.utime(fname, None)
return folder
def random_string(length):

View File

@ -18,17 +18,20 @@ from elodie import constants
from elodie import log
def call_log_and_assert(func, message, expected):
def call_log_and_assert(func, args, expected):
saved_stdout = sys.stdout
try:
out = StringIO()
sys.stdout = out
func(message)
output = out.getvalue().strip()
assert output == expected, (func, output)
func(*args)
output = out.getvalue()
assert output == expected, (expected, func, output)
finally:
sys.stdout = saved_stdout
def with_new_line(string):
return "{}\n".format(string)
@patch('elodie.log')
@patch('elodie.constants.debug', True)
def test_calls_print_debug_true(fake_log):
@ -37,14 +40,14 @@ def test_calls_print_debug_true(fake_log):
fake_log.warn.return_value = expected
fake_log.error.return_value = expected
for func in [log.info, log.warn, log.error]:
call_log_and_assert(func, expected, expected)
call_log_and_assert(func, [expected], with_new_line(expected))
expected_json = {'foo':'bar'}
fake_log.info.return_value = expected_json
fake_log.warn.return_value = expected_json
fake_log.error.return_value = expected_json
for func in [log.info_json, log.warn_json, log.error_json]:
call_log_and_assert(func, expected_json, dumps(expected_json))
call_log_and_assert(func, [expected_json], with_new_line(dumps(expected_json)))
@patch('elodie.log')
@patch('elodie.constants.debug', False)
@ -54,11 +57,27 @@ def test_calls_print_debug_false(fake_log):
fake_log.warn.return_value = expected
fake_log.error.return_value = expected
for func in [log.info, log.warn, log.error]:
call_log_and_assert(func, expected, '')
call_log_and_assert(func, [expected], '')
expected_json = {'foo':'bar'}
fake_log.info.return_value = expected_json
fake_log.warn.return_value = expected_json
fake_log.error.return_value = expected_json
for func in [log.info_json, log.warn_json, log.error_json]:
call_log_and_assert(func, expected_json, '')
call_log_and_assert(func, [expected_json], '')
@patch('elodie.log')
def test_calls_print_progress_no_new_line(fake_log):
expected = 'some other string'
fake_log.info.return_value = expected
fake_log.warn.return_value = expected
fake_log.error.return_value = expected
call_log_and_assert(log.progress, [expected], expected)
@patch('elodie.log')
def test_calls_print_progress_with_new_line(fake_log):
expected = "some other string\n"
fake_log.info.return_value = expected
fake_log.warn.return_value = expected
fake_log.error.return_value = expected
call_log_and_assert(log.progress, [expected, True], with_new_line(expected))