diff --git a/thefuck/shells/__init__.py b/thefuck/shells/__init__.py index 9a62225..0709f4f 100644 --- a/thefuck/shells/__init__.py +++ b/thefuck/shells/__init__.py @@ -19,7 +19,16 @@ shells = {'bash': Bash, 'powershell': Powershell} -def _get_shell(): +def _get_shell_from_env(): + path = os.environ.get('SHELL', '') + base_name = os.path.basename(path) + name = os.path.splitext(base_name)[0] + + if name in shells: + return shells[name]() + + +def _get_shell_from_proc(): proc = Process(os.getpid()) while proc is not None and proc.pid > 0: @@ -41,4 +50,4 @@ def _get_shell(): return Generic() -shell = _get_shell() +shell = _get_shell_from_env() or _get_shell_from_proc()