diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 7c99380..d6dc5a5 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -531,10 +531,6 @@ def parallelize_model( for i in range(len(module.layers)): module.layers[i] = checkpoint_wrapper( module.layers[i], - context_fn=partial( - create_selective_checkpoint_contexts, - get_default_policy(no_recompute_ops), - ), ) if distributed_args.compile: