Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/abstractions/pod/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ func (i *podInstance) startContainers(containersToRun int) error {
fmt.Sprintf("STUB_ID=%s", i.Stub.ExternalId),
fmt.Sprintf("STUB_TYPE=%s", i.Stub.Type),
fmt.Sprintf("KEEP_WARM_SECONDS=%d", i.StubConfig.KeepWarmSeconds),
fmt.Sprintf("CHECKPOINT_ENABLED=%t", i.StubConfig.CheckpointEnabled),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding CHECKPOINT_ENABLED to request.Env duplicates the worker-controlled flag and can override it due to merge order; it also ignores the computed effective value (GPU>1), risking inconsistent checkpoint behavior.

Prompt for AI agents
Address the following comment on pkg/abstractions/pod/instance.go at line 36:

<comment>Adding CHECKPOINT_ENABLED to request.Env duplicates the worker-controlled flag and can override it due to merge order; it also ignores the computed effective value (GPU&gt;1), risking inconsistent checkpoint behavior.</comment>

<file context>
@@ -33,6 +33,8 @@ func (i *podInstance) startContainers(containersToRun int) error {
 		fmt.Sprintf(&quot;STUB_ID=%s&quot;, i.Stub.ExternalId),
 		fmt.Sprintf(&quot;STUB_TYPE=%s&quot;, i.Stub.Type),
 		fmt.Sprintf(&quot;KEEP_WARM_SECONDS=%d&quot;, i.StubConfig.KeepWarmSeconds),
+		fmt.Sprintf(&quot;CHECKPOINT_ENABLED=%t&quot;, i.StubConfig.CheckpointEnabled),
+		fmt.Sprintf(&quot;CHECKPOINT_CONDITION=%s&quot;, i.StubConfig.CheckpointCondition),
 	}...)
</file context>

fmt.Sprintf("CHECKPOINT_CONDITION=%s", i.StubConfig.CheckpointCondition),
}...)

gpuRequest := types.GpuTypesToStrings(i.StubConfig.Runtime.Gpus)
Expand Down
2 changes: 2 additions & 0 deletions pkg/abstractions/pod/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ func (s *GenericPodService) run(ctx context.Context, authInfo *auth.AuthInfo, st
fmt.Sprintf("STUB_ID=%s", stub.ExternalId),
fmt.Sprintf("STUB_TYPE=%s", stub.Type),
fmt.Sprintf("KEEP_WARM_SECONDS=%d", stubConfig.KeepWarmSeconds),
fmt.Sprintf("CHECKPOINT_ENABLED=%t", stubConfig.CheckpointEnabled),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECKPOINT_ENABLED env var reflects stubConfig instead of the computed effective value, causing mismatches when checkpointing is disabled (e.g., multi-GPU).

Prompt for AI agents
Address the following comment on pkg/abstractions/pod/pod.go at line 285:

<comment>CHECKPOINT_ENABLED env var reflects stubConfig instead of the computed effective value, causing mismatches when checkpointing is disabled (e.g., multi-GPU).</comment>

<file context>
@@ -282,6 +282,8 @@ func (s *GenericPodService) run(ctx context.Context, authInfo *auth.AuthInfo, st
 		fmt.Sprintf(&quot;STUB_ID=%s&quot;, stub.ExternalId),
 		fmt.Sprintf(&quot;STUB_TYPE=%s&quot;, stub.Type),
 		fmt.Sprintf(&quot;KEEP_WARM_SECONDS=%d&quot;, stubConfig.KeepWarmSeconds),
+		fmt.Sprintf(&quot;CHECKPOINT_ENABLED=%t&quot;, stubConfig.CheckpointEnabled),
+		fmt.Sprintf(&quot;CHECKPOINT_CONDITION=%s&quot;, stubConfig.CheckpointCondition),
 	}...)
</file context>

fmt.Sprintf("CHECKPOINT_CONDITION=%s", stubConfig.CheckpointCondition),
}...)

gpuRequest := types.GpuTypesToStrings(stubConfig.Runtime.Gpus)
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ message GetOrCreateStubRequest {
Schema inputs = 35;
Schema outputs = 36;
bool tcp = 37;
string checkpoint_condition = 38;
}

