From f00dec3bb1b69540eba0042734cb8100d60de2bf Mon Sep 17 00:00:00 2001 From: Sebastien Lepers Date: Fri, 25 Nov 2022 14:31:47 +0100 Subject: [PATCH] Added user_checker support --- .../Security/Factory/SamlFactory.php | 1 + Resources/config/services.yml | 2 +- .../Authentication/Provider/SamlProvider.php | 23 +++++++++++-- .../Provider/SamlProviderTest.php | 34 +++++++++++++++++-- .../Security/Factory/SamlFactoryTest.php | 3 +- 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/DependencyInjection/Security/Factory/SamlFactory.php b/DependencyInjection/Security/Factory/SamlFactory.php index d3277b1..cafed11 100644 --- a/DependencyInjection/Security/Factory/SamlFactory.php +++ b/DependencyInjection/Security/Factory/SamlFactory.php @@ -91,6 +91,7 @@ protected function createAuthProvider(ContainerBuilder $container, $id, $config, $definition = $container->setDefinition($providerId, new $definitionClassname('hslavich_onelogin_saml.saml_provider')) ->replaceArgument(0, new Reference($userProviderId)) ->replaceArgument(1, new Reference('event_dispatcher', ContainerInterface::NULL_ON_INVALID_REFERENCE)) + ->replaceArgument(2, new Reference('security.user_checker.'.$id)) ; if ($config['user_factory']) { diff --git a/Resources/config/services.yml b/Resources/config/services.yml index d571a18..95d4613 100644 --- a/Resources/config/services.yml +++ b/Resources/config/services.yml @@ -15,7 +15,7 @@ services: hslavich_onelogin_saml.saml_provider: class: Hslavich\OneloginSamlBundle\Security\Authentication\Provider\SamlProvider - arguments: ["", ""] + arguments: ["", "", ""] hslavich_onelogin_saml.saml_token_factory: class: Hslavich\OneloginSamlBundle\Security\Authentication\Token\SamlTokenFactory diff --git a/Security/Authentication/Provider/SamlProvider.php b/Security/Authentication/Provider/SamlProvider.php index 177fc4d..6fac9a9 100644 --- a/Security/Authentication/Provider/SamlProvider.php +++ b/Security/Authentication/Provider/SamlProvider.php @@ -12,6 +12,7 @@ use Symfony\Component\Security\Core\Authentication\Token\TokenInterface; use Symfony\Component\Security\Core\Exception\AuthenticationException; use Symfony\Component\Security\Core\Exception\UsernameNotFoundException; +use Symfony\Component\Security\Core\User\UserCheckerInterface; use Symfony\Component\Security\Core\User\UserProviderInterface; class SamlProvider implements AuthenticationProviderInterface @@ -20,11 +21,13 @@ class SamlProvider implements AuthenticationProviderInterface protected $userFactory; protected $tokenFactory; protected $eventDispatcher; + protected $userChecker; - public function __construct(UserProviderInterface $userProvider, $eventDispatcher) + public function __construct(UserProviderInterface $userProvider, $eventDispatcher, $userChecker) { $this->userProvider = $userProvider; $this->eventDispatcher = $eventDispatcher; + $this->userChecker = $userChecker; } public function setUserFactory(SamlUserFactoryInterface $userFactory) @@ -71,16 +74,30 @@ public function supports(TokenInterface $token) protected function retrieveUser($token) { try { - return $this->userProvider->loadUserByUsername($token->getUsername()); + $user = $this->userProvider->loadUserByUsername($token->getUsername()); + + return $this->checkUser($user); } catch (UsernameNotFoundException $e) { if ($this->userFactory instanceof SamlUserFactoryInterface) { - return $this->generateUser($token); + $user = $this->generateUser($token); + + return $this->checkUser($user); } throw $e; } } + protected function checkUser($user) + { + if ($user && $this->userChecker instanceof UserCheckerInterface) { + $this->userChecker->checkPreAuth($user); + $this->userChecker->checkPostAuth($user); + } + + return $user; + } + protected function generateUser($token) { $user = $this->userFactory->createUser($token); diff --git a/Tests/Authentication/Provider/SamlProviderTest.php b/Tests/Authentication/Provider/SamlProviderTest.php index 832b182..9db927a 100644 --- a/Tests/Authentication/Provider/SamlProviderTest.php +++ b/Tests/Authentication/Provider/SamlProviderTest.php @@ -49,6 +49,36 @@ public function testAuthenticateInvalidUser() $provider->authenticate($this->getSamlToken()); } + public function testAuthenticateCheckerInvalidUser() + { + $user = $this->createMock('Symfony\Component\Security\Core\User\UserInterface'); + + $userChecker = $this->createMock('Symfony\Component\Security\Core\User\UserCheckerInterface'); + $exception = new \Exception('This user is valid in SSO but invalid in app'); + $userChecker->expects($this->once())->method('checkPreAuth')->willThrowException($exception); + + $provider = $this->getProvider($user, null, null, $userChecker); + + $this->expectExceptionMessage('This user is valid in SSO but invalid in app'); + + $provider->authenticate($this->getSamlToken()); + } + + public function testAuthenticateUserCheckerPostAuth() + { + $user = $this->createMock('Symfony\Component\Security\Core\User\UserInterface'); + $user->expects($this->once())->method('getRoles')->willReturn(array()); + + $userChecker = $this->createMock('Symfony\Component\Security\Core\User\UserCheckerInterface'); + $userChecker->expects($this->once())->method('checkPostAuth'); + + $provider = $this->getProvider($user, null, null, $userChecker); + + $token = $provider->authenticate($this->getSamlToken()); + + $this->assertSame($user, $token->getUser()); + } + public function testAuthenticateWithUserFactory() { $user = $this->createMock('Symfony\Component\Security\Core\User\UserInterface'); @@ -117,7 +147,7 @@ protected function getSamlToken() return $token; } - protected function getProvider($user = null, $userFactory = null, $eventDispatcher = null) + protected function getProvider($user = null, $userFactory = null, $eventDispatcher = null, $userChecker = null) { $userProvider = $this->createMock('Symfony\Component\Security\Core\User\UserProviderInterface'); if ($user) { @@ -126,7 +156,7 @@ protected function getProvider($user = null, $userFactory = null, $eventDispatch $userProvider->method('loadUserByUsername')->will($this->throwException(new UsernameNotFoundException())); } - $provider = new SamlProvider($userProvider, $eventDispatcher); + $provider = new SamlProvider($userProvider, $eventDispatcher, $userChecker); $provider->setTokenFactory(new SamlTokenFactory()); if ($userFactory) { diff --git a/Tests/DependencyInjection/Security/Factory/SamlFactoryTest.php b/Tests/DependencyInjection/Security/Factory/SamlFactoryTest.php index b731fab..dae0483 100644 --- a/Tests/DependencyInjection/Security/Factory/SamlFactoryTest.php +++ b/Tests/DependencyInjection/Security/Factory/SamlFactoryTest.php @@ -86,7 +86,8 @@ public function testBasicCreate() $providerDefinition = $container->getDefinition('security.authentication.provider.saml.test_firewall'); $this->assertEquals(array( 'index_0' => new Reference('my_user_provider'), - 'index_1' => new Reference('event_dispatcher', ContainerInterface::NULL_ON_INVALID_REFERENCE) + 'index_1' => new Reference('event_dispatcher', ContainerInterface::NULL_ON_INVALID_REFERENCE), + 'index_2' => new Reference('security.user_checker.test_firewall') ), $providerDefinition->getArguments()); } }