diff --git a/setup.cfg b/setup.cfg index aa4695c..13d3587 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,7 +23,7 @@ packages = [extras] luigi = - luigi>=2.3.3,<3.0.0 + luigi>=2.8.5,<3.0.0 examples = scipy sklearn diff --git a/spotify_tensorflow/luigi/python_dataflow_task.py b/spotify_tensorflow/luigi/python_dataflow_task.py deleted file mode 100644 index 31461c0..0000000 --- a/spotify_tensorflow/luigi/python_dataflow_task.py +++ /dev/null @@ -1,297 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2017-2019 Spotify AB. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -import os -import subprocess -import time - -import luigi -from luigi.task import MixinNaiveBulkComplete -from spotify_tensorflow.luigi.utils import get_uri, run_with_logging - -logger = logging.getLogger("luigi-interface") - - -class PythonDataflowTask(MixinNaiveBulkComplete, luigi.Task): - """"Luigi wrapper for a dataflow job - - The following properties can be set: - python_script = None # Python script for the dataflow task. - project = None # Name of the project owning the dataflow job. - staging_location = None # GCS path for staging code packages needed by workers. - zone = None # GCE availability zone for launching workers. - region = None # GCE region for creating the dataflow job. - temp_location = None # GCS path for saving temporary workflow jobs. - num_workers = None # The number of workers to start the task with. - autoscaling_algorithm = None # Set to "NONE" to disable autoscaling. `num_workers` - # will then be used for the job. - max_num_workers = None # Used if the autoscaling is enabled. - network = None # Network in GCE to be used for launching workers. - subnetwork = None # Subnetwork in GCE to be used for launching workers. - disk_size_gb = None # Remote worker disk size, if not defined uses default size. - worker_machine_type = None # Machine type to create Dataflow worker VMs. If unset, - # the Dataflow service will choose a reasonable default. - worker_disk_type = None # Specify SSD for local disk or defaults to hard disk. - service_account = None # Service account of Dataflow VMs/workers. Default is a - default GCE service account. - job_name = None # Name of the dataflow job - requirements_file = None # Path to a requirements file containing package dependencies. - local_runner = False # If local_runner = True, the job uses DirectRunner, - otherwise it uses DataflowRunner - setup_file = None # Path to a setup Python file containing package dependencies. - - :Example: - - class AwesomeJob(PythonDataflowJobTask): - python_script = "/path/to/python_script" - project = "gcp-project" - staging_location = "gs://gcp-project-playground/user/staging" - temp_location = "gs://gcp-project-playground/user/tmp" - max_num_workers = 20 - region = "europe-west1" - service_account_email = "service_account@gcp-project.iam.gserviceaccount.com" - - def output(self): - ... - """ - # Required dataflow args - python_script = None # type: str - project = None # type: str - staging_location = None # type: str - - # Dataflow requires one and only one of: - zone = None # type: str - region = None # type: str - - # Optional dataflow args - temp_location = None # type: str - num_workers = None # type: int - autoscaling_algorithm = None # type: str - max_num_workers = None # type: int - network = None # type: str - subnetwork = None # type: str - disk_size_gb = None # type: int - worker_machine_type = None # type: str - worker_disk_type = None # type: str - service_account = None # type: str - job_name = None # type: str - requirements_file = None # type: str - local_runner = False # type: bool - setup_file = None # type: str - - def __init__(self, *args, **kwargs): - super(PythonDataflowTask, self).__init__(*args, **kwargs) - self._output = self.output() - if isinstance(self._output, luigi.Target): - self._output = {"output": self._output} - if self.job_name is None: - # job_name must consist of only the characters [-a-z0-9] - cls_name = self.__class__.__name__.replace("_", "-").lower() - self.job_name = "{cls_name}-{timestamp}".format(cls_name=cls_name, - timestamp=str(time.time())[:-3]) - - def on_successful_run(self): - """ Callback that gets called right after the dataflow job has finished successfully but - before validate_output is run. - """ - pass - - def validate_output(self): - """ Callback that can be used to validate your output before it is moved to it's final - location. Returning false here will cause the job to fail, and output to be removed instead - of published. - - :return: Whether the output is valid or not - :rtype: Boolean - """ - return True - - def file_pattern(self): - """ If one/some of the input target files are not in the pattern of part-*, - we can add the key of the required target and the correct file pattern - that should be appended in the command line here. If the input target key is not found - in this dict, the file pattern will be assumed to be part-* for that target. - - :return A dictionary of overrided file pattern that is not part-* for the inputs - :rtype: Dict of String to String - """ - return {} - - def run(self): - cmd_line = self._mk_cmd_line() - logger.info(" ".join(cmd_line)) - - try: - run_with_logging(cmd_line, logger) - except subprocess.CalledProcessError as e: - logging.error(e, exc_info=True) - # exit luigi with the same exit code as the python dataflow job proccess - # In this way users can easily exit the job with code 50 to avoid Styx retries - # https://github.com/spotify/styx/blob/master/doc/design-overview.md#workflow-state-graph - os._exit(e.returncode) - - self.on_successful_run() - if self.validate_output(): - self._publish_outputs() - else: - raise ValueError("Output is not valid") - - def _publish_outputs(self): - for (name, target) in self._output.items(): - if hasattr(target, "publish"): - target.publish(self._output_uris[name]) - - def _mk_cmd_line(self): - cmd_line = self._dataflow_executable() - cmd_line.extend(self._get_dataflow_args()) - cmd_line.extend(self._get_input_args()) - cmd_line.extend(self._get_output_args()) - cmd_line.extend(self.args()) - return cmd_line - - def _dataflow_executable(self): - """ - Defines the executable used to run the python dataflow job. - """ - return ["python", self.python_script] - - def _get_input_uri(self, file_pattern, target): - uri = get_uri(target) - uri = uri.rstrip("/") + "/" + file_pattern - return uri - - def _get_file_pattern(self): - file_pattern = self.file_pattern() - if not isinstance(file_pattern, dict): - raise ValueError("file_pattern() must return a dict type") - return file_pattern - - def _get_input_args(self): - """ - Collects outputs from requires() and converts them to input arguments. - file_pattern() is called to construct input file path glob with default value "part-*" - """ - job_input = self.input() - if isinstance(job_input, luigi.Target): - job_input = {"input": job_input} - if not isinstance(job_input, dict): - raise ValueError("Input (requires()) must be dict type") - - input_args = [] - file_pattern_dict = self._get_file_pattern() - for (name, targets) in job_input.items(): - uri_targets = luigi.task.flatten(targets) - pattern = file_pattern_dict.get(name, "part-*") - uris = [self._get_input_uri(pattern, uri_target) for uri_target in uri_targets] - if isinstance(targets, dict): - # If targets is a dict that means it had multiple outputs. - # Make the input args in that case "-" - names = ["%s-%s" % (name, key) for key in targets.keys()] - else: - names = [name] * len(uris) - for (arg_name, uri) in zip(names, uris): - input_args.append("--%s=%s" % (arg_name, uri)) - - return input_args - - def _get_output_args(self): - if not isinstance(self._output, dict): - raise ValueError("Output must be dict type") - - output_args = [] - self._output_uris = {} - - for (name, target) in self._output.items(): - uri = target.generate_uri() if hasattr(target, "generate_uri") else get_uri(target) - uri = uri.rstrip("/") - output_args.append("--%s=%s" % (name, uri)) - self._output_uris[name] = uri - - return output_args - - def _get_runner(self): - return "DirectRunner" if self.local_runner else "DataflowRunner" - - def _get_dataflow_args(self): - dataflow_args = [] - - _runner = self._get_runner() - if _runner: - dataflow_args += ["--runner={}".format(_runner)] - if self.project: - dataflow_args += ["--project={}".format(self.project)] - if self.staging_location: - dataflow_args += ["--staging_location={}".format(self.staging_location)] - if self.zone: - dataflow_args += ["--zone={}".format(self.zone)] - if self.region: - dataflow_args += ["--region={}".format(self.region)] - if self.temp_location: - dataflow_args += ["--temp_location={}".format(self.temp_location)] - if self.num_workers: - dataflow_args += ["--num_workers={}".format(self.num_workers)] - if self.autoscaling_algorithm: - dataflow_args += ["--autoscaling_algorithm={}".format(self.autoscaling_algorithm)] - if self.max_num_workers: - dataflow_args += ["--max_num_workers={}".format(self.max_num_workers)] - if self.network: - dataflow_args += ["--network={}".format(self.network)] - if self.subnetwork: - dataflow_args += ["--subnetwork={}".format(self.subnetwork)] - if self.disk_size_gb: - dataflow_args += ["--disk_size_gb={}".format(self.disk_size_gb)] - if self.worker_machine_type: - dataflow_args += ["--worker_machine_type={}".format(self.worker_machine_type)] - if self.job_name: - dataflow_args += ["--job_name={}".format(self.job_name)] - if self.worker_disk_type: - dataflow_args += ["--worker_disk_type={}".format(self.worker_disk_type)] - if self.service_account: - dataflow_args += ["--service_account_email={}".format(self.service_account)] - if self.requirements_file: - dataflow_args += ["--requirements_file={}".format(self.requirements_file)] - if self.setup_file: - dataflow_args += ["--setup_file={}".format(self.setup_file)] - - return dataflow_args - - def args(self): - """ Extra arguments that will be passed to your dataflow job. - - Example: - return ["--project=my-gcp-project", - "--zone=a-zone", - "--staging_location=gs://my-gcp-project/dataflow"] - - Note that: - - * You "set" args by overriding this method in your subclass. - * This function should return an iterable of strings. - """ - return [] - - def get_output_uris(self): - """ Returns a dictionary that contains output uris. - The key is the name of the output target defined in output(), and the value is - the path/uri of the output target. It can be used to write data to different sub directories - under one output target. - - :return A dictionary of output uris - :rtype: Dict of String to String - """ - return self._output_uris diff --git a/spotify_tensorflow/luigi/tfx_task.py b/spotify_tensorflow/luigi/tfx_task.py index 9ec4656..b6ea74b 100644 --- a/spotify_tensorflow/luigi/tfx_task.py +++ b/spotify_tensorflow/luigi/tfx_task.py @@ -18,13 +18,19 @@ from abc import abstractmethod -from spotify_tensorflow.luigi.python_dataflow_task import PythonDataflowTask +from luigi.contrib.beam_dataflow import BeamDataflowJobTask -class TFXBaseTask(PythonDataflowTask): +class TFXBaseTask(BeamDataflowJobTask): + python_script = None # type: str + def __init__(self, *args, **kwargs): super(TFXBaseTask, self).__init__(*args, **kwargs) + def dataflow_executable(self): + """ Must be overwritten from the BeamDataflowTask """ + return ["python", self.python_script] + def tfx_args(self): """ Extra arguments that will be passed to your tfx dataflow job. @@ -44,6 +50,8 @@ def _mk_cmd_line(self): class TFTransformTask(TFXBaseTask): + # Required dataflow arg + def __init__(self, *args, **kwargs): super(TFTransformTask, self).__init__(*args, **kwargs) diff --git a/tests/python_dataflow_task_test.py b/tests/python_dataflow_task_test.py deleted file mode 100644 index 8597eee..0000000 --- a/tests/python_dataflow_task_test.py +++ /dev/null @@ -1,119 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2017-2019 Spotify AB. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import absolute_import, division, print_function - -from unittest import TestCase - -import luigi -from spotify_tensorflow.luigi.python_dataflow_task import PythonDataflowTask -from tests.test_utils import MockGCSTarget - - -class DummyRawFeature(luigi.ExternalTask): - def output(self): - return MockGCSTarget("output_uri") - - -class DummyPythonDataflowTask(PythonDataflowTask): - python_script = "pybeamjob.py" - requirements_file = "tfx_requirement.txt" - zone = "zone" - region = "region" - project = "dummy" - worker_machine_type = "n1-standard-4" - num_workers = 5 - max_num_workers = 20 - autoscaling_algorithm = "THROUGHPUT_BASED" - service_account = "dummy@dummy.iam.gserviceaccount.com" - local_runner = True - staging_location = "staging_uri" - temp_location = "tmp" - network = "network" - subnetwork = "subnetwork" - disk_size_gb = 30 - worker_disk_type = "disc_type" - job_name = "dummy" - setup_file = "setup.py" - - def requires(self): - return {"input": DummyRawFeature()} - - def args(self): - return ["--foo=bar"] - - def output(self): - return MockGCSTarget(path="output_uri") - - -class PythonDataflowTaskFailedOnValidation(PythonDataflowTask): - python_script = "pybeamjob.py" - - # override to construct a test run - def _mk_cmd_line(self): - return ["python", "-c", "\"print(1)\""] - - def validate_output(self): - return False - - def args(self): - return ["--foo=bar"] - - def output(self): - return MockGCSTarget(path="output_uri") - - -class PythonDataflowTaskTest(TestCase): - def test_python_dataflow_task(self): - task = DummyPythonDataflowTask() - - expected = [ - "python", - "pybeamjob.py", - "--runner=DirectRunner", - "--project=dummy", - "--autoscaling_algorithm=THROUGHPUT_BASED", - "--num_workers=5", - "--max_num_workers=20", - "--service_account_email=dummy@dummy.iam.gserviceaccount.com", - "--input=output_uri/part-*", - "--output=output_uri", - "--staging_location=staging_uri", - "--requirements_file=tfx_requirement.txt", - "--worker_machine_type=n1-standard-4", - "--foo=bar", - "--temp_location=tmp", - "--network=network", - "--subnetwork=subnetwork", - "--disk_size_gb=30", - "--worker_disk_type=disc_type", - "--job_name=dummy", - "--zone=zone", - "--region=region", - "--setup_file=setup.py" - ] - actual = task._mk_cmd_line() - self.assertEquals(actual[:2], expected[:2]) - self.assertEquals(set(actual[2:]), set(expected[2:])) - - def test_task_failed_on_validation(self): - task = PythonDataflowTaskFailedOnValidation() - try: - task.run() - self.assertTrue(False) - except ValueError: - self.assertTrue(True)