Build accurate ML training datasets using point-in-time queries with Amazon SageMaker Feature Store and Apache Spark
This post is co-written with Raphey Holmes, Software Engineering Manager, and Jason Mackay, Principal Software Development Engineer, at GoDaddy. GoDaddy is the world’s largest services platform for entrepreneurs around the globe, empowering their worldwide community of over 20 million customers—and entrepreneurs everywhere—by giving them all the help and tools they need to grow online. GoDaddy…
This post is co-written with Raphey Holmes, Software Engineering Manager, and Jason Mackay, Principal Software Development Engineer, at GoDaddy.
GoDaddy is the world’s largest services platform for entrepreneurs around the globe, empowering their worldwide community of over 20 million customers—and entrepreneurs everywhere—by giving them all the help and tools they need to grow online. GoDaddy needs a robust, validated, and customizable ML feature management solution, and has chosen Amazon SageMaker Feature Store to manage thousands of features across dozens of feature groups with unique data pipelines and update schedules. Feature Store lets GoDaddy use point-in-time queries to support accurate training and deployment of machine learning (ML) models, covering everything from personalizing content, to preventing fraud, to helping customers find the perfect domain name.
Feature Store lets you define groups of features, use batch ingestion and streaming ingestion, retrieve features with as low as single-digit millisecond latency for highly accurate online predictions, and extract point-in-time correct datasets for training. Instead of building and maintaining these infrastructure capabilities, you get a fully managed service that scales as your data grows, enables feature sharing across teams, and lets data scientists focus on building great ML-driven products for game-changing business use cases. Teams can now deliver robust features and reuse them in a variety of models that may be built by different teams.
In this post, we (the joint team of GoDaddy and AWS architects), explain how to use Feature Store and the processing power of Apache Spark to create accurate training datasets using point-in-time queries against reusable feature groups in a scalable fashion.
Avoid data leakage by using point-in-time correct queries
In ML, data leakage or target leakage is accidentally using data in model training that wouldn’t be available at the time of prediction. Leakage can be subtle and difficult to detect, yet the business impact can be significant. Models with leakage perform unrealistically well in development, but they deliver poor model accuracy in production without the benefit of future data.
Leakage with time-dependent features can occur in a wide range of use cases. For example, a model predicting lung disease might use features about a patient’s use of medications or surgical procedures. A recommendation model on a website may use customer orders to predict what offers would be most attractive to that customer. These features are valid when used correctly, but data scientists must ensure that the feature values are built only using data that could be known before the target was observed. For example, if a patient was diagnosed at time t1, any data about medications or hospital visits at times beyond t1 must be excluded when creating a training dataset.
So how do data science teams provide a rich set of ML features, while ensuring they don’t leak future data into their trained models? Companies are increasingly adopting the use of a feature store to help solve this model training challenge. A robust feature store provides an offline store with a complete history of feature values. Feature records include timestamps to support point-in-time correct queries. Data scientists can query for the exact set of feature values that would have been available at a specific time, without the chance of including data from beyond that time.
Let’s use a diagram to explain the concept of a point-in-time feature query. Imagine we’re training a fraud detection model on a set of historical transactions. Each transaction has features associated with various entities involved in the transaction, such as the consumer, merchant, and credit card. Feature values for these entities change over time, and they’re updated on different schedules. To avoid leaking future feature values, a point-in-time query retrieves the state of each feature that was available at each transaction time, and no later. For example, the transaction at time t2 can only use features available before time t2, and the transaction at t1 can’t use features from timestamps greater than t1.
The resulting training dataset in the following diagram shows that a point-in-time query returns an accurate set of feature values for each transaction, avoiding values that would have only been known in the future. Reliably retrieving the right set of values from history ensures that model performance won’t suffer when it faces real-world transactions.
To solidify the concept one step further, let’s consider two other types of queries that don’t protect you from data leakage:
Get latest feature values – Consider a query that simply returns the latest feature values available for each feature. Although such a query works well for creating a batch scoring dataset, the query mistakenly leaks data that wasn’t available for the transactions at t1 and t2, providing a very poor training dataset.
Get features as of a specific timestamp – Likewise, a query that returns all feature data as of a single timestamp produces an inappropriate training dataset, because it treats all records uniformly instead of using a distinct timestamp for each training dataset entry. So-called time travel capabilities are great for auditing and reproducing experiments. However, time travel doesn’t solve for accurate training dataset extraction, because it doesn’t account for event timestamps that vary for each training record.
Point-in-time queries as part of the overall ML lifecycle
The following diagram shows how point-in-time queries fit into the overall ML lifecycle. The diagram starts with a set of automated feature pipelines that perform feature transformations and ingest feature records into Feature Store. Pipelines for individual feature groups are independent and can run on varying schedules. One feature group may be refreshed nightly, whereas another is updated hourly, and a third may have updates that are triggered as source data arrives on an input stream, such as an Apache Kafka topic or via Amazon Kinesis Data Streams.
Depending on the configuration of your feature groups, the resulting features are made available in an online store, an offline store, or both with automatic synchronization. The offline store provides a secure and scalable repository of features, letting data scientists create training and validation datasets or batch scoring datasets from a set of feature groups that can be managed centrally with a full history of feature values.
The result of a point-in-time query is a training dataset, which can then be used directly with an ML algorithm, or as input to a SageMaker training job. A point-in-time query can be run interactively in an Apache Spark environment such as a SageMaker notebook or Amazon EMR notebook. For large-scale training datasets, you can run these queries in a SageMaker Processing job, which lets you scale the job to a multi-instance cluster without having to manage any infrastructure.
Although we show an Apache Spark implementation of point in time queries in this post, Amazon Athena provides another alternative. With Athena, you can create a temporary table containing your selection criteria, and then use SQL to join that criteria to your offline store feature groups. The SQL query can select the most recent feature values that are older than the targeted event time for each training dataset row.
Before we walk through how to perform point-in-time correct queries against the offline store, it’s important to understand the definition and purpose of the three different timestamps that SageMaker provides in the offline store schema for every feature group:
Event time – The customer-defined timestamp associated with the feature record, such as the transaction time, the time a customer placed an order, or the time a new insurance claim was created. The specific name of this feature is specified when the feature group is created, and the customer’s ingestion code is responsible for populating this timestamp.
API invocation time – The time when SageMaker receives the API call to write a feature record to the feature store. This timestamp is automatically populated by SageMaker as the api_invocation_time feature.
Write time – The time when the feature record is persisted to the offline store. This timestamp is always greater than API invocation time, and is automatically populated by SageMaker as the write_time feature.
Depending on your use case, you can use a combination of the timestamp fields in SageMaker to choose accurate feature values. Each instance in a training dataset has a customer-defined event time and a record identifier. When you join a list of training events against feature history in a point-in-time query, you can ignore all records that happened after the instance-specific event timestamp, and select the most recent of the remaining records. The event time field is the key to this join. In a standard use case, choosing the most recent remaining record is sufficient. However, if a feature value has been corrected or revised as part of a feature backfill, the offline store contains multiple records with the same event time and record identifier. Both the original record and the new record are available, and they each have a distinct write time. For these situations, a point-in-time query can use the write_time feature or the api_invocation_time feature as a tie-breaker, ensuring the corrected feature value is returned.
Implement a point-in-time correct query
Now that we have explained the concepts, let’s dive into the implementation details of how to efficiently perform point-in-time queries using Feature Store and Apache Spark. You can try this out in your own account using the following GitHub repo. This repository contains three Jupyter notebooks plus some schema files used to create our feature groups. In this section, we show you the inner workings of a query implementation. In the notebook, we also provide a reusable function that makes this as simple as passing a few parameters:
The implementation of our point-in-time query uses SageMaker, Jupyter notebooks, and Apache Spark (PySpark). Most of the intermediate data is stored in Spark DataFrames, which gives us powerful built-in methods to manipulate, filter, and reduce that dataset so that the query runs efficiently. To enable Spark to run properly in our environment, we configure the Spark session to allocate extra driver memory and executor cores:
Next, we load the historical transaction dataset that contains the raw credit card transactions along with the target, meaning the fraud label that we want to predict. This data contains primarily the attributes that are part of the transaction itself, and includes the following columns:
The need to run point-in-time queries originates from not having perfect information at all times pertaining to a given transaction. For example, if we had a complete set of aggregate features for every transaction event, we could use this data directly. Most organizations can’t build out this type of data for every event, but instead run periodic jobs that calculate these aggregates, perhaps on an hourly or daily basis.
For this post, we simulate these daily snapshots by running an aggregation function for each day in our timeframe that spans 1 month. The aggregation code creates two dataframes, one indexed by credit card number, and the other indexed by consumer ID. Each dataframe contains data for lookback periods of 1–7 days. These datasets simulate the periodic job runs that create snapshots of our aggregate features, and they’re written to the offline store. The following screenshot is a sample of the aggregated consumer features that are generated periodically.
To prepare the input criteria for our point-in-time query, we begin by creating an entity dataframe, which contains one row for each desired training dataset row. The entity dataframe identifies the consumer IDs of interest, each paired with an event time that represents our cutoff time for that training row. The consumer ID is used to join with feature data from the consumer feature group, and the transaction event time helps us filter out newer feature values. For our example, we look up a subset of historical transactions from one specific week:
last_1w_df = spark.sql(‘select * from trans where event_time >= “2021-03-25T00:00:00Z” and event_time <= "2021-03-31T23:59:59Z"') cid_ts_tuples = last_1w_df.rdd.map(lambda r: (r.consumer_id, r.cc_num, r.event_time, r.amount, int(r.fraud_label))).collect() entity_df = spark.createDataFrame(cid_ts_tuples, entity_df_schema)
This produces the following entity dataframe that drives our point-in-time query.
To query against the offline store, we need to know the Amazon Simple Storage Service (Amazon S3) location of our feature group. We use the describe_feature_group method to look up that location:
In the preceding code, we use the S3A filesystem client. This client ensures we have the latest patches and performance enhancements for accessing S3 objects.
Now we use Spark to read data from the offline store, which is stored in Parquet format in the S3 location from the preceding code:
The following output shows the schema of the data read from the offline store. It contains several additional fields automatically populated by SageMaker: timestamp fields, as defined earlier in this post (write_time, api_invocation_time), a soft delete flag (is_deleted), and date-time partitioning fields (year, month, day, and hour).
The is_deleted attribute is a Boolean soft delete indicator for the referenced record identifier. If the DeleteRecord method is called, a new record is inserted with the is_deleted flag set to True in the offline store. The date-time partitioning fields are used to segregate the individual data files written to the offline store, and are useful when navigating to a desired timeframe or reading a subset of data.
To optimize the performance of the point-in-time query, we can immediately filter out all records that don’t meet our overall criteria. In many cases, this optimization drastically reduces the number of records we carry forward, and therefore makes the subsequent joins more efficient. We use a min/max time window to drop all data that doesn’t meet our timeframe boundary. We also include a staleness window to ensure that we don’t include records that are too old to be useful. The appropriate length of the staleness window is specific to your use case. See the following code:
# NOTE: This filter is simply a performance optimization # Filter out records from after query max_time and before staleness # window prior to the min_time. # Doing this prior to individual {consumer_id, joindate} filtering # speeds up subsequent filters for large scale queries. # Choose a “staleness” window of time before which we want # to ignore records allowed_staleness_days = 14 # Eliminate history that is outside of our time window # This window represents the {max_time – min_time} delta, # plus our staleness window # entity_df used to define bounded time window minmax_time = entity_df.agg(sql_min(“query_date”), sql_max(“query_date”)).collect() min_time, max_time = minmax_time[0][“min(query_date)”], minmax_time[0][“max(query_date)”] # Via the staleness check, we are actually removing items when # event_time is MORE than N days before min_time # Usage: datediff ( enddate, startdate ) – returns days filtered = feature_store_active_df.filter( (feature_store_active_df.event_time <= max_time) & (datediff(lit(min_time), feature_store_active_df.event_time) <= allowed_staleness_days) )
Now we’re ready to join the filtered dataset with the entity dataframe to reduce the results to only those consumer IDs (our entities) that are part of our desired training dataset. This inner join uses consumer_id as a join key, thereby removing transactions for other consumers:
This results in an enhanced dataframe with all the aggregate attributes from our consumer feature group, for each targeted training row. We still need to remove transactions that are outside of our selected time window. This window is defined as the time no later than the event time of interest, and no earlier than our selected staleness allowance. This time window filtering is run against each item that is part of our chosen list of training rows. See the following filter code, with the results named drop_future_and_stale_df:
# Filter out data from after query time to remove future data leakage. # Also filter out data that is older than our allowed staleness # window (days before each query time) drop_future_and_stale_df = t_joined.filter( (t_joined.event_time <= entity_df.query_date) & (datediff(entity_df.query_date, t_joined.event_time) <= allowed_staleness_days))
In our final training dataset, we want to allow for multiple aggregate records per entity ID (think multiple credit card transactions by a single consumer), but only keep exactly one record per transaction. Therefore, we assemble a composite key made from the consumer ID and the query timestamp: {x.consumer_id}-{x.query_date}. This step ensures that only the latest aggregate record for each composite key remains. Doing this naively (using a real sort operation) would be expensive. Instead, we can implement this using a custom reduction passed to Spark RDD reduceByKey(), which scales very well for large datasets. See the following code:
# Group by record id and query timestamp, select only the latest # remaining record by event time, # using write time as a tie breaker to account for any more # recent backfills or data corrections. latest = drop_future_and_stale_df.rdd.map(lambda x: (f'{x.consumer_id}-{x.query_date}’, x)) .reduceByKey( lambda x, y: x if (x.event_time, x.write_time) > (y.event_time, y.write_time) else y).values() latest_df = latest.toDF(drop_future_and_stale_df.schema)
To view our final results, we can select specific columns for display and reference a test_consumer_id taken from our original dataframe:
The following screenshot is a sample of the final results from our point-in-time query. These results demonstrate clearly that we’re only choosing features from the past, and not leaking any future values. The event time for each record is earlier than the query timestamp, ensuring we have the latest features, without using features that would have only been known in the future.
This completes the historical query, and we now have an accurate training dataset that represents a point-in-time query for each individual training transaction.
Conclusion
In this post, we described the concept of point-in-time correct queries and explained the importance of these queries in training effective ML models. We showed an efficient and reproducible way to use historical feature data using Feature Store and Apache Spark. We hope you experiment with the code we’ve provided, and try it out on your own datasets. We’re always looking forward to your feedback, either through your usual AWS Support contacts or on the Amazon SageMaker Discussion Forum.
About the Authors
Paul Hargis has focused his efforts on Machine Learning at several companies, including AWS, Amazon, and Hortonworks. He enjoys building technology solutions and also teaching people how to make the most of it. Prior to his role at AWS, he was lead architect for Amazon Exports and Expansions helping amazon.com improve experience for international shoppers. Paul likes to help customers expand their machine learning initiatives to solve real-world problems.
Raphey Holmes is an engineering manager on GoDaddy’s Machine Learning platform team. Prior to changing careers, he worked for a decade as a high school physics teacher, and he still loves all things related to teaching and learning. See picture attached.
Jason Mackay is a Principal SDE at GoDaddy on the GoDaddy’s Machine Learning Team. He has been in the software industry for 25 years spanning operating systems, parallel/concurrent/distributed systems, formal languages, high performance cryptography, big data, and machine learning.
Mark Roy is a Principal Machine Learning Architect for AWS, helping customers design and build AI/ML solutions. Mark’s work covers a wide range of ML use cases, with a primary interest in computer vision, deep learning, and scaling ML across the enterprise. He has helped companies in many industries, including insurance, financial services, media and entertainment, healthcare, utilities, and manufacturing. Mark holds six AWS certifications, including the ML Specialty Certification. Prior to joining AWS, Mark was an architect, developer, and technology leader for over 25 years, including 19 years in financial services.
In this post, we discuss the core capabilities of Amazon Elastic Compute Cloud (Amazon EC2) P5e instances and the use cases they’re well-suited for. We walk you through an example of how to get started with these instances and carry out inference deployment of Meta Llama 3.1 70B and 405B models on them. Source
In this post, we discuss the core capabilities of Amazon Elastic Compute Cloud (Amazon EC2) P5e instances and the use cases they’re well-suited for. We walk you through an example of how to get started with these instances and carry out inference deployment of Meta Llama 3.1 70B and 405B models on them.
AWS Weekly Roundup: Amazon DynamoDB, AWS AppSync, Storage Browser for Amazon S3, and more (September 9, 2024)
Last week, the latest AWS Heroes arrived! AWS Heroes are amazing technical experts who generously share their insights, best practices, and innovative solutions to help others. The AWS GenAI Lofts are in full swing with San Francisco and São Paulo open now, and London, Paris, and Seoul coming in the next couple of months. Here’s…
Last week, the latest AWS Heroes arrived! AWS Heroes are amazing technical experts who generously share their insights, best practices, and innovative solutions to help others. The AWS GenAI Lofts are in full swing with San Francisco and São Paulo open now, and London, Paris, and Seoul coming in the next couple of months. Here’s […]
Align Meta Llama 3 to human preferences with DPO, Amazon SageMaker Studio, and Amazon SageMaker Ground Truth
In this post, we show you how to enhance the performance of Meta Llama 3 8B Instruct by fine-tuning it using direct preference optimization (DPO) on data collected with SageMaker Ground Truth. Source
In this post, we show you how to enhance the performance of Meta Llama 3 8B Instruct by fine-tuning it using direct preference optimization (DPO) on data collected with SageMaker Ground Truth.