Network-accelerated Distributed Machine Learning for Multi-Tenant Settings
https://dl.acm.org/doi/pdf/10.1145/3419111.3421296?casa_token=eX-ZTtI2CTwAAAAA:7TPapL35POUlmFtA46isDCmIbtMy6gtXADpQzZPUcNdywSFA1_qjYQLmHfOLP7uXAVBjrCL1ZFhE
Presentation 
- Background: - ML is becoming ubiquitous in software industry 
- ML models are trained with large volumes of data typically over shared infrastructure 
- DML - More compute cycles 
- Process lots of distributed data 
- But require frequent synchronization (i.e. SGD: after each 'iteration') 
 
- Architecture - Parameter server: Google DistBelief, Microsoft Adam - Async v.s Sync 
 
- P2P MPI - Synchronous 
 
 
 
- Key factors affecting performance of distributed ML training - Async v.s Sync SGD 
- A-SGD: - Version of the model being updated v.s. used to update the gradients 
- Communication intensive - Compute update (~100ms) 
- Fetch model (~600ms) 
- R1: Reduce network load at the server 
 
- Stragglers affect convergence - Halving bandwidth on 10% --> 35% in iteration through convergence 
- R2: Bound delays in the presence of stragglers 
 
- Fault tolerance has huge overhead - Chain replication (every worker update forwards to replica) - Outgoing nic: - Carries both models to the workers 
- And model updates to the replica 
 
 
- Directly forward to replica - Asynchrony + stateful leads to inconsistency 
 
- R3: Non-divergent replication without server overhead 
 
 
- S-SGD - With PS architecture - Server NIC overload is a bottleneck for large models 
- Stragglers increase time per-iteration 
 
- MPI architecture - Ring reduce algorithms are typically use: but assume homogeneous network infrastructure 
- Compute & network stragglers increase time per-iteration 
 
- R4: bandwidth-aware aggregation 
 
 
Existing works 
- R1 - Infrequent updates to the server - Workers aggregate locally, then transmit to server 
- But: cannot aggregate updates across multiple workers 
 
 
- R2 - Dropping delayed updates - updates: transmitted to the servers anyway, does not reduce the server load 
 
 
- R3 - No prior work 
 
- R4 - Hierarchical AllReduce, but assume static bandwidth setting 
 
MLFabric 
- Contributions - Prove bounded delay helps convergence 
- Network aware ordering and aggregation of updates - Helps bound delay 
- Reduce network load at parameter server 
 
- Model replication strategy - Guarantee bounded consistency b/w server and replica 
 
 
- Network aware ordering and aggregation of updates - Re-ordering updates - Buffering updates - Bounded delay (R2) at the cost of stale read 
 
- Time-shared schedules - Update scheduled for later can be aggregated off-server 
- SJF + Time-sharing satisfies R1 and R2 
 
 
 
- Architecture 

- Ordering and Aggregation algorithm - Scheduler batches transfer requests to server - Information: model version, size of the update, norm of gradients 
 
- Orders them iteratively in SJF fashion to server - Completion time is determined in network aware fashion 
- State update in each iteration 
- Update that cannot meet deadline are dropped 
 
- Queued updates are grouped, aggregated and then sent to server - Minimize total completion time 
- Aggregator is chosen randomly 
 
 
Paper 
- DML: compute and network contention, resulting in stragglers - Current system takes too simplistic a view of the network (having either fixed bandwidth between all the workers or as a blackbox with unknown inter-worker bandwidth) 
 
- ML-fabric: contention-aware DML system - order transfers to improve convergence 
- opportunistically aggregates them at idle DML workers to improve resource efficiency 
- replicates them to support new notions of fault tolerance 
- systematically account for compute stragglers and network contention 
- implemented as a communication layer between DML applications and AllReduce libraries 
 
- Key idea - Control update delays - transfer available updates at non-overlapping times, reserving bandwidth per update, and explicitly ordering them 
- worker's delay = the difference between the server model version a worker pulled and computed on, versus the server model version at the time the worker's compute gradient is applied 
 
- Dynamically aggregating or dropping updates 
- Replicating updates for fault tolerance 
 
Last updated
Was this helpful?