# Network-accelerated Distributed Machine Learning for Multi-Tenant Settings

### Presentation&#x20;

* Background:&#x20;
  * ML is becoming ubiquitous in software industry&#x20;
  * ML models are trained with large volumes of data typically over shared infrastructure&#x20;
  * DML&#x20;
    * More compute cycles
    * Process lots of distributed data&#x20;
    * But require frequent synchronization (i.e. SGD: after each 'iteration')&#x20;
  * Architecture&#x20;
    * Parameter server: Google DistBelief, Microsoft Adam&#x20;
      * Async v.s Sync&#x20;
    * P2P MPI&#x20;
      * Synchronous&#x20;
* Key factors affecting performance of distributed ML training&#x20;
  * Async v.s Sync SGD
  * A-SGD:
    * Version of the model being updated v.s. used to update the gradients&#x20;
    * Communication intensive&#x20;
      * Compute update (\~100ms)
      * Fetch model (\~600ms)&#x20;
      * **R1**: Reduce network load at the server&#x20;
    * Stragglers affect convergence&#x20;
      * Halving bandwidth on 10% --> 35% in iteration through convergence&#x20;
      * **R2**: Bound delays in the presence of stragglers&#x20;
    * Fault tolerance has huge overhead&#x20;
      * Chain replication (every worker update forwards to replica)&#x20;
        * Outgoing nic:&#x20;
          * Carries both models to the workers
          * And model updates to the replica&#x20;
      * Directly forward to replica&#x20;
        * Asynchrony + stateful leads to inconsistency&#x20;
      * **R3:** Non-divergent replication without server overhead&#x20;
  * S-SGD
    * With PS architecture&#x20;
      * Server NIC overload is a bottleneck for large models
      * Stragglers increase time per-iteration
    * MPI architecture&#x20;
      * Ring reduce algorithms are typically use: but assume homogeneous network infrastructure&#x20;
      * Compute & network stragglers increase time per-iteration
    * &#x20;**R4**: bandwidth-aware aggregation&#x20;

#### Existing works&#x20;

* R1
  * Infrequent updates to the server&#x20;
    * Workers aggregate locally, then transmit to server
    * But: cannot aggregate updates across multiple workers&#x20;
* R2
  * Dropping delayed updates&#x20;
    * updates: transmitted to the servers anyway, does not reduce the server load&#x20;
* R3&#x20;
  * No prior work&#x20;
* R4
  * Hierarchical AllReduce, but assume static bandwidth setting&#x20;

#### MLFabric&#x20;

* Contributions&#x20;
  * Prove bounded delay helps convergence&#x20;
  * Network aware ordering and aggregation of updates&#x20;
    * Helps bound delay
    * Reduce network load at parameter server
  * &#x20;Model replication strategy&#x20;
    * Guarantee bounded consistency b/w server and replica&#x20;
* **Network aware ordering and aggregation of updates**
  * &#x20;Re-ordering updates&#x20;
    * Buffering updates&#x20;
      * Bounded delay (R2) at the cost of stale read&#x20;
    * Time-shared schedules&#x20;
      * Update scheduled for later can be aggregated off-server&#x20;
      * SJF + Time-sharing satisfies R1 and R2&#x20;
* Architecture&#x20;

![](https://2097630930-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MVORxAomcgtzVVUqmws%2Fuploads%2F7pw3CKh7QNR9uqrmgFjK%2Fimage.png?alt=media\&token=8cc5eb22-f77c-47de-a6a0-f7e453f8e9bb)

* Ordering and Aggregation algorithm&#x20;
  * Scheduler batches transfer requests to server&#x20;
    * Information: model version, size of the update, norm of gradients&#x20;
  * Orders them iteratively in SJF fashion to server&#x20;
    * Completion time is determined in network aware fashion&#x20;
    * State update in each iteration
    * Update that cannot meet deadline are dropped&#x20;
  * Queued updates are grouped, aggregated and then sent to server&#x20;
    * Minimize total completion time&#x20;
    * Aggregator is chosen randomly

### Paper&#x20;

* DML: compute and network contention, resulting in stragglers&#x20;
  * Current system takes too simplistic a view of the network (having either fixed bandwidth between all the workers or as a blackbox with unknown inter-worker bandwidth)&#x20;
* ML-fabric: contention-aware DML system
  * order transfers to improve convergence&#x20;
  * opportunistically aggregates them at idle DML workers to improve resource efficiency&#x20;
  * replicates them to support new notions of fault tolerance&#x20;
  * systematically account for compute stragglers and network contention
  * implemented as a communication layer between DML applications and AllReduce libraries&#x20;
* Key idea
  * Control update delays&#x20;
    * transfer available updates at non-overlapping times, reserving bandwidth per update, and explicitly ordering them&#x20;
    * worker's delay = the difference between the server model version a worker pulled and computed on, versus the server model version at the time the worker's compute gradient is applied&#x20;
  * Dynamically aggregating or dropping updates&#x20;
  * Replicating updates for fault tolerance&#x20;