message GetOrCreateStubResponse {
Expand Down
54 changes: 29 additions & 25 deletions pkg/gateway/services/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"log"
"math"
"os"
"path"
Expand Down Expand Up @@ -124,31 +125,34 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
Memory: in.Memory,
ImageId: in.ImageId,
},
Handler: in.Handler,
OnStart: in.OnStart,
OnDeploy: in.OnDeploy,
OnDeployStubId: in.OnDeployStubId,
CallbackUrl: in.CallbackUrl,
PythonVersion: in.PythonVersion,
TaskPolicy: gws.configureTaskPolicy(in.TaskPolicy, types.StubType(in.StubType)),
KeepWarmSeconds: uint(in.KeepWarmSeconds),
Workers: uint(in.Workers),
ConcurrentRequests: uint(in.ConcurrentRequests),
MaxPendingTasks: uint(in.MaxPendingTasks),
Volumes: in.Volumes,
Secrets: []types.Secret{},
Authorized: in.Authorized,
Autoscaler: autoscaler,
Extra: json.RawMessage(in.Extra),
CheckpointEnabled: in.CheckpointEnabled,
EntryPoint: in.Entrypoint,
Ports: in.Ports,
Env: in.Env,
Pricing: pricing,
Inputs: inputs,
Outputs: outputs,
TCP: in.Tcp,
}
Handler: in.Handler,
OnStart: in.OnStart,
OnDeploy: in.OnDeploy,
OnDeployStubId: in.OnDeployStubId,
CallbackUrl: in.CallbackUrl,
PythonVersion: in.PythonVersion,
TaskPolicy: gws.configureTaskPolicy(in.TaskPolicy, types.StubType(in.StubType)),
KeepWarmSeconds: uint(in.KeepWarmSeconds),
Workers: uint(in.Workers),
ConcurrentRequests: uint(in.ConcurrentRequests),
MaxPendingTasks: uint(in.MaxPendingTasks),
Volumes: in.Volumes,
Secrets: []types.Secret{},
Authorized: in.Authorized,
Autoscaler: autoscaler,
Extra: json.RawMessage(in.Extra),
CheckpointEnabled: in.CheckpointEnabled,
CheckpointCondition: in.CheckpointCondition,
EntryPoint: in.Entrypoint,
Ports: in.Ports,
Env: in.Env,
Pricing: pricing,
Inputs: inputs,
Outputs: outputs,
TCP: in.Tcp,
}

log.Print(stubConfig)

// Ensure GPU count is at least 1 if a GPU is required
if stubConfig.RequiresGPU() && in.GpuCount == 0 {
Expand Down
53 changes: 27 additions & 26 deletions pkg/types/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,32 +369,33 @@ const (
)

type StubConfigV1 struct {
Runtime Runtime `json:"runtime"`
Handler string `json:"handler"`
OnStart string `json:"on_start"`
OnDeploy string `json:"on_deploy"`
OnDeployStubId string `json:"on_deploy_stub_id"`
PythonVersion string `json:"python_version"`
KeepWarmSeconds uint `json:"keep_warm_seconds"`
MaxPendingTasks uint `json:"max_pending_tasks"`
CallbackUrl string `json:"callback_url"`
TaskPolicy TaskPolicy `json:"task_policy"`
Workers uint `json:"workers"`
ConcurrentRequests uint `json:"concurrent_requests"`
Authorized bool `json:"authorized"`
Volumes []*pb.Volume `json:"volumes"`
Secrets []Secret `json:"secrets,omitempty"`
Env []string `json:"env,omitempty"`
Autoscaler *Autoscaler `json:"autoscaler"`
Extra json.RawMessage `json:"extra"`
CheckpointEnabled bool `json:"checkpoint_enabled"`
WorkDir string `json:"work_dir"`
EntryPoint []string `json:"entry_point"`
Ports []uint32 `json:"ports"`
Pricing *PricingPolicy `json:"pricing"`
Inputs *Schema `json:"inputs"`
Outputs *Schema `json:"outputs"`
TCP bool `json:"tcp"`
Runtime Runtime `json:"runtime"`
Handler string `json:"handler"`
OnStart string `json:"on_start"`
OnDeploy string `json:"on_deploy"`
OnDeployStubId string `json:"on_deploy_stub_id"`
PythonVersion string `json:"python_version"`
KeepWarmSeconds uint `json:"keep_warm_seconds"`
MaxPendingTasks uint `json:"max_pending_tasks"`
CallbackUrl string `json:"callback_url"`
TaskPolicy TaskPolicy `json:"task_policy"`
Workers uint `json:"workers"`
ConcurrentRequests uint `json:"concurrent_requests"`
Authorized bool `json:"authorized"`
Volumes []*pb.Volume `json:"volumes"`
Secrets []Secret `json:"secrets,omitempty"`
Env []string `json:"env,omitempty"`
Autoscaler *Autoscaler `json:"autoscaler"`
Extra json.RawMessage `json:"extra"`
CheckpointEnabled bool `json:"checkpoint_enabled"`
CheckpointCondition string `json:"checkpoint_condition"`
WorkDir string `json:"work_dir"`
EntryPoint []string `json:"entry_point"`
Ports []uint32 `json:"ports"`
Pricing *PricingPolicy `json:"pricing"`
Inputs *Schema `json:"inputs"`
Outputs *Schema `json:"outputs"`
TCP bool `json:"tcp"`
}

type StubConfigLimitedValues struct {
Expand Down
42 changes: 26 additions & 16 deletions pkg/worker/criu.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
Expand Down Expand Up @@ -84,14 +85,14 @@ func InitializeCRIUManager(ctx context.Context, config types.CRIUConfig) (CRIUMa
return criuManager, nil
}

func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.ContainerRequest, outputLogger *slog.Logger, outputWriter io.Writer, startedChan chan int, checkpointPIDChan chan int, configPath string) (int, string, error) {
func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.ContainerRequest, outputLogger *slog.Logger, outputWriter io.Writer, startedChan chan int, checkpointPIDChan chan int, configPath string, exposeNetwork func() error) (int, string, error) {
state, createCheckpoint := s.shouldCreateCheckpoint(request)

// If checkpointing is enabled, attempt to create a checkpoint
if createCheckpoint {
outputLogger.Info("Attempting to create container checkpoint...")

exitCode, err := s.createCheckpoint(ctx, request, outputWriter, outputLogger, startedChan, checkpointPIDChan, configPath)
exitCode, err := s.createCheckpoint(ctx, request, outputWriter, outputLogger, startedChan, checkpointPIDChan, configPath, exposeNetwork)
if err != nil {
return -1, "", err
}
Expand All @@ -113,6 +114,11 @@ func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.
}
defer f.Close()

err = exposeNetwork()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Network is exposed before a successful restore, potentially allowing premature connections; expose after a successful restore instead.

(Based on the PR description stating the worker only exposes ports after checkpoint or restore.)

Prompt for AI agents
Address the following comment on pkg/worker/criu.go at line 117:

<comment>Network is exposed before a successful restore, potentially allowing premature connections; expose after a successful restore instead.

(Based on the PR description stating the worker only exposes ports after checkpoint or restore.)</comment>

<file context>
@@ -113,6 +114,11 @@ func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.
 		}
 		defer f.Close()
 
+		err = exposeNetwork()
+		if err != nil {
+			return -1, &quot;&quot;, fmt.Errorf(&quot;failed to expose network: %v&quot;, err)
</file context>

if err != nil {
return -1, "", fmt.Errorf("failed to expose network: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use %w to wrap underlying error for exposeNetwork failure to preserve error chain.

Prompt for AI agents
Address the following comment on pkg/worker/criu.go at line 119:

<comment>Use %w to wrap underlying error for exposeNetwork failure to preserve error chain.</comment>

<file context>
@@ -113,6 +114,11 @@ func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.
 
+		err = exposeNetwork()
+		if err != nil {
+			return -1, &quot;&quot;, fmt.Errorf(&quot;failed to expose network: %v&quot;, err)
+		}
+
</file context>
Suggested change
return -1, "", fmt.Errorf("failed to expose network: %v", err)
return -1, "", fmt.Errorf("failed to expose network: %w", err)

}

exitCode, err := s.criuManager.RestoreCheckpoint(ctx, &RestoreOpts{
request: request,
state: state,
Expand All @@ -123,31 +129,30 @@ func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.
configPath: configPath,
})
if err != nil {
updateStateErr := s.updateCheckpointState(request, types.CheckpointStatusRestoreFailed)
if updateStateErr != nil {
log.Error().Str("container_id", request.ContainerId).Msgf("failed to update checkpoint state: %v", updateStateErr)
var e *runc.ExitError
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO]: This logic is not complete. @luke-lombardi I want to chat with you about stop and error conditions related to a running RestoreCheckpoint process.

A SIGKILL from container stop event from user would put the checkpoint into a restore_failed state which is not what we want but this SIGKILL can also occur for other factors that are related to a failed restore. We need to hash about better indicators of failure in a restore checkpoint process.

if errors.As(err, &e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restore failures that are not runc.ExitError won't update checkpoint state or be logged, leaving stale/incorrect state.

Prompt for AI agents
Address the following comment on pkg/worker/criu.go at line 133:

<comment>Restore failures that are not runc.ExitError won&#39;t update checkpoint state or be logged, leaving stale/incorrect state.</comment>

<file context>
@@ -123,31 +129,30 @@ func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.
-			if updateStateErr != nil {
-				log.Error().Str(&quot;container_id&quot;, request.ContainerId).Msgf(&quot;failed to update checkpoint state: %v&quot;, updateStateErr)
+			var e *runc.ExitError
+			if errors.As(err, &amp;e) {
+				code := e.Status
+
</file context>

code := e.Status

if code != 137 {
log.Error().Str("container_id", request.ContainerId).Msgf("failed to restore checkpoint: %v", err)
updateStateErr := s.updateCheckpointState(request, types.CheckpointStatusRestoreFailed)
if updateStateErr != nil {
log.Error().Str("container_id", request.ContainerId).Msgf("failed to update checkpoint state: %v", updateStateErr)
}
}
}
return exitCode, "", err
}

outputLogger.Info("Checkpoint found and restored")
return exitCode, request.ContainerId, nil
}

// If a checkpoint exists but is not available (previously failed), run the container normally
bundlePath := filepath.Dir(configPath)

exitCode, err := s.runcHandle.Run(s.ctx, request.ContainerId, bundlePath, &runc.CreateOpts{
OutputWriter: outputWriter,
Started: startedChan,
})

return exitCode, request.ContainerId, err
return -1, "", fmt.Errorf("checkpoint not found")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent function already handles this exact intent of running the code without checkpoints as a fallback. This is just duplicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing fallback to normal run when checkpoint is not found; this now returns an error instead of starting the container normally.

(Based on the PR description that the worker falls back to a normal run if no checkpoint is found.)

Prompt for AI agents
Address the following comment on pkg/worker/criu.go at line 150:

<comment>Missing fallback to normal run when checkpoint is not found; this now returns an error instead of starting the container normally.

(Based on the PR description that the worker falls back to a normal run if no checkpoint is found.)</comment>

<file context>
@@ -123,31 +129,30 @@ func (s *Worker) attemptCheckpointOrRestore(ctx context.Context, request *types.
-	})
-
-	return exitCode, request.ContainerId, err
+	return -1, &quot;&quot;, fmt.Errorf(&quot;checkpoint not found&quot;)
 }
 
</file context>

}

// Waits for the container to be ready to checkpoint at the desired point in execution, ie.
// after all processes within a container have reached a checkpointable state
func (s *Worker) createCheckpoint(ctx context.Context, request *types.ContainerRequest, outputWriter io.Writer, outputLogger *slog.Logger, startedChan chan int, checkpointPIDChan chan int, configPath string) (int, error) {
func (s *Worker) createCheckpoint(ctx context.Context, request *types.ContainerRequest, outputWriter io.Writer, outputLogger *slog.Logger, startedChan chan int, checkpointPIDChan chan int, configPath string, exposeNetwork func() error) (int, error) {
bundlePath := filepath.Dir(configPath)

go func() {
Expand Down Expand Up @@ -214,6 +219,11 @@ func (s *Worker) createCheckpoint(ctx context.Context, request *types.ContainerR
if updateStateErr != nil {
log.Error().Str("container_id", request.ContainerId).Msgf("failed to update checkpoint state: %v", updateStateErr)
}

err = exposeNetwork()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we are checkpointing, we do not want requests to taint the checkpoint. It should only checkpoint the ready state of the pod and not any running requests. In the case of an API server, the running request would be a connection.

if err != nil {
log.Error().Str("container_id", request.ContainerId).Msgf("failed to expose network: %v", err)
}
}()

exitCode, err := s.criuManager.Run(ctx, request, bundlePath, &runc.CreateOpts{
Expand Down
35 changes: 22 additions & 13 deletions pkg/worker/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *Worker) stopContainer(containerId string, kill bool) error {
return nil
}

log.Info().Str("container_id", containerId).Msg("container stopped")
log.Info().Str("container_id", containerId).Msgf("container stopped with signal %d", signal)
return nil
}

Expand Down Expand Up @@ -542,6 +542,16 @@ func (s *Worker) getContainerEnvironment(request *types.ContainerRequest, option
return env
}

func (s *Worker) exposeBindPorts(containerId string, request *types.ContainerRequest, opts *ContainerOptions) error {
for idx, bindPort := range opts.BindPorts {
err := s.containerNetworkManager.ExposePort(containerId, bindPort, int(request.Ports[idx]))
if err != nil {
return err
}
}
return nil
}

// spawn a container using runc binary
func (s *Worker) spawn(request *types.ContainerRequest, spec *specs.Spec, outputLogger *slog.Logger, opts *ContainerOptions) {
ctx, cancel := context.WithCancel(s.ctx)
Expand Down Expand Up @@ -668,15 +678,6 @@ func (s *Worker) spawn(request *types.ContainerRequest, spec *specs.Spec, output
}
}

// Expose the bind ports
for idx, bindPort := range opts.BindPorts {
err = s.containerNetworkManager.ExposePort(containerId, bindPort, int(request.Ports[idx]))
if err != nil {
log.Error().Str("container_id", containerId).Msgf("failed to expose container bind port: %v", err)
return
}
}

// Write runc config spec to disk
configContents, err := json.MarshalIndent(spec, "", " ")
if err != nil {
Expand Down Expand Up @@ -726,7 +727,7 @@ func (s *Worker) spawn(request *types.ContainerRequest, spec *specs.Spec, output
go s.watchOOMEvents(ctx, request, outputLogger, &isOOMKilled) // Watch for OOM events
}()

exitCode, containerId, _ = s.runContainer(ctx, request, configPath, outputLogger, outputWriter, startedChan, checkpointPIDChan)
exitCode, containerId, _ = s.runContainer(ctx, request, configPath, outputLogger, outputWriter, startedChan, checkpointPIDChan, opts)

stopReason := types.StopContainerReasonUnknown
containerInstance, exists = s.containerInstances.Get(containerId)
Expand Down Expand Up @@ -762,16 +763,24 @@ func (s *Worker) spawn(request *types.ContainerRequest, spec *specs.Spec, output
}
}

func (s *Worker) runContainer(ctx context.Context, request *types.ContainerRequest, configPath string, outputLogger *slog.Logger, outputWriter *common.OutputWriter, startedChan chan int, checkpointPIDChan chan int) (int, string, error) {
func (s *Worker) runContainer(ctx context.Context, request *types.ContainerRequest, configPath string, outputLogger *slog.Logger, outputWriter *common.OutputWriter, startedChan chan int, checkpointPIDChan chan int, opts *ContainerOptions) (int, string, error) {
// Handle checkpoint creation & restore if applicable
if s.IsCRIUAvailable(request.GpuCount) && request.CheckpointEnabled {
exitCode, containerId, err := s.attemptCheckpointOrRestore(ctx, request, outputLogger, outputWriter, startedChan, checkpointPIDChan, configPath)
exitCode, containerId, err := s.attemptCheckpointOrRestore(ctx, request, outputLogger, outputWriter, startedChan, checkpointPIDChan, configPath, func() error {
return s.exposeBindPorts(request.ContainerId, request, opts)
})
if err == nil {
return exitCode, containerId, err
}
log.Warn().Str("container_id", request.ContainerId).Err(err).Msgf("error running container from checkpoint/restore, exit code %d", exitCode)
}

err := s.exposeBindPorts(request.ContainerId, request, opts)
if err != nil {
log.Error().Str("container_id", request.ContainerId).Msgf("failed to expose container bind ports: %v", err)
return -1, "", err
}

bundlePath := filepath.Dir(configPath)
exitCode, err := s.runcHandle.Run(ctx, request.ContainerId, bundlePath, &runc.CreateOpts{
OutputWriter: outputWriter,
Expand Down
Loading
Loading