Machine Learning Integration
Overview
The Riskify protocol integrates machine learning models to enhance risk assessment and optimization across different pool types. This document details the ML integration architecture, model training, and deployment process, with a focus on how to connect, extend, and deploy ML components within the Riskify ecosystem.
Architecture
1. ML Pipeline
The ML pipeline in Riskify is designed for modularity and extensibility. It covers the full lifecycle from data ingestion to on-chain deployment and monitoring. The diagram below shows the high-level flow:
graph TD
A[Data Collection] --> B[Feature Engineering]
B --> C[Model Training]
C --> D[Model Validation]
D --> E[Model Deployment]
E --> F[On-chain Integration]
F --> G[Performance Monitoring]
G --> A
2. System Components
The ML system is organized into modular classes for each stage of the pipeline. You can extend or replace any component to suit your data or modeling needs. The example below shows a typical system composition:
from dataclasses import dataclass
from typing import List, Dict, Optional
@dataclass
class MLSystem:
data_collector: DataCollector
feature_engineer: FeatureEngineer
model_trainer: ModelTrainer
model_validator: ModelValidator
model_deployer: ModelDeployer
performance_monitor: PerformanceMonitor
Data Collection
Data collection is the first step in the pipeline. Riskify supports both on-chain and off-chain (market) data sources. You can implement your own collectors to fetch additional data as needed.
1. On-chain Data
On-chain data collectors interface with smart contracts and blockchain APIs to gather pool and network metrics. Use or extend these classes to connect to your own contracts or add new metrics.
class OnChainCollector:
def collect_pool_metrics(self, pool_id: int) -> Dict[str, float]:
"""Collect pool-specific metrics."""
return {
'total_staked': self._get_total_staked(pool_id),
'utilization_rate': self._get_utilization_rate(pool_id),
'risk_score': self._get_risk_score(pool_id),
'reward_rate': self._get_reward_rate(pool_id),
}
def collect_network_metrics(self) -> Dict[str, float]:
"""Collect network-wide metrics."""
return {
'total_tvl': self._get_total_tvl(),
'network_risk': self._get_network_risk(),
'average_correlation': self._get_average_correlation(),
}
2. Market Data
Market data collectors fetch external signals such as price, volume, and sentiment. Integrate with APIs or data vendors as needed for your use case.
class MarketDataCollector:
def collect_market_metrics(self) -> Dict[str, float]:
"""Collect market-related metrics."""
return {
'price_volatility': self._get_price_volatility(),
'volume': self._get_volume(),
'liquidity': self._get_liquidity(),
'market_sentiment': self._get_market_sentiment(),
}
Feature Engineering
Feature engineering transforms raw data into model-ready features. This step is critical for model performance and interpretability. You can customize feature extraction and selection to match your data and objectives.
1. Feature Extraction
Extract features from pool, network, or market data. Normalize and preprocess as needed for your models.
class FeatureExtractor:
def extract_pool_features(
self,
pool_data: Dict[str, float]
) -> np.ndarray:
"""Extract features from pool data."""
features = np.array([
pool_data['total_staked'],
pool_data['utilization_rate'],
pool_data['risk_score'],
pool_data['reward_rate'],
])
return self._normalize_features(features)
def extract_network_features(
self,
network_data: Dict[str, float]
) -> np.ndarray:
"""Extract features from network data."""
features = np.array([
network_data['total_tvl'],
network_data['network_risk'],
network_data['average_correlation'],
])
return self._normalize_features(features)
2. Feature Selection
Select the most relevant features for your model. This can improve performance and reduce overfitting. You can implement custom selection logic or use built-in methods.
class FeatureSelector:
def select_features(
self,
features: np.ndarray,
target: np.ndarray
) -> np.ndarray:
"""Select most relevant features."""
importance = self._calculate_importance(features, target)
selected = features[:, importance > self.threshold]
return selected
Model Training
Model training is where you fit your ML models to the engineered features. Riskify provides templates for risk assessment and optimization models, but you can substitute your own architectures as needed.
1. Risk Assessment Model
This model predicts risk scores for pools or the network. You can adjust the architecture, loss function, or training loop to fit your requirements.
class RiskAssessmentModel:
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int
):
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.Sigmoid()
)
def train(
self,
features: torch.Tensor,
targets: torch.Tensor,
epochs: int = 100
):
"""Train the risk assessment model."""
optimizer = optim.Adam(self.model.parameters())
criterion = nn.MSELoss()
for epoch in range(epochs):
optimizer.zero_grad()
outputs = self.model(features)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
2. Optimization Model
The optimization model recommends how to allocate capital or rebalance pools. You can use reinforcement learning, classical optimization, or other approaches.
class PoolOptimizationModel:
def __init__(
self,
n_pools: int,
hidden_dim: int
):
self.model = nn.Sequential(
nn.Linear(n_pools * 4, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_pools),
nn.Softmax(dim=1)
)
def optimize_allocation(
self,
pool_metrics: torch.Tensor
) -> torch.Tensor:
"""Optimize risk allocation across pools."""
with torch.no_grad():
allocation = self.model(pool_metrics)
return allocation
Model Validation
Validation ensures your models are accurate and robust before deployment. Use the provided tools for metric calculation and cross-validation, or integrate your own validation logic.
1. Performance Metrics
Calculate standard metrics (MSE, MAE, R2, etc.) to evaluate model predictions. Extend this class to add custom metrics as needed.
class ModelValidator:
def calculate_metrics(
self,
predictions: np.ndarray,
targets: np.ndarray
) -> Dict[str, float]:
"""Calculate model performance metrics."""
return {
'mse': mean_squared_error(targets, predictions),
'mae': mean_absolute_error(targets, predictions),
'r2': r2_score(targets, predictions),
'explained_variance': explained_variance_score(targets, predictions)
}
def validate_model(
self,
model: nn.Module,
val_features: torch.Tensor,
val_targets: torch.Tensor
) -> Dict[str, float]:
"""Validate model performance."""
model.eval()
with torch.no_grad():
predictions = model(val_features)
return self.calculate_metrics(predictions.numpy(), val_targets.numpy())
2. Cross-Validation
Cross-validation helps prevent overfitting and gives a more reliable estimate of model performance. Adjust the number of folds or add stratification as needed.
class CrossValidator:
def cross_validate(
self,
model: nn.Module,
features: torch.Tensor,
targets: torch.Tensor,
n_folds: int = 5
) -> List[Dict[str, float]]:
"""Perform k-fold cross-validation."""
kf = KFold(n_splits=n_folds, shuffle=True)
metrics = []
for train_idx, val_idx in kf.split(features):
train_features = features[train_idx]
train_targets = targets[train_idx]
val_features = features[val_idx]
val_targets = targets[val_idx]
model.train()
self._train_fold(model, train_features, train_targets)
fold_metrics = self.validate_model(
model,
val_features,
val_targets
)
metrics.append(fold_metrics)
return metrics
Model Deployment
Deployment covers exporting trained models and integrating them with the Riskify platform. This includes converting models to ONNX format for on-chain use and preparing them for API or batch inference.
1. Model Export
Export your trained PyTorch models to ONNX format for compatibility with on-chain or cross-platform inference. This is a key step for integrating ML with smart contracts.
class ModelExporter:
def export_model(
self,
model: nn.Module,
path: str
) -> bool:
"""Export model to ONNX format."""
try:
dummy_input = torch.randn(1, model.input_dim)
torch.onnx.export(
model,
dummy_input,
path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
return True
except Exception as e:
logger.error(f"Model export failed: {str(e)}")
return False
2. On-chain Integration
To connect ML predictions to smart contracts, export your model to ONNX and deploy it to an off-chain inference service that the contract can query (e.g., via an oracle). The contract expects input features in a specific format and returns risk scores or allocations.
Integration Steps:
- Train and validate your model in Python.
- Export the model to ONNX using the provided exporter.
- Deploy the ONNX model to an inference server or oracle node.
- The smart contract (see below) calls the inference endpoint, passing in pool/network/market metrics.
- The contract receives predictions and uses them to automate risk management logic on-chain.
contract MLOptimizer {
struct ModelInput {
uint256[] poolMetrics;
uint256[] networkMetrics;
uint256[] marketMetrics;
}
struct ModelOutput {
uint256 riskScore;
uint256 confidence;
bytes extraData;
}
function predict(
ModelInput memory input
) external view returns (ModelOutput memory) {
// Preprocess input
bytes memory preprocessed = _preprocess(input);
// Call ML model
bytes memory raw_output = _callModel(preprocessed);
// Postprocess output
return _postprocess(raw_output);
}
}
Performance Monitoring
Monitoring tracks model health and performance in production. Use these tools to log metrics, detect drift, and trigger retraining as needed.
1. Metrics Tracking
Track and store model performance metrics over time. Integrate with your preferred database or monitoring stack.
class PerformanceMonitor:
def track_metrics(
self,
model_id: str,
metrics: Dict[str, float]
):
"""Track model performance metrics."""
self.db.insert_one({
'model_id': model_id,
'timestamp': datetime.now(),
'metrics': metrics
})
def get_metrics_history(
self,
model_id: str,
start_time: datetime,
end_time: datetime
) -> List[Dict[str, float]]:
"""Get historical performance metrics."""
return list(self.db.find({
'model_id': model_id,
'timestamp': {
'$gte': start_time,
'$lte': end_time
}
}))
2. Alerts
Set up alerting to notify you when model performance degrades or anomalies are detected. Customize thresholds and notification channels as needed.
class AlertSystem:
def check_performance(
self,
metrics: Dict[str, float]
):
"""Check if performance metrics trigger alerts."""
if metrics['mse'] > self.mse_threshold:
self._send_alert('High MSE detected', metrics)
if metrics['r2'] < self.r2_threshold:
self._send_alert('Low R2 Score detected', metrics)
Best Practices
Follow these best practices to ensure robust, secure, and maintainable ML integration:
Data Quality
- Regular data validation
- Outlier detection
- Missing value handling
- Feature normalization
Model Management
- Version control for models
- Regular retraining
- A/B testing
- Fallback mechanisms
Security
- Input validation
- Output verification
- Access control
- Audit logging
Monitoring
- Real-time performance tracking
- Resource utilization
- Error rates
- Prediction latency
Navigation
← Back to ML Overview | ML page → |