Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions luigi/contrib/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
import tempfile
import time
from io import BytesIO
from typing import IO, Any, Optional
from urllib.parse import urlsplit

from tenacity import after_log, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, wait_exponential

import luigi.target
from luigi.contrib import gcp
from luigi.format import FileWrapper
from luigi.format import FileWrapper, Format

logger = logging.getLogger("luigi-interface")

# Retry when following errors happened
RETRYABLE_ERRORS = None
RETRYABLE_ERRORS: tuple[type[BaseException], ...] = ()

try:
import httplib2
Expand Down Expand Up @@ -74,7 +75,7 @@ def is_error_5xx(err):
wait=wait_exponential(multiplier=1, min=1, max=10),
stop=stop_after_attempt(5),
reraise=True,
after=after_log(logger, logging.WARNING),
after=after_log(logger, logging.WARNING), # type: ignore[arg-type]
)


Expand Down Expand Up @@ -432,17 +433,17 @@ def move_to_final_destination(self):


class GCSTarget(luigi.target.FileSystemTarget):
fs = None
fs: GCSClient

def __init__(self, path, format=None, client=None):
def __init__(self, path: str, format: Optional[Format] = None, client: Optional[GCSClient] = None):
super(GCSTarget, self).__init__(path)
if format is None:
format = luigi.format.get_default_format()

self.format = format
self.fs = client or GCSClient()

def open(self, mode="r"):
def open(self, mode: str = "r") -> IO[Any]:
if mode == "r":
return self.format.pipe_reader(FileWrapper(io.BufferedReader(self.fs.download(self.path))))
elif mode == "w":
Expand Down Expand Up @@ -471,9 +472,9 @@ class GCSFlagTarget(GCSTarget):
If we have 1,000,000 output files, then we have to rename 1,000,000 objects.
"""

fs = None
fs: GCSClient

def __init__(self, path, format=None, client=None, flag="_SUCCESS"):
def __init__(self, path: str, format: Optional[Format] = None, client: Optional[GCSClient] = None, flag: str = "_SUCCESS"):
"""
Initializes a GCSFlagTarget.

Expand Down
4 changes: 2 additions & 2 deletions luigi/contrib/hadoop.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# See benchmark at https://gist.github.com/mvj3/02dca2bcc8b0ef1bbfb5
import ujson as json
except ImportError:
import json
import json # type: ignore[no-redef]

logger = logging.getLogger("luigi-interface")

Expand Down Expand Up @@ -666,7 +666,7 @@ class BaseHadoopJobTask(luigi.Task):
mr_priority = NotImplemented
package_binary = None

_counter_dict = {}
_counter_dict: dict[tuple[str, ...], int] = {}
task_id = None

def _get_pool(self):
Expand Down
11 changes: 6 additions & 5 deletions luigi/contrib/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import warnings
from configparser import NoSectionError
from multiprocessing.pool import ThreadPool
from typing import Optional
from urllib.parse import urlsplit

