使用 Docker 容器运行源神开源的 Grok-1 🤔

查看 26|回复 1
作者:mayooot   
最近工作中让我尝试跑一跑马斯克开源的 Grok-1 ,正好自己业务时间也在看这个,但是跑的过程中大部分都是依赖、环境相关的问题,而且在 Github 的 issue/discussions 里都没找到怎么用 Docker 运行或者提供一个基础环境的镜像。
所以我打算献丑一下,给大家提供一个镜像,在我们的环境里是可以正常运行的,但是一批服务器都大差不差,所以在别的环境运行起来可能有错误,大家有空测试的话,欢迎来讨论讨论。
下面就是 README 的内容了,如果对你有帮助,请点一个 Star⭐(最近打算看看新的工作机会,Star 多一点面试就能吹牛逼了😄)
概述
最近源神开源了 Grok-1 大模型,想着跑起来看看是什么样子。Grok 的 GitHub 里写的非常清楚了,首先 clone 代码,然后下载模型(大概
300 个 G ),然后执行:
pip install -r requirements.txt
python run.py
听起来很简单,就像把大象塞进冰箱需要几步一样。但是实际上模型要依赖
jax 、jaxlib ,这俩对环境要求还是比较苛刻的,所以尝试在服务器上运行了一下,各种报错,无奈只能使用容器一个个环境的尝试,最后成功构建出一个可以运行的镜像(下面会展示宿主机和容器的环境)。这个镜像是适用于我们的环境的,在别的环境下不知道能否正常运行,所以欢迎你使用后给出一点反馈。
我做了什么
首先模型文件非常大,不适合每次都 docker cp 进基础环境的容器中,而且如果这个容器经过调试后可用,那么 commit
时也会把模型顺带着保存,那么这个镜像的体积可就太大了。所以模型文件,使用 -v 挂载进容器的 /root 下。
而代码比较小,大概
900MB ,调试中免不了要修改一些代码,并且这些是希望调试好后直接保存进容器的,所以我将程序代码通过 docker cp
复制到了容器里,并且提交的镜像里也有,方便你直接使用。
然后就是安装各种环境,遇到一个报错解决一个。
GitHub
项目地址: https://github.com/mayooot/grok-docker
欢迎 ✨
快速启动
首先拉取镜像,大概 8 个 G 。
docker pull mayooot/grok-docker:v1
然后要将下载的模型文件 ckpt-0
目录挂载进容器,下载教程可以参考这篇文章:Grok-1 本地部署过程。
最后启动容器。

  • 注意要将 $your-dir/ckpt-0 替换成你的实际模型地址。
  • 共享内存设置为了 600g ,应该是够用的,如果不够,请自行调整。
  • 要跑起来模型大概需要 8 张 A800/A100 。所以这里使用 --gpus all 将所有 gpu 挂载进去。


    docker run -d -it \
    --network=host \
    --shm-size 600g \
    --name=grok-docker \
    --gpus all \
    -v $your-dir/ckpt-0:/root/ckpt-0 \
    mayooot/grok-docker:v1
    训练
    程序代码已经存在于容器中,并且修改了模型的加载路径,所以只要你正确的把 ckpt-0 挂载进容器,那么直接执行下面代码,然后等待结果。
    docker exec -it grok-docker bash
    cd /root/grok-1/
    python run.py
    运行结果:

    环境
    宿主机环境
  • OS: Ubuntu 20.04.4
  • Physical Storage: 1TB
  • Physical Memory: 2TB
  • GPU: 8 * NVIDIA A100 80GB
  • Docker: 24.0.5
  • Nvidia Driver: 525.85.12

    容器环境
    $ cat /etc/issue
    Ubuntu 22.04.1 LTS \n \l
    $ python --version
    Python 3.10.8
    $ nvcc --version
    nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2022 NVIDIA Corporation
    Built on Wed_Sep_21_10:33:58_PDT_2022
    Cuda compilation tools, release 11.8, V11.8.89
    Build cuda_11.8.r11.8/compiler.31833905_0
    $ pip show jax
    Name: jax
    Version: 0.4.26
    Summary: Differentiate, compile, and transform Numpy code.
    Home-page: https://github.com/google/jax
    Author: JAX team
    Author-email: [email protected]
    License: Apache-2.0
    Location: /root/miniconda3/lib/python3.10/site-packages
    Requires: ml-dtypes, numpy, opt-einsum, scipy
    Required-by: chex, flax, optax, orbax-checkpoint
    $ pip show jaxlib
    Name: jaxlib
    Version: 0.4.26+cuda12.cudnn89
    Summary: XLA library for JAX
    Home-page: https://github.com/google/jax
    Author: JAX team
    Author-email: [email protected]
    License: Apache-2.0
    Location: /root/miniconda3/lib/python3.10/site-packages
    Requires: ml-dtypes, numpy, scipy
    Required-by: chex, optax, orbax-checkpoint
  • crackidz   
    其他的不说,你这个配置很可...
    如果 GGUF 格式的话,直接 llama.cpp 之类的直接跑吧
    您需要登录后才可以回帖 登录 | 立即注册

    返回顶部