diff --git a/setup.cfg b/setup.cfg
index aa4695c..13d3587 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -23,7 +23,7 @@ packages =
[extras]
luigi =
- luigi>=2.3.3,<3.0.0
+ luigi>=2.8.5,<3.0.0
examples =
scipy
sklearn
diff --git a/spotify_tensorflow/luigi/python_dataflow_task.py b/spotify_tensorflow/luigi/python_dataflow_task.py
deleted file mode 100644
index 31461c0..0000000
--- a/spotify_tensorflow/luigi/python_dataflow_task.py
+++ /dev/null
@@ -1,297 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Copyright 2017-2019 Spotify AB.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-import os
-import subprocess
-import time
-
-import luigi
-from luigi.task import MixinNaiveBulkComplete
-from spotify_tensorflow.luigi.utils import get_uri, run_with_logging
-
-logger = logging.getLogger("luigi-interface")
-
-
-class PythonDataflowTask(MixinNaiveBulkComplete, luigi.Task):
- """"Luigi wrapper for a dataflow job
-
- The following properties can be set:
- python_script = None # Python script for the dataflow task.
- project = None # Name of the project owning the dataflow job.
- staging_location = None # GCS path for staging code packages needed by workers.
- zone = None # GCE availability zone for launching workers.
- region = None # GCE region for creating the dataflow job.
- temp_location = None # GCS path for saving temporary workflow jobs.
- num_workers = None # The number of workers to start the task with.
- autoscaling_algorithm = None # Set to "NONE" to disable autoscaling. `num_workers`
- # will then be used for the job.
- max_num_workers = None # Used if the autoscaling is enabled.
- network = None # Network in GCE to be used for launching workers.
- subnetwork = None # Subnetwork in GCE to be used for launching workers.
- disk_size_gb = None # Remote worker disk size, if not defined uses default size.
- worker_machine_type = None # Machine type to create Dataflow worker VMs. If unset,
- # the Dataflow service will choose a reasonable default.
- worker_disk_type = None # Specify SSD for local disk or defaults to hard disk.
- service_account = None # Service account of Dataflow VMs/workers. Default is a
- default GCE service account.
- job_name = None # Name of the dataflow job
- requirements_file = None # Path to a requirements file containing package dependencies.
- local_runner = False # If local_runner = True, the job uses DirectRunner,
- otherwise it uses DataflowRunner
- setup_file = None # Path to a setup Python file containing package dependencies.
-
- :Example:
-
- class AwesomeJob(PythonDataflowJobTask):
- python_script = "/path/to/python_script"
- project = "gcp-project"
- staging_location = "gs://gcp-project-playground/user/staging"
- temp_location = "gs://gcp-project-playground/user/tmp"
- max_num_workers = 20
- region = "europe-west1"
- service_account_email = "service_account@gcp-project.iam.gserviceaccount.com"
-
- def output(self):
- ...
- """
- # Required dataflow args
- python_script = None # type: str
- project = None # type: str
- staging_location = None # type: str
-
- # Dataflow requires one and only one of:
- zone = None # type: str
- region = None # type: str
-
- # Optional dataflow args
- temp_location = None # type: str
- num_workers = None # type: int
- autoscaling_algorithm = None # type: str
- max_num_workers = None # type: int
- network = None # type: str
- subnetwork = None # type: str
- disk_size_gb = None # type: int
- worker_machine_type = None # type: str
- worker_disk_type = None # type: str
- service_account = None # type: str
- job_name = None # type: str
- requirements_file = None # type: str
- local_runner = False # type: bool
- setup_file = None # type: str
-
- def __init__(self, *args, **kwargs):
- super(PythonDataflowTask, self).__init__(*args, **kwargs)
- self._output = self.output()
- if isinstance(self._output, luigi.Target):
- self._output = {"output": self._output}
- if self.job_name is None:
- # job_name must consist of only the characters [-a-z0-9]
- cls_name = self.__class__.__name__.replace("_", "-").lower()
- self.job_name = "{cls_name}-{timestamp}".format(cls_name=cls_name,
- timestamp=str(time.time())[:-3])
-
- def on_successful_run(self):
- """ Callback that gets called right after the dataflow job has finished successfully but
- before validate_output is run.
- """
- pass
-
- def validate_output(self):
- """ Callback that can be used to validate your output before it is moved to it's final
- location. Returning false here will cause the job to fail, and output to be removed instead
- of published.
-
- :return: Whether the output is valid or not
- :rtype: Boolean
- """
- return True
-
- def file_pattern(self):
- """ If one/some of the input target files are not in the pattern of part-*,
- we can add the key of the required target and the correct file pattern
- that should be appended in the command line here. If the input target key is not found
- in this dict, the file pattern will be assumed to be part-* for that target.
-
- :return A dictionary of overrided file pattern that is not part-* for the inputs
- :rtype: Dict of String to String
- """
- return {}
-
- def run(self):
- cmd_line = self._mk_cmd_line()
- logger.info(" ".join(cmd_line))
-
- try:
- run_with_logging(cmd_line, logger)
- except subprocess.CalledProcessError as e:
- logging.error(e, exc_info=True)
- # exit luigi with the same exit code as the python dataflow job proccess
- # In this way users can easily exit the job with code 50 to avoid Styx retries
- # https://github.com/spotify/styx/blob/master/doc/design-overview.md#workflow-state-graph
- os._exit(e.returncode)
-
- self.on_successful_run()
- if self.validate_output():
- self._publish_outputs()
- else:
- raise ValueError("Output is not valid")
-
- def _publish_outputs(self):
- for (name, target) in self._output.items():
- if hasattr(target, "publish"):
- target.publish(self._output_uris[name])
-
- def _mk_cmd_line(self):
- cmd_line = self._dataflow_executable()
- cmd_line.extend(self._get_dataflow_args())
- cmd_line.extend(self._get_input_args())
- cmd_line.extend(self._get_output_args())
- cmd_line.extend(self.args())
- return cmd_line
-
- def _dataflow_executable(self):
- """
- Defines the executable used to run the python dataflow job.
- """
- return ["python", self.python_script]
-
- def _get_input_uri(self, file_pattern, target):
- uri = get_uri(target)
- uri = uri.rstrip("/") + "/" + file_pattern
- return uri
-
- def _get_file_pattern(self):
- file_pattern = self.file_pattern()
- if not isinstance(file_pattern, dict):
- raise ValueError("file_pattern() must return a dict type")
- return file_pattern
-
- def _get_input_args(self):
- """
- Collects outputs from requires() and converts them to input arguments.
- file_pattern() is called to construct input file path glob with default value "part-*"
- """
- job_input = self.input()
- if isinstance(job_input, luigi.Target):
- job_input = {"input": job_input}
- if not isinstance(job_input, dict):
- raise ValueError("Input (requires()) must be dict type")
-
- input_args = []
- file_pattern_dict = self._get_file_pattern()
- for (name, targets) in job_input.items():
- uri_targets = luigi.task.flatten(targets)
- pattern = file_pattern_dict.get(name, "part-*")
- uris = [self._get_input_uri(pattern, uri_target) for uri_target in uri_targets]
- if isinstance(targets, dict):
- # If targets is a dict that means it had multiple outputs.
- # Make the input args in that case "-"
- names = ["%s-%s" % (name, key) for key in targets.keys()]
- else:
- names = [name] * len(uris)
- for (arg_name, uri) in zip(names, uris):
- input_args.append("--%s=%s" % (arg_name, uri))
-
- return input_args
-
- def _get_output_args(self):
- if not isinstance(self._output, dict):
- raise ValueError("Output must be dict type")
-
- output_args = []
- self._output_uris = {}
-
- for (name, target) in self._output.items():
- uri = target.generate_uri() if hasattr(target, "generate_uri") else get_uri(target)
- uri = uri.rstrip("/")
- output_args.append("--%s=%s" % (name, uri))
- self._output_uris[name] = uri
-
- return output_args
-
- def _get_runner(self):
- return "DirectRunner" if self.local_runner else "DataflowRunner"
-
- def _get_dataflow_args(self):
- dataflow_args = []
-
- _runner = self._get_runner()
- if _runner:
- dataflow_args += ["--runner={}".format(_runner)]
- if self.project:
- dataflow_args += ["--project={}".format(self.project)]
- if self.staging_location:
- dataflow_args += ["--staging_location={}".format(self.staging_location)]
- if self.zone:
- dataflow_args += ["--zone={}".format(self.zone)]
- if self.region:
- dataflow_args += ["--region={}".format(self.region)]
- if self.temp_location:
- dataflow_args += ["--temp_location={}".format(self.temp_location)]
- if self.num_workers:
- dataflow_args += ["--num_workers={}".format(self.num_workers)]
- if self.autoscaling_algorithm:
- dataflow_args += ["--autoscaling_algorithm={}".format(self.autoscaling_algorithm)]
- if self.max_num_workers:
- dataflow_args += ["--max_num_workers={}".format(self.max_num_workers)]
- if self.network:
- dataflow_args += ["--network={}".format(self.network)]
- if self.subnetwork:
- dataflow_args += ["--subnetwork={}".format(self.subnetwork)]
- if self.disk_size_gb:
- dataflow_args += ["--disk_size_gb={}".format(self.disk_size_gb)]
- if self.worker_machine_type:
- dataflow_args += ["--worker_machine_type={}".format(self.worker_machine_type)]
- if self.job_name:
- dataflow_args += ["--job_name={}".format(self.job_name)]
- if self.worker_disk_type:
- dataflow_args += ["--worker_disk_type={}".format(self.worker_disk_type)]
- if self.service_account:
- dataflow_args += ["--service_account_email={}".format(self.service_account)]
- if self.requirements_file:
- dataflow_args += ["--requirements_file={}".format(self.requirements_file)]
- if self.setup_file:
- dataflow_args += ["--setup_file={}".format(self.setup_file)]
-
- return dataflow_args
-
- def args(self):
- """ Extra arguments that will be passed to your dataflow job.
-
- Example:
- return ["--project=my-gcp-project",
- "--zone=a-zone",
- "--staging_location=gs://my-gcp-project/dataflow"]
-
- Note that:
-
- * You "set" args by overriding this method in your subclass.
- * This function should return an iterable of strings.
- """
- return []
-
- def get_output_uris(self):
- """ Returns a dictionary that contains output uris.
- The key is the name of the output target defined in output(), and the value is
- the path/uri of the output target. It can be used to write data to different sub directories
- under one output target.
-
- :return A dictionary of output uris
- :rtype: Dict of String to String
- """
- return self._output_uris
diff --git a/spotify_tensorflow/luigi/tfx_task.py b/spotify_tensorflow/luigi/tfx_task.py
index 9ec4656..b6ea74b 100644
--- a/spotify_tensorflow/luigi/tfx_task.py
+++ b/spotify_tensorflow/luigi/tfx_task.py
@@ -18,13 +18,19 @@
from abc import abstractmethod
-from spotify_tensorflow.luigi.python_dataflow_task import PythonDataflowTask
+from luigi.contrib.beam_dataflow import BeamDataflowJobTask
-class TFXBaseTask(PythonDataflowTask):
+class TFXBaseTask(BeamDataflowJobTask):
+ python_script = None # type: str
+
def __init__(self, *args, **kwargs):
super(TFXBaseTask, self).__init__(*args, **kwargs)
+ def dataflow_executable(self):
+ """ Must be overwritten from the BeamDataflowTask """
+ return ["python", self.python_script]
+
def tfx_args(self):
""" Extra arguments that will be passed to your tfx dataflow job.
@@ -44,6 +50,8 @@ def _mk_cmd_line(self):
class TFTransformTask(TFXBaseTask):
+ # Required dataflow arg
+
def __init__(self, *args, **kwargs):
super(TFTransformTask, self).__init__(*args, **kwargs)
diff --git a/tests/python_dataflow_task_test.py b/tests/python_dataflow_task_test.py
deleted file mode 100644
index 8597eee..0000000
--- a/tests/python_dataflow_task_test.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Copyright 2017-2019 Spotify AB.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import absolute_import, division, print_function
-
-from unittest import TestCase
-
-import luigi
-from spotify_tensorflow.luigi.python_dataflow_task import PythonDataflowTask
-from tests.test_utils import MockGCSTarget
-
-
-class DummyRawFeature(luigi.ExternalTask):
- def output(self):
- return MockGCSTarget("output_uri")
-
-
-class DummyPythonDataflowTask(PythonDataflowTask):
- python_script = "pybeamjob.py"
- requirements_file = "tfx_requirement.txt"
- zone = "zone"
- region = "region"
- project = "dummy"
- worker_machine_type = "n1-standard-4"
- num_workers = 5
- max_num_workers = 20
- autoscaling_algorithm = "THROUGHPUT_BASED"
- service_account = "dummy@dummy.iam.gserviceaccount.com"
- local_runner = True
- staging_location = "staging_uri"
- temp_location = "tmp"
- network = "network"
- subnetwork = "subnetwork"
- disk_size_gb = 30
- worker_disk_type = "disc_type"
- job_name = "dummy"
- setup_file = "setup.py"
-
- def requires(self):
- return {"input": DummyRawFeature()}
-
- def args(self):
- return ["--foo=bar"]
-
- def output(self):
- return MockGCSTarget(path="output_uri")
-
-
-class PythonDataflowTaskFailedOnValidation(PythonDataflowTask):
- python_script = "pybeamjob.py"
-
- # override to construct a test run
- def _mk_cmd_line(self):
- return ["python", "-c", "\"print(1)\""]
-
- def validate_output(self):
- return False
-
- def args(self):
- return ["--foo=bar"]
-
- def output(self):
- return MockGCSTarget(path="output_uri")
-
-
-class PythonDataflowTaskTest(TestCase):
- def test_python_dataflow_task(self):
- task = DummyPythonDataflowTask()
-
- expected = [
- "python",
- "pybeamjob.py",
- "--runner=DirectRunner",
- "--project=dummy",
- "--autoscaling_algorithm=THROUGHPUT_BASED",
- "--num_workers=5",
- "--max_num_workers=20",
- "--service_account_email=dummy@dummy.iam.gserviceaccount.com",
- "--input=output_uri/part-*",
- "--output=output_uri",
- "--staging_location=staging_uri",
- "--requirements_file=tfx_requirement.txt",
- "--worker_machine_type=n1-standard-4",
- "--foo=bar",
- "--temp_location=tmp",
- "--network=network",
- "--subnetwork=subnetwork",
- "--disk_size_gb=30",
- "--worker_disk_type=disc_type",
- "--job_name=dummy",
- "--zone=zone",
- "--region=region",
- "--setup_file=setup.py"
- ]
- actual = task._mk_cmd_line()
- self.assertEquals(actual[:2], expected[:2])
- self.assertEquals(set(actual[2:]), set(expected[2:]))
-
- def test_task_failed_on_validation(self):
- task = PythonDataflowTaskFailedOnValidation()
- try:
- task.run()
- self.assertTrue(False)
- except ValueError:
- self.assertTrue(True)