from luigi import configuration
from luigi.format import get_default_format
from luigi.format import Format, get_default_format
from luigi.parameter import OptionalParameter, Parameter
from luigi.target import AtomicLocalFile, FileAlreadyExists, FileSystem, FileSystemException, FileSystemTarget, MissingParentDirectory
from luigi.task import ExternalTask
Expand Down Expand Up @@ -618,9 +619,9 @@ class S3Target(FileSystemTarget):
:param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload`
"""

fs = None
fs: FileSystem

def __init__(self, path, format=None, client=None, **kwargs):
def __init__(self, path: str, format: Optional[Format] = None, client: Optional[S3Client] = None, **kwargs):
super(S3Target, self).__init__(path)
if format is None:
format = get_default_format()
Expand Down Expand Up @@ -665,9 +666,9 @@ class S3FlagTarget(S3Target):
If we have 1,000,000 output files, then we have to rename 1,000,000 objects.
"""

fs = None
fs: FileSystem

def __init__(self, path, format=None, client=None, flag="_SUCCESS"):
def __init__(self, path: str, format: Optional[Format] = None, client: Optional[S3Client] = None, flag="_SUCCESS"):
"""
Initializes a S3FlagTarget.

Expand Down
9 changes: 5 additions & 4 deletions luigi/contrib/webhdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
"""

import logging
from typing import Optional

import luigi.contrib.hdfs
from luigi.format import get_default_format
from luigi.target import AtomicLocalFile, FileSystemTarget
from luigi.format import Format, get_default_format
from luigi.target import AtomicLocalFile, FileSystem, FileSystemTarget

logger = logging.getLogger("luigi-interface")


class WebHdfsTarget(FileSystemTarget):
fs = None
fs: FileSystem

def __init__(self, path, client=None, format=None):
def __init__(self, path: str, client: Optional[FileSystem] = None, format: Optional[Format] = None):
super(WebHdfsTarget, self).__init__(path)
path = self.path
self.fs = client or WebHdfsClient()
Expand Down
2 changes: 1 addition & 1 deletion luigi/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,5 +491,5 @@ def pipe_writer(self, output_pipe):
MixedUnicodeBytes = MixedUnicodeBytesFormat()


def get_default_format():
def get_default_format() -> Format:
return Text
61 changes: 34 additions & 27 deletions luigi/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import tempfile
import warnings
from contextlib import contextmanager
from types import TracebackType
from typing import IO, Any, Generator, Iterator, Optional, Union

logger = logging.getLogger("luigi-interface")

Expand All @@ -44,7 +46,7 @@ class Target(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def exists(self):
def exists(self) -> bool:
"""
Returns ``True`` if the :py:class:`Target` exists and ``False`` otherwise.
"""
Expand Down Expand Up @@ -99,7 +101,7 @@ class FileSystem(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def exists(self, path):
def exists(self, path: str) -> bool:
"""
Return ``True`` if file or directory at ``path`` exist, ``False`` otherwise

Expand All @@ -108,7 +110,7 @@ def exists(self, path):
pass

@abc.abstractmethod
def remove(self, path, recursive=True, skip_trash=True):
def remove(self, path: str, recursive: bool = True, skip_trash: bool = True) -> None:
"""Remove file or directory at location ``path``

:param str path: a path within the FileSystem to remove.
Expand All @@ -117,7 +119,7 @@ def remove(self, path, recursive=True, skip_trash=True):
"""
pass

def mkdir(self, path, parents=True, raise_if_exists=False):
def mkdir(self, path: str, parents: bool = True, raise_if_exists: bool = False) -> None:
"""
Create directory at location ``path``

Expand All @@ -133,7 +135,7 @@ def mkdir(self, path, parents=True, raise_if_exists=False):
"""
raise NotImplementedError("mkdir() not implemented on {0}".format(self.__class__.__name__))

def isdir(self, path):
def isdir(self, path: str) -> bool:
"""
Return ``True`` if the location at ``path`` is a directory. If not, return ``False``.

Expand All @@ -143,7 +145,7 @@ def isdir(self, path):
"""
raise NotImplementedError("isdir() not implemented on {0}".format(self.__class__.__name__))

def listdir(self, path):
def listdir(self, path: str) -> Iterator[str]:
"""Return a list of files rooted in path.

This returns an iterable of the files rooted at ``path``. This is intended to be a
Expand All @@ -155,13 +157,13 @@ def listdir(self, path):
"""
raise NotImplementedError("listdir() not implemented on {0}".format(self.__class__.__name__))

def move(self, path, dest):
def move(self, path: str, dest: str) -> None:
"""
Move a file, as one would expect.
"""
raise NotImplementedError("move() not implemented on {0}".format(self.__class__.__name__))

def rename_dont_move(self, path, dest):
def rename_dont_move(self, path: str, dest: str) -> None:
"""
Potentially rename ``path`` to ``dest``, but don't move it into the
``dest`` folder (if it is a folder). This relates to :ref:`AtomicWrites`.
Expand All @@ -175,13 +177,13 @@ def rename_dont_move(self, path, dest):
raise FileAlreadyExists()
self.move(path, dest)

def rename(self, *args, **kwargs):
def rename(self, *args: Any, **kwargs: Any) -> None:
"""
Alias for ``move()``
"""
self.move(*args, **kwargs)

def copy(self, path, dest):
def copy(self, path: str, dest: str) -> None:
"""
Copy a file or a directory with contents.
Currently, LocalFileSystem and MockFileSystem support only single file
Expand Down Expand Up @@ -209,7 +211,7 @@ class FileSystemTarget(Target):
target.exists() # False
"""

def __init__(self, path):
def __init__(self, path: Union[str, "os.PathLike[str]"]) -> None:
"""
Initializes a FileSystemTarget instance.

Expand All @@ -218,19 +220,19 @@ def __init__(self, path):
# cast to str to allow path to be objects like pathlib.PosixPath and py._path.local.LocalPath
self.path = str(path)

def __str__(self):
def __str__(self) -> str:
return self.path

@property
@abc.abstractmethod
def fs(self):
def fs(self) -> FileSystem:
"""
The :py:class:`FileSystem` associated with this FileSystemTarget.
"""
raise NotImplementedError()

@abc.abstractmethod
def open(self, mode):
def open(self, mode: str) -> IO[Any]:
"""
Open the FileSystem target.

Expand All @@ -244,7 +246,7 @@ def open(self, mode):
"""
pass

def exists(self):
def exists(self) -> bool:
"""
Returns ``True`` if the path for this FileSystemTarget exists; ``False`` otherwise.

Expand All @@ -255,7 +257,7 @@ def exists(self):
logger.warning("Using wildcards in path %s might lead to processing of an incomplete dataset; override exists() to suppress the warning.", path)
return self.fs.exists(path)

def remove(self):
def remove(self) -> None:
"""
Remove the resource at the path specified by this FileSystemTarget.

Expand All @@ -264,7 +266,7 @@ def remove(self):
self.fs.remove(self.path)

@contextmanager
def temporary_path(self):
def temporary_path(self) -> Generator[str, None, None]:
"""
A context manager that enables a reasonably short, general and
magic-less way to solve the :ref:`AtomicWrites`.
Expand Down Expand Up @@ -301,11 +303,11 @@ def run(self):
# We won't reach here if there was an user exception.
self.fs.rename_dont_move(_temp_path, self.path)

def _touchz(self):
def _touchz(self) -> None:
with self.open("w"):
pass

def _trailing_slash(self):
def _trailing_slash(self) -> str:
# I suppose one day schema-like paths, like
# file:///path/blah.txt?params=etc can be parsed too
return self.path[-1] if self.path[-1] in r"\/" else ""
Expand All @@ -320,31 +322,36 @@ class AtomicLocalFile(io.BufferedWriter):
:class:`luigi.local_target.LocalTarget` for example
"""

def __init__(self, path):
def __init__(self, path: str) -> None:
self.__tmp_path = self.generate_tmp_path(path)
self.path = path
super(AtomicLocalFile, self).__init__(io.FileIO(self.__tmp_path, "w"))

def close(self):
def close(self) -> None:
super(AtomicLocalFile, self).close()
self.move_to_final_destination()

def generate_tmp_path(self, path):
def generate_tmp_path(self, path: str) -> str:
return os.path.join(tempfile.gettempdir(), "luigi-s3-tmp-%09d" % random.randrange(0, 10_000_000_000))

def move_to_final_destination(self):
def move_to_final_destination(self) -> None:
raise NotImplementedError()

def __del__(self):
def __del__(self) -> None:
if os.path.exists(self.tmp_path):
os.remove(self.tmp_path)

@property
def tmp_path(self):
def tmp_path(self) -> str:
return self.__tmp_path

def __exit__(self, exc_type, exc, traceback):
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"Close/commit the file if there are no exception"
if exc_type:
return
return super(AtomicLocalFile, self).__exit__(exc_type, exc, traceback)
super(AtomicLocalFile, self).__exit__(exc_type, exc, traceback)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ common = [
"types-python-dateutil",
"types-requests",
"types-toml",
"types-ujson>=5.10.0.20250822",
]

docs = [
Expand Down Expand Up @@ -182,8 +183,6 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"luigi.contrib.gcs",
"luigi.contrib.hadoop",
"luigi.contrib.hdfs.config",
"luigi.contrib.postgres",
"luigi.contrib.redis_store",
Expand Down
Loading
Loading