From 3bd06f0321c899d93c5a4446f02faee8f0c84c1e Mon Sep 17 00:00:00 2001 From: Cedric Leporcq Date: Sat, 14 Aug 2021 21:35:12 +0200 Subject: [PATCH] Fix filter-by-ext --- ordigi/filesystem.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/ordigi/filesystem.py b/ordigi/filesystem.py index 6cfb7fb..7ee100d 100644 --- a/ordigi/filesystem.py +++ b/ordigi/filesystem.py @@ -27,13 +27,19 @@ class FileSystem(object): """A class for interacting with the file system.""" def __init__(self, cache=False, day_begins=0, dry_run=False, exclude_regex_list=set(), - filter_by_ext=(), logger=logging.getLogger(), max_deep=None, + filter_by_ext=set(), logger=logging.getLogger(), max_deep=None, mode='copy', path_format=None): self.cache = cache self.day_begins = day_begins self.dry_run = dry_run self.exclude_regex_list = exclude_regex_list - self.filter_by_ext = filter_by_ext + + if '%media' in filter_by_ext: + filter_by_ext.remove('%media') + self.filter_by_ext = filter_by_ext.union(media.extensions) + else: + self.filter_by_ext = filter_by_ext + self.items = self.get_items() self.logger = logger self.max_deep = max_deep @@ -507,24 +513,13 @@ class FileSystem(object): return self.summary, has_errors - - def get_files_in_path(self, path, extensions=False): + def get_files_in_path(self, path, extensions=set()): """Recursively get files which match a path and extension. :param str path string: Path to start recursive file listing :param tuple(str) extensions: File extensions to include (whitelist) :returns: file_path, subdirs """ - if self.filter_by_ext != () and not extensions: - # Filtering files by extensions. - if '%media' in self.filter_by_ext: - extensions = set() - subclasses = get_all_subclasses() - for cls in subclasses: - extensions.update(cls.extensions) - else: - extensions = self.filter_by_ext - file_list = set() if os.path.isfile(path): if not self.should_exclude(path, self.exclude_regex_list, True): @@ -547,7 +542,7 @@ class FileSystem(object): # Then append to the list filename_path = os.path.join(dirname, filename) if ( - extensions == False + extensions == set() or os.path.splitext(filename)[1][1:].lower() in extensions and not self.should_exclude(filename_path, compiled_regex_list, False) ):