Skip to content

Instantly share code, notes, and snippets.

@inardini
Created August 31, 2022 09:23
Show Gist options
  • Select an option

  • Save inardini/22879be00e5ca0d8eea037e327c0ecd0 to your computer and use it in GitHub Desktop.

Select an option

Save inardini/22879be00e5ca0d8eea037e327c0ecd0 to your computer and use it in GitHub Desktop.
### Pseudo code of Cloud Function v2 retraining trigger
import base64
import json
import yaml
from pathlib import Path as p
from datetime import datetime
from google.cloud import aiplatform as vertex_ai
# Get config variables
with open(r'config.yaml') as f:
config = yaml.full_load(f)
f.close()
PROJECT_ID = config['general']['project_id']
BUCKET_URI = config['general']['bucket_uri']
REGION = config['general']['region']
PIPELINE_NAME_PREFIX = config['pipeline_config']['pipeline_name']
PIPELINE_SPEC_URI = config['pipeline_config']['pipeline_package_uri']
PIPELINE_ROOT_URI = config['pipeline_config']['pipeline_root_uri']
BQ_TRAINING_TABLE_PREFIX = config['pipeline_config']['bq_training_table_prefix']
BQ_MODEL_TABLE_PREFIX = config['pipeline_config']['bq_model_table_prefix']
BQ_EVALUATE_TIME_SERIES_TABLE_PREFIX = config['pipeline_config']['bq_evaluate_time_series_table_prefix']
BQ_EVALUATE_MODEL_TABLE_PREFIX = config['pipeline_config']['bq_evaluate_model_table_prefix']
BQ_FORECAST_TABLE_PREFIX = config['pipeline_config']['bq_forecast_table_prefix']
BQ_EXPLAIN_FORECAST_TABLE_PREFIX = config['pipeline_config']['bq_explain_forecast_table_prefix']
PERF_THRESHOLD = config['pipeline_config']['performance_threshold']
BQ_LOCATION = config['pipeline_config']['location']
def get_pipeline_run(pipeline_name, pipeline_spec_uri, pipeline_root, parameter_values, project, region):
"""
Get pipeline run
Args:
pipeline_name: name of the pipeline
pipeline_spec_uri: URI of the pipeline spec
pipeline_root: URI of the pipeline root
parameter_values: parameter values
project: project id
region: region
Returns:
pipeline run
"""
vertex_ai.init(
project=project,
location=region,
)
job = vertex_ai.PipelineJob(
display_name=pipeline_name,
template_path=pipeline_spec_uri,
pipeline_root=pipeline_root,
enable_caching=False,
parameter_values=parameter_values
)
return job
def trigger(event, context):
"""
Triggered from a message on a Cloud Pub/Sub topic.
Args:
event (dict): Event payload.
context (google.cloud.functions.Context): Metadata for the event.
"""
# Read the pubsub message
pubsub_message = json.loads(base64.b64decode(event['data']).decode('utf-8'))
print(pubsub_message)
# Check if insertedRowsCount --> both a new table has been created or has been updated
n_rows = pubsub_message.get('protoPayload', {}).get('metadata', {}).get('tableDataChange', {}).get('insertedRowsCount')
if n_rows and int(n_rows) > 0:
# Get timestamp
try:
tid = pubsub_message.get('timestamp')
tid = datetime.strptime(tid, "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y%m%d%H%M%S")
except AttributeError as e:
print(e)
# Get project_id, dataset_id, table_id --> projects/project_id/datasets/dataset_id/tables/table_id
try:
resource_path = p(pubsub_message.get('protoPayload', {}).get('resourceName'))
bq_dataset = resource_path.parts[3]
bq_orders_table = resource_path.parts[-1]
tid = bq_orders_table.split('_')[-1]
except AttributeError as e:
print(e)
# Get pipeline name
pipeline_name = f"{PIPELINE_NAME_PREFIX}_{tid}"
# Get table names
bq_training_table = f"{BQ_TRAINING_TABLE_PREFIX}_{tid}"
bq_model_table = f"{BQ_MODEL_TABLE_PREFIX}_{tid}"
bq_evaluate_time_series_table = f"{BQ_EVALUATE_TIME_SERIES_TABLE_PREFIX}_{tid}"
bq_evaluate_model_table = f"{BQ_EVALUATE_MODEL_TABLE_PREFIX}_{tid}"
bq_forecast_table = f"{BQ_FORECAST_TABLE_PREFIX}_{tid}"
bq_explain_forecast_table = f"{BQ_EXPLAIN_FORECAST_TABLE_PREFIX}_{tid}"
# Build configurations
bq_train_config = {
"destinationTable": {"projectId": PROJECT_ID,
"datasetId": bq_dataset, "tableId": bq_training_table},
"writeDisposition": "WRITE_TRUNCATE"
}
bq_evaluate_time_series_config = {
"destinationTable": {"projectId": PROJECT_ID,
"datasetId": bq_dataset, "tableId": bq_evaluate_time_series_table},
"writeDisposition": "WRITE_TRUNCATE"
}
bq_evaluate_model_config = {
"destinationTable": {"projectId": PROJECT_ID,
"datasetId": bq_dataset, "tableId": bq_evaluate_model_table},
"writeDisposition": "WRITE_TRUNCATE"
}
bq_forecast_config = {
"destinationTable": {"projectId": PROJECT_ID,
"datasetId": bq_dataset, "tableId": bq_forecast_table},
"writeDisposition": "WRITE_TRUNCATE"
}
bq_explain_forecast_config = {
"destinationTable": {"projectId": PROJECT_ID,
"datasetId": bq_dataset, "tableId": bq_explain_forecast_table},
"writeDisposition": "WRITE_TRUNCATE"
}
# Define pipeline parameters dictionary
parameter_values = {'bq_dataset': bq_dataset, 'bq_orders_table': bq_orders_table,
'bq_training_table': bq_training_table, 'bq_train_configuration': bq_train_config,
'bq_model_table': bq_model_table,
'bq_evaluate_time_series_configuration': bq_evaluate_time_series_config,
'bq_evaluate_model_configuration': bq_evaluate_model_config,
'performance_threshold': PERF_THRESHOLD,
'bq_forecast_configuration': bq_forecast_config,
'bq_explain_forecast_configuration': bq_explain_forecast_config,
'project': PROJECT_ID, 'location': BQ_LOCATION}
# Trigger the pipeline
print(f'Pipeline {pipeline_name} configuration: ', parameter_values)
print('Pipeline run is starting...')
pipeline_job = get_pipeline_run(pipeline_name=pipeline_name,
parameter_values=parameter_values,
pipeline_spec_uri=PIPELINE_SPEC_URI,
pipeline_root=PIPELINE_ROOT_URI,
project=PROJECT_ID,
region=REGION)
pipeline_job.submit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment