Convert FileSystem.get_all_files to be a generator function

This commit is contained in:
Jaisen Mathai 2016-12-19 21:12:36 -08:00
parent a7c9a5ffbd
commit 64773655d2
6 changed files with 109 additions and 40 deletions

View File

@ -115,34 +115,21 @@ def _generate_db(source):
result = Result() result = Result()
source = os.path.abspath(os.path.expanduser(source)) source = os.path.abspath(os.path.expanduser(source))
extensions = set()
all_files = set()
valid_files = set()
if not os.path.isdir(source): if not os.path.isdir(source):
log.error('Source is not a valid directory %s' % source) log.error('Source is not a valid directory %s' % source)
sys.exit(1) sys.exit(1)
subclasses = get_all_subclasses(Base)
for cls in subclasses:
extensions.update(cls.extensions)
all_files.update(FILESYSTEM.get_all_files(source, None))
db = Db() db = Db()
db.backup_hash_db() db.backup_hash_db()
db.reset_hash_db() db.reset_hash_db()
for current_file in all_files: for current_file in FILESYSTEM.get_all_files(source):
if os.path.splitext(current_file)[1][1:].lower() not in extensions:
log.info('Skipping invalid file %s' % current_file)
result.append((current_file, False))
continue
result.append((current_file, True)) result.append((current_file, True))
db.add_hash(db.checksum(current_file), current_file) db.add_hash(db.checksum(current_file), current_file)
log.progress()
db.update_hash_db() db.update_hash_db()
log.progress('', True)
result.write() result.write()
@click.command('verify') @click.command('verify')
@ -152,14 +139,18 @@ def _verify():
for checksum, file_path in db.all(): for checksum, file_path in db.all():
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
result.append((file_path, False)) result.append((file_path, False))
log.progress('x')
continue continue
actual_checksum = db.checksum(file_path) actual_checksum = db.checksum(file_path)
if checksum == actual_checksum: if checksum == actual_checksum:
result.append((file_path, True)) result.append((file_path, True))
log.progress()
else: else:
result.append((file_path, False)) result.append((file_path, False))
log.progress('x')
log.progress('', True)
result.write() result.write()

View File

@ -14,6 +14,7 @@ import time
from elodie import geolocation from elodie import geolocation
from elodie import log from elodie import log
from elodie.localstorage import Db from elodie.localstorage import Db
from elodie.media.base import Base, get_all_subclasses
class FileSystem(object): class FileSystem(object):
@ -61,17 +62,20 @@ class FileSystem(object):
:param str path string: Path to start recursive file listing :param str path string: Path to start recursive file listing
:param tuple(str) extensions: File extensions to include (whitelist) :param tuple(str) extensions: File extensions to include (whitelist)
:returns: generator
""" """
files = [] # If extensions is None then we get all supported extensions
if not extensions:
extensions = set()
subclasses = get_all_subclasses(Base)
for cls in subclasses:
extensions.update(cls.extensions)
for dirname, dirnames, filenames in os.walk(path): for dirname, dirnames, filenames in os.walk(path):
# print path to all filenames.
for filename in filenames: for filename in filenames:
if( # If file extension is in `extensions` then append to the list
extensions is None or if os.path.splitext(filename)[1][1:].lower() in extensions:
filename.lower().endswith(extensions) yield os.path.join(dirname, filename)
):
files.append(os.path.join(dirname, filename))
return files
def get_current_directory(self): def get_current_directory(self):
"""Get the current working directory. """Get the current working directory.

View File

@ -5,6 +5,8 @@ General file system methods.
""" """
from __future__ import print_function from __future__ import print_function
import sys
from json import dumps from json import dumps
from elodie import constants from elodie import constants
@ -18,6 +20,13 @@ def info_json(payload):
_print(dumps(payload)) _print(dumps(payload))
def progress(message='.', new_line=False):
if not new_line:
print(message, end="")
else:
print(message)
def warn(message): def warn(message):
_print(message) _print(message)

View File

@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirna
from . import helper from . import helper
from elodie.filesystem import FileSystem from elodie.filesystem import FileSystem
from elodie.media.text import Text
from elodie.media.media import Media from elodie.media.media import Media
from elodie.media.photo import Photo from elodie.media.photo import Photo
from elodie.media.video import Video from elodie.media.video import Video
@ -102,35 +103,75 @@ def test_delete_directory_if_empty_when_not_empty():
def test_get_all_files_success(): def test_get_all_files_success():
filesystem = FileSystem() filesystem = FileSystem()
folder = helper.populate_folder(5) folder = helper.populate_folder(5)
files = filesystem.get_all_files(folder)
files = set()
files.update(filesystem.get_all_files(folder))
shutil.rmtree(folder) shutil.rmtree(folder)
length = len(files) length = len(files)
assert length == 5, length assert length == 5, files
def test_get_all_files_by_extension(): def test_get_all_files_by_extension():
filesystem = FileSystem() filesystem = FileSystem()
folder = helper.populate_folder(5) folder = helper.populate_folder(5)
files = filesystem.get_all_files(folder) files = set()
files.update(filesystem.get_all_files(folder))
length = len(files) length = len(files)
assert length == 5, length assert length == 5, length
files = filesystem.get_all_files(folder, 'jpg') files = set()
files.update(filesystem.get_all_files(folder, 'jpg'))
length = len(files) length = len(files)
assert length == 3, length assert length == 3, length
files = filesystem.get_all_files(folder, 'txt') files = set()
files.update(filesystem.get_all_files(folder, 'txt'))
length = len(files) length = len(files)
assert length == 2, length assert length == 2, length
files = filesystem.get_all_files(folder, 'gif') files = set()
files.update(filesystem.get_all_files(folder, 'gif'))
length = len(files) length = len(files)
assert length == 0, length assert length == 0, length
shutil.rmtree(folder) shutil.rmtree(folder)
def test_get_all_files_with_only_invalid_file():
filesystem = FileSystem()
folder = helper.populate_folder(0, include_invalid=True)
files = set()
files.update(filesystem.get_all_files(folder))
shutil.rmtree(folder)
length = len(files)
assert length == 0, length
def test_get_all_files_with_invalid_file():
filesystem = FileSystem()
folder = helper.populate_folder(5, include_invalid=True)
files = set()
files.update(filesystem.get_all_files(folder))
shutil.rmtree(folder)
length = len(files)
assert length == 5, length
def test_get_all_files_for_loop():
filesystem = FileSystem()
folder = helper.populate_folder(5)
files = set()
files.update()
counter = 0
for file in filesystem.get_all_files(folder):
counter += 1
shutil.rmtree(folder)
assert counter == 5, counter
def test_get_current_directory(): def test_get_current_directory():
filesystem = FileSystem() filesystem = FileSystem()
assert os.getcwd() == filesystem.get_current_directory() assert os.getcwd() == filesystem.get_current_directory()

View File

@ -57,7 +57,7 @@ def get_file_path(name):
def get_test_location(): def get_test_location():
return (61.013710, 99.196656, 'Siberia') return (61.013710, 99.196656, 'Siberia')
def populate_folder(number_of_files): def populate_folder(number_of_files, include_invalid=False):
folder = '%s/%s' % (tempfile.gettempdir(), random_string(10)) folder = '%s/%s' % (tempfile.gettempdir(), random_string(10))
os.makedirs(folder) os.makedirs(folder)
@ -67,6 +67,11 @@ def populate_folder(number_of_files):
with open(fname, 'a'): with open(fname, 'a'):
os.utime(fname, None) os.utime(fname, None)
if include_invalid:
fname = '%s/%s' % (folder, 'invalid.invalid')
with open(fname, 'a'):
os.utime(fname, None)
return folder return folder
def random_string(length): def random_string(length):

View File

@ -18,17 +18,20 @@ from elodie import constants
from elodie import log from elodie import log
def call_log_and_assert(func, message, expected): def call_log_and_assert(func, args, expected):
saved_stdout = sys.stdout saved_stdout = sys.stdout
try: try:
out = StringIO() out = StringIO()
sys.stdout = out sys.stdout = out
func(message) func(*args)
output = out.getvalue().strip() output = out.getvalue()
assert output == expected, (func, output) assert output == expected, (expected, func, output)
finally: finally:
sys.stdout = saved_stdout sys.stdout = saved_stdout
def with_new_line(string):
return "{}\n".format(string)
@patch('elodie.log') @patch('elodie.log')
@patch('elodie.constants.debug', True) @patch('elodie.constants.debug', True)
def test_calls_print_debug_true(fake_log): def test_calls_print_debug_true(fake_log):
@ -37,14 +40,14 @@ def test_calls_print_debug_true(fake_log):
fake_log.warn.return_value = expected fake_log.warn.return_value = expected
fake_log.error.return_value = expected fake_log.error.return_value = expected
for func in [log.info, log.warn, log.error]: for func in [log.info, log.warn, log.error]:
call_log_and_assert(func, expected, expected) call_log_and_assert(func, [expected], with_new_line(expected))
expected_json = {'foo':'bar'} expected_json = {'foo':'bar'}
fake_log.info.return_value = expected_json fake_log.info.return_value = expected_json
fake_log.warn.return_value = expected_json fake_log.warn.return_value = expected_json
fake_log.error.return_value = expected_json fake_log.error.return_value = expected_json
for func in [log.info_json, log.warn_json, log.error_json]: for func in [log.info_json, log.warn_json, log.error_json]:
call_log_and_assert(func, expected_json, dumps(expected_json)) call_log_and_assert(func, [expected_json], with_new_line(dumps(expected_json)))
@patch('elodie.log') @patch('elodie.log')
@patch('elodie.constants.debug', False) @patch('elodie.constants.debug', False)
@ -54,11 +57,27 @@ def test_calls_print_debug_false(fake_log):
fake_log.warn.return_value = expected fake_log.warn.return_value = expected
fake_log.error.return_value = expected fake_log.error.return_value = expected
for func in [log.info, log.warn, log.error]: for func in [log.info, log.warn, log.error]:
call_log_and_assert(func, expected, '') call_log_and_assert(func, [expected], '')
expected_json = {'foo':'bar'} expected_json = {'foo':'bar'}
fake_log.info.return_value = expected_json fake_log.info.return_value = expected_json
fake_log.warn.return_value = expected_json fake_log.warn.return_value = expected_json
fake_log.error.return_value = expected_json fake_log.error.return_value = expected_json
for func in [log.info_json, log.warn_json, log.error_json]: for func in [log.info_json, log.warn_json, log.error_json]:
call_log_and_assert(func, expected_json, '') call_log_and_assert(func, [expected_json], '')
@patch('elodie.log')
def test_calls_print_progress_no_new_line(fake_log):
expected = 'some other string'
fake_log.info.return_value = expected
fake_log.warn.return_value = expected
fake_log.error.return_value = expected
call_log_and_assert(log.progress, [expected], expected)
@patch('elodie.log')
def test_calls_print_progress_with_new_line(fake_log):
expected = "some other string\n"
fake_log.info.return_value = expected
fake_log.warn.return_value = expected
fake_log.error.return_value = expected
call_log_and_assert(log.progress, [expected, True], with_new_line(expected))