Created
August 31, 2022 09:23
-
-
Save inardini/22879be00e5ca0d8eea037e327c0ecd0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ### Pseudo code of Cloud Function v2 retraining trigger | |
| import base64 | |
| import json | |
| import yaml | |
| from pathlib import Path as p | |
| from datetime import datetime | |
| from google.cloud import aiplatform as vertex_ai | |
| # Get config variables | |
| with open(r'config.yaml') as f: | |
| config = yaml.full_load(f) | |
| f.close() | |
| PROJECT_ID = config['general']['project_id'] | |
| BUCKET_URI = config['general']['bucket_uri'] | |
| REGION = config['general']['region'] | |
| PIPELINE_NAME_PREFIX = config['pipeline_config']['pipeline_name'] | |
| PIPELINE_SPEC_URI = config['pipeline_config']['pipeline_package_uri'] | |
| PIPELINE_ROOT_URI = config['pipeline_config']['pipeline_root_uri'] | |
| BQ_TRAINING_TABLE_PREFIX = config['pipeline_config']['bq_training_table_prefix'] | |
| BQ_MODEL_TABLE_PREFIX = config['pipeline_config']['bq_model_table_prefix'] | |
| BQ_EVALUATE_TIME_SERIES_TABLE_PREFIX = config['pipeline_config']['bq_evaluate_time_series_table_prefix'] | |
| BQ_EVALUATE_MODEL_TABLE_PREFIX = config['pipeline_config']['bq_evaluate_model_table_prefix'] | |
| BQ_FORECAST_TABLE_PREFIX = config['pipeline_config']['bq_forecast_table_prefix'] | |
| BQ_EXPLAIN_FORECAST_TABLE_PREFIX = config['pipeline_config']['bq_explain_forecast_table_prefix'] | |
| PERF_THRESHOLD = config['pipeline_config']['performance_threshold'] | |
| BQ_LOCATION = config['pipeline_config']['location'] | |
| def get_pipeline_run(pipeline_name, pipeline_spec_uri, pipeline_root, parameter_values, project, region): | |
| """ | |
| Get pipeline run | |
| Args: | |
| pipeline_name: name of the pipeline | |
| pipeline_spec_uri: URI of the pipeline spec | |
| pipeline_root: URI of the pipeline root | |
| parameter_values: parameter values | |
| project: project id | |
| region: region | |
| Returns: | |
| pipeline run | |
| """ | |
| vertex_ai.init( | |
| project=project, | |
| location=region, | |
| ) | |
| job = vertex_ai.PipelineJob( | |
| display_name=pipeline_name, | |
| template_path=pipeline_spec_uri, | |
| pipeline_root=pipeline_root, | |
| enable_caching=False, | |
| parameter_values=parameter_values | |
| ) | |
| return job | |
| def trigger(event, context): | |
| """ | |
| Triggered from a message on a Cloud Pub/Sub topic. | |
| Args: | |
| event (dict): Event payload. | |
| context (google.cloud.functions.Context): Metadata for the event. | |
| """ | |
| # Read the pubsub message | |
| pubsub_message = json.loads(base64.b64decode(event['data']).decode('utf-8')) | |
| print(pubsub_message) | |
| # Check if insertedRowsCount --> both a new table has been created or has been updated | |
| n_rows = pubsub_message.get('protoPayload', {}).get('metadata', {}).get('tableDataChange', {}).get('insertedRowsCount') | |
| if n_rows and int(n_rows) > 0: | |
| # Get timestamp | |
| try: | |
| tid = pubsub_message.get('timestamp') | |
| tid = datetime.strptime(tid, "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y%m%d%H%M%S") | |
| except AttributeError as e: | |
| print(e) | |
| # Get project_id, dataset_id, table_id --> projects/project_id/datasets/dataset_id/tables/table_id | |
| try: | |
| resource_path = p(pubsub_message.get('protoPayload', {}).get('resourceName')) | |
| bq_dataset = resource_path.parts[3] | |
| bq_orders_table = resource_path.parts[-1] | |
| tid = bq_orders_table.split('_')[-1] | |
| except AttributeError as e: | |
| print(e) | |
| # Get pipeline name | |
| pipeline_name = f"{PIPELINE_NAME_PREFIX}_{tid}" | |
| # Get table names | |
| bq_training_table = f"{BQ_TRAINING_TABLE_PREFIX}_{tid}" | |
| bq_model_table = f"{BQ_MODEL_TABLE_PREFIX}_{tid}" | |
| bq_evaluate_time_series_table = f"{BQ_EVALUATE_TIME_SERIES_TABLE_PREFIX}_{tid}" | |
| bq_evaluate_model_table = f"{BQ_EVALUATE_MODEL_TABLE_PREFIX}_{tid}" | |
| bq_forecast_table = f"{BQ_FORECAST_TABLE_PREFIX}_{tid}" | |
| bq_explain_forecast_table = f"{BQ_EXPLAIN_FORECAST_TABLE_PREFIX}_{tid}" | |
| # Build configurations | |
| bq_train_config = { | |
| "destinationTable": {"projectId": PROJECT_ID, | |
| "datasetId": bq_dataset, "tableId": bq_training_table}, | |
| "writeDisposition": "WRITE_TRUNCATE" | |
| } | |
| bq_evaluate_time_series_config = { | |
| "destinationTable": {"projectId": PROJECT_ID, | |
| "datasetId": bq_dataset, "tableId": bq_evaluate_time_series_table}, | |
| "writeDisposition": "WRITE_TRUNCATE" | |
| } | |
| bq_evaluate_model_config = { | |
| "destinationTable": {"projectId": PROJECT_ID, | |
| "datasetId": bq_dataset, "tableId": bq_evaluate_model_table}, | |
| "writeDisposition": "WRITE_TRUNCATE" | |
| } | |
| bq_forecast_config = { | |
| "destinationTable": {"projectId": PROJECT_ID, | |
| "datasetId": bq_dataset, "tableId": bq_forecast_table}, | |
| "writeDisposition": "WRITE_TRUNCATE" | |
| } | |
| bq_explain_forecast_config = { | |
| "destinationTable": {"projectId": PROJECT_ID, | |
| "datasetId": bq_dataset, "tableId": bq_explain_forecast_table}, | |
| "writeDisposition": "WRITE_TRUNCATE" | |
| } | |
| # Define pipeline parameters dictionary | |
| parameter_values = {'bq_dataset': bq_dataset, 'bq_orders_table': bq_orders_table, | |
| 'bq_training_table': bq_training_table, 'bq_train_configuration': bq_train_config, | |
| 'bq_model_table': bq_model_table, | |
| 'bq_evaluate_time_series_configuration': bq_evaluate_time_series_config, | |
| 'bq_evaluate_model_configuration': bq_evaluate_model_config, | |
| 'performance_threshold': PERF_THRESHOLD, | |
| 'bq_forecast_configuration': bq_forecast_config, | |
| 'bq_explain_forecast_configuration': bq_explain_forecast_config, | |
| 'project': PROJECT_ID, 'location': BQ_LOCATION} | |
| # Trigger the pipeline | |
| print(f'Pipeline {pipeline_name} configuration: ', parameter_values) | |
| print('Pipeline run is starting...') | |
| pipeline_job = get_pipeline_run(pipeline_name=pipeline_name, | |
| parameter_values=parameter_values, | |
| pipeline_spec_uri=PIPELINE_SPEC_URI, | |
| pipeline_root=PIPELINE_ROOT_URI, | |
| project=PROJECT_ID, | |
| region=REGION) | |
| pipeline_job.submit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment