1- #
2- # Copyright 2018-2023 - Swiss Data Science Center (SDSC)
3- # A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4- # Eidgenössische Technische Hochschule Zürich (ETHZ).
1+ # Copyright Swiss Data Science Center (SDSC). A partnership between
2+ # École Polytechnique Fédérale de Lausanne (EPFL) and
3+ # Eidgenössische Technische Hochschule Zürich (ETHZ).
54#
65# Licensed under the Apache License, Version 2.0 (the "License");
76# you may not use this file except in compliance with the License.
2625
2726from renku .core import errors
2827from renku .core .config import get_value
28+ from renku .core .constant import ProviderPriority
2929from renku .core .login import read_renku_token
3030from renku .core .plugin import hookimpl
3131from renku .core .session .utils import get_renku_project_name , get_renku_url
3434from renku .core .util .jwt import is_token_expired
3535from renku .core .util .ssh import SystemSSHConfig
3636from renku .domain_model .project_context import project_context
37- from renku .domain_model .session import ISessionProvider , Session
37+ from renku .domain_model .session import ISessionProvider , Session , SessionStopStatus
3838
3939if TYPE_CHECKING :
4040 from renku .core .dataset .providers .models import ProviderParameter
@@ -44,6 +44,8 @@ class RenkulabSessionProvider(ISessionProvider):
4444 """A session provider that uses the notebook service API to launch sessions."""
4545
4646 DEFAULT_TIMEOUT_SECONDS = 300
47+ # NOTE: Give the renkulab provider the lowest priority so that it's checked last
48+ priority : ProviderPriority = ProviderPriority .LOWEST
4749
4850 def __init__ (self ):
4951 self .__renku_url : Optional [str ] = None
@@ -187,7 +189,7 @@ def _cleanup_ssh_connection_configs(
187189 gotten from the server.
188190 """
189191 if not running_sessions :
190- running_sessions = self .session_list ("" , None , ssh_garbage_collection = False )
192+ running_sessions = self .session_list (project_name = "" , ssh_garbage_collection = False )
191193
192194 system_config = SystemSSHConfig ()
193195
@@ -199,7 +201,8 @@ def _cleanup_ssh_connection_configs(
199201 if path not in session_config_paths :
200202 path .unlink ()
201203
202- def _remote_head_hexsha (self ):
204+ @staticmethod
205+ def _remote_head_hexsha ():
203206 remote = get_remote (repository = project_context .repository )
204207
205208 if remote is None :
@@ -221,7 +224,8 @@ def _send_renku_request(self, req_type: str, *args, **kwargs):
221224 )
222225 return res
223226
224- def _project_name_from_full_project_name (self , project_name : str ) -> str :
227+ @staticmethod
228+ def _project_name_from_full_project_name (project_name : str ) -> str :
225229 """Get just project name of project name if in owner/name form."""
226230 if "/" not in project_name :
227231 return project_name
@@ -282,9 +286,7 @@ def get_open_parameters(self) -> List["ProviderParameter"]:
282286 ProviderParameter ("ssh" , help = "Open a remote terminal through SSH." , is_flag = True ),
283287 ]
284288
285- def session_list (
286- self , project_name : str , config : Optional [Dict [str , Any ]], ssh_garbage_collection : bool = True
287- ) -> List [Session ]:
289+ def session_list (self , project_name : str , ssh_garbage_collection : bool = True ) -> List [Session ]:
288290 """Lists all the sessions currently running by the given session provider.
289291
290292 Returns:
@@ -398,45 +400,67 @@ def session_start(
398400 )
399401 raise errors .RenkulabSessionError ("Cannot start session via the notebook service because " + res .text )
400402
401- def session_stop (self , project_name : str , session_name : Optional [str ], stop_all : bool ) -> bool :
403+ def session_stop (self , project_name : str , session_name : Optional [str ], stop_all : bool ) -> SessionStopStatus :
402404 """Stops all sessions (for the given project) or a specific interactive session."""
403405 responses = []
406+ sessions = self .session_list (project_name = project_name )
407+ n_sessions = len (sessions )
408+
409+ if n_sessions == 0 :
410+ return SessionStopStatus .NO_ACTIVE_SESSION
411+
404412 if stop_all :
405- sessions = self .session_list (project_name = project_name , config = None )
406413 for session in sessions :
407414 responses .append (
408415 self ._send_renku_request (
409416 "delete" , f"{ self ._notebooks_url ()} /servers/{ session .id } " , headers = self ._auth_header ()
410417 )
411418 )
412419 self ._wait_for_session_status (session .id , "stopping" )
413- else :
420+ elif session_name :
414421 responses .append (
415422 self ._send_renku_request (
416423 "delete" , f"{ self ._notebooks_url ()} /servers/{ session_name } " , headers = self ._auth_header ()
417424 )
418425 )
419426 self ._wait_for_session_status (session_name , "stopping" )
427+ elif n_sessions == 1 :
428+ responses .append (
429+ self ._send_renku_request (
430+ "delete" , f"{ self ._notebooks_url ()} /servers/{ sessions [0 ].id } " , headers = self ._auth_header ()
431+ )
432+ )
433+ self ._wait_for_session_status (sessions [0 ].id , "stopping" )
434+ else :
435+ return SessionStopStatus .NAME_NEEDED
420436
421437 self ._cleanup_ssh_connection_configs (project_name )
422438
423- return all ([ response . status_code == 204 for response in responses ]) if responses else False
439+ n_successfully_stopped = len ([ r for r in responses if r . status_code == 204 ])
424440
425- def session_open (self , project_name : str , session_name : str , ssh : bool = False , ** kwargs ) -> bool :
441+ return SessionStopStatus .SUCCESSFUL if n_successfully_stopped == n_sessions else SessionStopStatus .FAILED
442+
443+ def session_open (self , project_name : str , session_name : Optional [str ], ssh : bool = False , ** kwargs ) -> bool :
426444 """Open a given interactive session.
427445
428446 Args:
429447 project_name(str): Renku project name.
430- session_name(str): The unique id of the interactive session.
448+ session_name(Optional[ str] ): The unique id of the interactive session.
431449 ssh(bool): Whether to open an SSH connection or a normal browser interface.
432450 """
433- sessions = self .session_list ("" , None )
451+ sessions = self .session_list (project_name = "" )
434452 system_config = SystemSSHConfig ()
435453 name = self ._project_name_from_full_project_name (project_name )
436454 ssh_prefix = f"{ system_config .renku_host } -{ name } -"
437455
456+ if not session_name :
457+ if len (sessions ) == 1 :
458+ session_name = sessions [0 ].id
459+ else :
460+ return False
461+
438462 if session_name .startswith (ssh_prefix ):
439- # NOTE: use passed in ssh connection name instead of session id by accident
463+ # NOTE: User passed in ssh connection name instead of session id by accident
440464 session_name = session_name .replace (ssh_prefix , "" , 1 )
441465
442466 if not any (s .id == session_name for s in sessions ):
0 commit comments