diff --git a/Example/KnitExample.xcodeproj/project.pbxproj b/Example/KnitExample.xcodeproj/project.pbxproj index 09fbeb9d..183a8206 100644 --- a/Example/KnitExample.xcodeproj/project.pbxproj +++ b/Example/KnitExample.xcodeproj/project.pbxproj @@ -399,7 +399,7 @@ INFOPLIST_KEY_UILaunchScreen_Generation = YES; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - IPHONEOS_DEPLOYMENT_TARGET = 15.0; + IPHONEOS_DEPLOYMENT_TARGET = 16.2; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -429,7 +429,7 @@ INFOPLIST_KEY_UILaunchScreen_Generation = YES; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - IPHONEOS_DEPLOYMENT_TARGET = 15.0; + IPHONEOS_DEPLOYMENT_TARGET = 16.2; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -450,7 +450,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 15.0; + IPHONEOS_DEPLOYMENT_TARGET = 16.2; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = com.cashapp.KnitExampleTests; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -468,7 +468,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 15.0; + IPHONEOS_DEPLOYMENT_TARGET = 16.2; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = com.cashapp.KnitExampleTests; PRODUCT_NAME = "$(TARGET_NAME)"; diff --git a/Example/KnitExample/ContentView.swift b/Example/KnitExample/ContentView.swift index 1b595ade..d24f724d 100644 --- a/Example/KnitExample/ContentView.swift +++ b/Example/KnitExample/ContentView.swift @@ -21,8 +21,8 @@ struct ContentView: View { } struct ContentView_Previews: PreviewProvider { + static let assembler = ScopedModuleAssembler([KnitExampleAssembly()]) static var previews: some View { - let resolver = ModuleAssembler([KnitExampleAssembly()]).resolver - return ContentView(resolver: resolver) + return ContentView(resolver: assembler.resolver) } } diff --git a/Example/KnitExample/KnitExampleApp.swift b/Example/KnitExample/KnitExampleApp.swift index 7b34baf6..6a4b99bb 100644 --- a/Example/KnitExample/KnitExampleApp.swift +++ b/Example/KnitExample/KnitExampleApp.swift @@ -8,11 +8,14 @@ import Knit @main struct KnitExampleApp: App { - let resolver: Resolver + let assembler: ScopedModuleAssembler + var resolver: Resolver { assembler.resolver } @MainActor init() { - resolver = ModuleAssembler([KnitExampleAssembly()]).resolver + assembler = ScopedModuleAssembler( + [KnitExampleAssembly()] + ) } var body: some Scene { diff --git a/Example/KnitExample/KnitExampleAssembly.swift b/Example/KnitExample/KnitExampleAssembly.swift index 440d79b6..7f20012e 100644 --- a/Example/KnitExample/KnitExampleAssembly.swift +++ b/Example/KnitExample/KnitExampleAssembly.swift @@ -13,9 +13,7 @@ final class KnitExampleAssembly: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { - container.addBehavior(ServiceCollector()) - + func assemble(container: Container) { container.register(ExampleService.self) { ExampleService.make(resolver: $0) } // @knit alias("example") diff --git a/Example/KnitExample/KnitExampleUserAssembly.swift b/Example/KnitExample/KnitExampleUserAssembly.swift index 786073af..41cbe817 100644 --- a/Example/KnitExample/KnitExampleUserAssembly.swift +++ b/Example/KnitExample/KnitExampleUserAssembly.swift @@ -13,7 +13,7 @@ final class KnitExampleUserAssembly: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { + func assemble(container: Container) { container.register(UserService.self) { _ in UserService() } } } diff --git a/Example/KnitExampleTests/TestModuleAssembler.swift b/Example/KnitExampleTests/TestModuleAssembler.swift index 37a12904..7209896b 100644 --- a/Example/KnitExampleTests/TestModuleAssembler.swift +++ b/Example/KnitExampleTests/TestModuleAssembler.swift @@ -9,7 +9,9 @@ import Knit extension KnitExampleAssembly { @MainActor static func makeAssemblerForTests() -> ModuleAssembler { - ModuleAssembler([KnitExampleAssembly()]) + ModuleAssembler( + [KnitExampleAssembly()] + ) } static func makeArgumentsForTests() -> KnitExampleRegistrationTestArguments { @@ -26,6 +28,8 @@ extension KnitExampleAssembly { extension KnitExampleUserAssembly { @MainActor static func makeAssemblerForTests() -> ModuleAssembler { - ModuleAssembler([KnitExampleUserAssembly(), KnitExampleAssembly()]) + ModuleAssembler( + [KnitExampleUserAssembly(), KnitExampleAssembly()] + ) } } diff --git a/Package.swift b/Package.swift index ee258b2b..0963531d 100644 --- a/Package.swift +++ b/Package.swift @@ -8,7 +8,7 @@ let package = Package( name: "Knit", platforms: [ .macOS(.v14), - .iOS(.v15), + .iOS(.v16), ], products: [ .library(name: "Knit", targets: ["Knit"]), diff --git a/Sources/Knit/Container+MainActor.swift b/Sources/Knit/Container+MainActorRegistration.swift similarity index 76% rename from Sources/Knit/Container+MainActor.swift rename to Sources/Knit/Container+MainActorRegistration.swift index 9aa58e99..7181955b 100644 --- a/Sources/Knit/Container+MainActor.swift +++ b/Sources/Knit/Container+MainActorRegistration.swift @@ -4,11 +4,9 @@ import Swinject -// This code should move into the Swinject library. -// There is an open pull request to make this change https://github.com/Swinject/Swinject/pull/570 +// MARK: - MainActor Registration Methods -// MARK: - MainActor registration -extension Container { +extension Knit.Container { /// Adds a registration for the specified service with the factory closure to specify how the service is /// resolved with dependencies which must be resolved on the main actor. @@ -27,18 +25,20 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { r in - MainActor.assumeIsolated { - return mainActorFactory(r) + return _unwrappedSwinjectContainer.register(serviceType, name: name) { r in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { + return mainActorFactory(resolver) } } } } -// MARK: - MainActor registration with Arguments -extension Container { +// MARK: - MainActor Registration Methods with Arguments + +extension Knit.Container { /// Adds a registration for the specified service with the factory closure to specify how the service is /// resolved with dependencies which must be resolved on the main actor. /// @@ -56,10 +56,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1) } } @@ -82,10 +83,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2) } } @@ -108,10 +110,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3) } } @@ -134,10 +137,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3, Arg4) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3, Arg4) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3, arg4) } } @@ -160,10 +164,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3, Arg4, Arg5) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3, arg4, arg5) } } @@ -186,10 +191,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3, arg4, arg5, arg6) } } @@ -212,10 +218,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7) } } @@ -238,10 +245,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7, arg8: Arg8) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7, arg8: Arg8) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } } @@ -264,10 +272,11 @@ extension Container { public func register( _ serviceType: Service.Type, name: String? = nil, - mainActorFactory: @escaping @MainActor (Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9) -> Service + mainActorFactory: @escaping @MainActor (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9) -> Service ) -> ServiceEntry { - return register(serviceType, name: name) { (resolver: Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7, arg8: Arg8, arg9: Arg9) in - MainActor.assumeIsolated { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7, arg8: Arg8, arg9: Arg9) in + let resolver = r.resolve(Container.self)!.resolver + return MainActor.assumeIsolated { return mainActorFactory(resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) } } diff --git a/Sources/Knit/Container+Registration.swift b/Sources/Knit/Container+Registration.swift new file mode 100644 index 00000000..cf4e9396 --- /dev/null +++ b/Sources/Knit/Container+Registration.swift @@ -0,0 +1,265 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import Swinject + +// MARK: - Registration Methods + +extension Knit.Container { + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a ``Resolver`` to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { r in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver) + } + } +} + +// MARK: - Registration Methods with Arguments + +extension Container { + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 1 argument to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 2 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 3 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 4 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3, Arg4) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3, arg4) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 5 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3, arg4, arg5) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 6 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3, arg4, arg5, arg6) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 7 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 8 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7, arg8: Arg8) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + } + } + + /// Adds a registration for the specified service with the factory closure to specify how the service is + /// resolved with dependencies. + /// + /// - Parameters: + /// - serviceType: The service type to register. + /// - name: A registration name, which is used to differentiate from other registrations + /// that have the same service and factory types. + /// - factory: The closure to specify how the service type is resolved with the dependencies of the type. + /// It is invoked when the ``Container`` needs to instantiate the instance. + /// It takes a `Resolver` instance and 9 arguments to inject dependencies to the instance, + /// and returns the instance of the component type for the service. + /// + /// - Returns: A registered ``ServiceEntry`` to configure more settings with method chaining. + @discardableResult + public func register( + _ serviceType: Service.Type, + name: String? = nil, + factory: @escaping (TargetResolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9) -> Service + ) -> ServiceEntry { + return _unwrappedSwinjectContainer.register(serviceType, name: name) { (r: Swinject.Resolver, arg1: Arg1, arg2: Arg2, arg3: Arg3, arg4: Arg4, arg5: Arg5, arg6: Arg6, arg7: Arg7, arg8: Arg8, arg9: Arg9) in + let resolver = r.resolve(Container.self)!.resolver + return factory(resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) + } + } + +} diff --git a/Sources/Knit/Container.swift b/Sources/Knit/Container.swift new file mode 100644 index 00000000..64454ebd --- /dev/null +++ b/Sources/Knit/Container.swift @@ -0,0 +1,74 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import Foundation +import Swinject + +/** + A light-weight wrapper around the `Swinject.Container` that adds type information about the `TargetResolver`. + This allows us to provide registration APIs that are specified to the `TargetResolver`. + + The Knit.Container also performs the function of a weak wrapper of the `Swinject.Container`. + */ +public class Container: Knit.Resolver { + + // MARK: - Knit.Resolver + + public var resolver: TargetResolver { + self as! TargetResolver + } + + /// Returns `true` if the backing container is still available in memory, otherwise `false`. + public var isAvailable: Bool { + _swinjectContainer != nil + } + + // MARK: - Swinject.Resolver + + public var unsafeResolver: Swinject.Resolver { + _unwrappedSwinjectContainer + } + + // MARK: - Private Properties + + private weak var _swinjectContainer: Swinject.Container? + + // MARK: - Life Cycle + + // This should not be promoted from `fileprivate` access level. + fileprivate init(_swinjectContainer: Swinject.Container) { + self._swinjectContainer = _swinjectContainer + } + + // **NOTE**: The only place this should be called is from the ModuleAssembler which + // is responsible for creating Containers. + // This should not be promoted from `internal` access level. + @discardableResult + internal static func _instantiateAndRegister(_swinjectContainer: Swinject.Container) -> Container { + let container = Container(_swinjectContainer: _swinjectContainer) + + // We don't want to make multiple copies of this Knit.Container wrapper, + // so store an instance of it in the wrapped container. + // This class only holds a weak reference to the wrapped container so no retain cycle is created. + _swinjectContainer.register( + Container.self, + factory: { _ in container } + ) + + return container + } + +} + +extension Container { + + // Force unwraps the weak Container + var _unwrappedSwinjectContainer: Swinject.Container { + guard let _swinjectContainer else { + fatalError("Attempting to resolve using a container which has been released") + } + return _swinjectContainer + } + +} diff --git a/Sources/Knit/DuplicateRegistrationDetector.swift b/Sources/Knit/DuplicateRegistrationDetector.swift index 3e445b08..0cb8c40d 100644 --- a/Sources/Knit/DuplicateRegistrationDetector.swift +++ b/Sources/Knit/DuplicateRegistrationDetector.swift @@ -39,7 +39,7 @@ public final class DuplicateRegistrationDetector { extension DuplicateRegistrationDetector: Behavior { public func container( - _ container: Container, + _ container: Swinject.Container, didRegisterType type: Type.Type, toService entry: ServiceEntry, withName name: String? diff --git a/Sources/Knit/Exports.swift b/Sources/Knit/Exports.swift deleted file mode 100644 index 030e304b..00000000 --- a/Sources/Knit/Exports.swift +++ /dev/null @@ -1 +0,0 @@ -@_exported import Swinject diff --git a/Sources/Knit/Module/Container+AbstractRegistration.swift b/Sources/Knit/Module/Container+AbstractRegistration.swift index c6a2345f..ed79f122 100644 --- a/Sources/Knit/Module/Container+AbstractRegistration.swift +++ b/Sources/Knit/Module/Container+AbstractRegistration.swift @@ -3,6 +3,7 @@ // import Foundation +import Swinject extension Container { @@ -16,7 +17,7 @@ extension Container { file: String = #fileID ) { let registration = RealAbstractRegistration(name: name, file: file, concurrency: concurrency) - addAbstractRegistration(registration) + _unwrappedSwinjectContainer.addAbstractRegistration(registration) } /// Register that a service is expected to exist but no implementation is currently available @@ -31,13 +32,17 @@ extension Container { file: String = #fileID ) { let registration = OptionalAbstractRegistration(name: name, file: file, concurrency: concurrency) - addAbstractRegistration(registration) + _unwrappedSwinjectContainer.addAbstractRegistration(registration) } +} + +extension Swinject.Container { + // Must be called before using `registerAbstract` func registerAbstractContainer() -> AbstractRegistrationContainer { let registrations = AbstractRegistrationContainer() - register(Container.AbstractRegistrationContainer.self, factory: { _ in registrations }) + register(AbstractRegistrationContainer.self, factory: { _ in registrations }) .inObjectScope(.container) addBehavior(registrations) return registrations @@ -47,7 +52,7 @@ extension Container { abstractRegistrations().abstractRegistrations.append(registration) } - private func abstractRegistrations() -> AbstractRegistrationContainer { + fileprivate func abstractRegistrations() -> AbstractRegistrationContainer { return resolve(AbstractRegistrationContainer.self)! } } @@ -75,7 +80,7 @@ public protocol AbstractRegistration { /// Register a placeholder registration to fill the unfulfilled abstract registration /// This placeholder cannot be resolved func registerPlaceholder( - container: Container, + container: Swinject.Container, errorFormatter: ModuleAssemblerErrorFormatter, dependencyTree: DependencyTree ) @@ -83,8 +88,8 @@ public protocol AbstractRegistration { extension AbstractRegistration { // Convert the key into an error - var error: Container.AbstractRegistrationError { - return Container.AbstractRegistrationError( + var error: AbstractRegistrationError { + AbstractRegistrationError( serviceType: serviceDescription, file: file, name: name @@ -110,7 +115,7 @@ fileprivate struct RealAbstractRegistration: AbstractRegistration { let concurrency: ConcurrencyAttribute func registerPlaceholder( - container: Container, + container: Swinject.Container, errorFormatter: ModuleAssemblerErrorFormatter, dependencyTree: DependencyTree ) { @@ -132,7 +137,7 @@ fileprivate struct OptionalAbstractRegistration: AbstractR typealias ServiceType = Optional func registerPlaceholder( - container: Container, + container: Swinject.Container, errorFormatter: ModuleAssemblerErrorFormatter, dependencyTree: DependencyTree ) { @@ -142,32 +147,36 @@ fileprivate struct OptionalAbstractRegistration: AbstractR } } -// MARK: - Inner types +// MARK: - -extension Container { - - public struct AbstractRegistrationError: LocalizedError { - public let serviceType: String - public let file: String - public let name: String? +public struct AbstractRegistrationError: LocalizedError { + public let serviceType: String + public let file: String + public let name: String? - public var errorDescription: String? { - var string = "Unsatisfied abstract registration. Service: \(serviceType), File: \(file)" - if let name = name { - string += ", Name: \(name)" - } - return string + public var errorDescription: String? { + var string = "Unsatisfied abstract registration. Service: \(serviceType), File: \(file)" + if let name = name { + string += ", Name: \(name)" } + return string } +} - // Array of abstract registration errors - public struct AbstractRegistrationErrors: LocalizedError { - public let errors: [AbstractRegistrationError] +// MARK: - - public var errorDescription: String? { - return errors.map { $0.localizedDescription }.joined(separator: "\n") - } +// Array of abstract registration errors +public struct AbstractRegistrationErrors: LocalizedError { + public let errors: [AbstractRegistrationError] + + public var errorDescription: String? { + return errors.map { $0.localizedDescription }.joined(separator: "\n") } +} + +// MARK: - + +extension Swinject.Container { final class AbstractRegistrationContainer: Behavior { @@ -180,7 +189,7 @@ extension Container { } func container( - _ container: Container, + _ container: Swinject.Container, didRegisterType type: Type.Type, toService entry: ServiceEntry, withName name: String? diff --git a/Sources/Knit/Module/ModuleAssembler.swift b/Sources/Knit/Module/ModuleAssembler.swift index dc2ec52b..8adb50c4 100644 --- a/Sources/Knit/Module/ModuleAssembler.swift +++ b/Sources/Knit/Module/ModuleAssembler.swift @@ -2,17 +2,20 @@ // Copyright © Block, Inc. All rights reserved. // -/// ModuleAssembler wraps the Swinject Assembler to resolves the full tree of module dependencies. +import Swinject + +/// ModuleAssembler wraps the Swinject Assembler to resolve the full tree of module dependencies. /// If dependencies are missing from the tree then the resolution will fail and indicate the missing module public final class ModuleAssembler { /// The container that registrations have been placed in. Prefer using resolver unless mutable access is required - let _container: Container + let _swinjectContainer: Swinject.Container let parent: ModuleAssembler? let serviceCollector: ServiceCollector + private let autoConfigureContainers: Bool - /// The resolver for this ModuleAssemblers container - public var resolver: Resolver { _container } + /// The unsafe resolver for this ModuleAssembler's container + public var resolver: Swinject.Resolver { _swinjectContainer } // Module types that were registered into the container owned by this ModuleAssembler var registeredModules: [any ModuleAssembly.Type] { @@ -21,11 +24,16 @@ public final class ModuleAssembler { let builder: DependencyBuilder - /** The created ModuleAssembler will create a `Container` which references the optional parent. - A depth first search will find all dependencies which will be registered + /** + The created ModuleAssembler will manage finding and applying registrations from all module assemblies. + A depth first search will find all dependencies which will be registered. + + - NOTE: Direct use of ModuleAssembler in your app will not provide safety for separation of TargetResolvers. + Specified generic Containers will be *automatically* created for any TargetResolver that is later used. + Using ModuleAssembler in this way also disallows parent-child container configuration. + If your app has multiple TargetResolvers you should only use the ScopedModuleAssembler instead. - Parameters: - - parent: A ModuleAssembler that has already been setup with some dependencies. - modules: Array of modules to register - overrideBehavior: Behavior of default override usage. - assemblyValidation: An optional closure to perform custom validation on module assemblies for this assembler. @@ -34,12 +42,11 @@ public final class ModuleAssembler { - postAssemble: Hook after all assemblies are registered to make changes to the container. */ @MainActor public convenience init( - parent: ModuleAssembler? = nil, _ modules: [any ModuleAssembly], overrideBehavior: OverrideBehavior = .defaultOverridesWhenTesting, assemblyValidation: ((any ModuleAssembly.Type) throws -> Void)? = nil, errorFormatter: ModuleAssemblerErrorFormatter = DefaultModuleAssemblerErrorFormatter(), - postAssemble: ((Container) -> Void)? = nil, + postAssemble: ((Swinject.Container) -> Void)? = nil, file: StaticString = #fileID, line: UInt = #line ) { @@ -47,12 +54,14 @@ public final class ModuleAssembler { var createdBuilder: DependencyBuilder? do { try self.init( - parent: parent, + parent: nil, _modules: modules, overrideBehavior: overrideBehavior, assemblyValidation: assemblyValidation, - errorFormatter: errorFormatter, - postAssemble: postAssemble + errorFormatter: errorFormatter, + preAssemble: nil, + postAssemble: postAssemble, + autoConfigureContainers: true ) createdBuilder = self.builder } catch { @@ -73,7 +82,9 @@ public final class ModuleAssembler { assemblyValidation: ((any ModuleAssembly.Type) throws -> Void)? = nil, errorFormatter: ModuleAssemblerErrorFormatter = DefaultModuleAssemblerErrorFormatter(), behaviors: [Behavior] = [], - postAssemble: ((Container) -> Void)? = nil + preAssemble: ((Swinject.Container) -> Void)?, + postAssemble: ((Swinject.Container) -> Void)? = nil, + autoConfigureContainers: Bool ) throws { self.builder = try DependencyBuilder( modules: modules, @@ -85,27 +96,30 @@ public final class ModuleAssembler { ) self.parent = parent - self._container = Container( - parent: parent?._container, + let _swinjectContainer = Swinject.Container( + parent: parent?._swinjectContainer, behaviors: behaviors ) - self.serviceCollector = .init(parent: parent?.serviceCollector) - self._container.addBehavior(serviceCollector) - let abstractRegistrations = self._container.registerAbstractContainer() + self._swinjectContainer = _swinjectContainer + self.autoConfigureContainers = autoConfigureContainers + preAssemble?(_swinjectContainer) + self.serviceCollector = ServiceCollector(parent: parent?.serviceCollector) + self._swinjectContainer.addBehavior(serviceCollector) + let abstractRegistrations = self._swinjectContainer.registerAbstractContainer() // Expose the dependency tree for debugging let dependencyTree = builder.dependencyTree - self._container.register(DependencyTree.self) { _ in dependencyTree } + self._swinjectContainer.register(DependencyTree.self) { _ in dependencyTree } for assembly in builder.assemblies { - assembly.assemble(container: self._container) + assembly._assemble(swinjectContainer: _swinjectContainer, autoConfigureContainers: autoConfigureContainers) } - postAssemble?(_container) + postAssemble?(_swinjectContainer) if overrideBehavior.useAbstractPlaceholders { for registration in abstractRegistrations.unfulfilledRegistrations { registration.registerPlaceholder( - container: _container, + container: _swinjectContainer, errorFormatter: errorFormatter, dependencyTree: dependencyTree ) @@ -132,8 +146,42 @@ public final class ModuleAssembler { } // Publicly expose the dependency tree so it can be used for debugging -public extension Resolver { +public extension Swinject.Resolver { func _dependencyTree(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> DependencyTree { return knitUnwrap(resolve(DependencyTree.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } + +// MARK: - + +private extension ModuleAssembly { + + @MainActor + func _assemble( + swinjectContainer: Swinject.Container, + autoConfigureContainers: Bool + ) { + let container = getContainer( + swinjectContainer: swinjectContainer, + autoConfigureContainers: autoConfigureContainers + ) + + assemble(container: container) + } + + private func getContainer( + swinjectContainer: Swinject.Container, + autoConfigureContainers: Bool + ) -> Container { + if let container = swinjectContainer.resolve(Container.self) { + return container + } + if autoConfigureContainers { + return Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + } else { + // This ModuleAssembler is being used internally by a ScopedModuleAssembler + fatalError("ModuleAssembler failed to locate appropriate Container for \(String(describing: TargetResolver.self))") + } + } + +} diff --git a/Sources/Knit/Module/ModuleAssemblerErrorFormatter.swift b/Sources/Knit/Module/ModuleAssemblerErrorFormatter.swift index 1d007d6d..0be25aa1 100644 --- a/Sources/Knit/Module/ModuleAssemblerErrorFormatter.swift +++ b/Sources/Knit/Module/ModuleAssemblerErrorFormatter.swift @@ -3,6 +3,7 @@ // import Foundation +import Swinject public protocol ModuleAssemblerErrorFormatter { func format(knitError: KnitAssemblyError, dependencyTree: DependencyTree?) -> String @@ -11,9 +12,9 @@ public protocol ModuleAssemblerErrorFormatter { extension ModuleAssemblerErrorFormatter { func format(error: Error, dependencyTree: DependencyTree?) -> String { - if let abstract = error as? Container.AbstractRegistrationErrors { + if let abstract = error as? AbstractRegistrationErrors { return format(knitError: .abstractList(abstract), dependencyTree: dependencyTree) - } else if let abstract = error as? Container.AbstractRegistrationError { + } else if let abstract = error as? AbstractRegistrationError { return format(knitError: .abstract(abstract), dependencyTree: dependencyTree) } else if let scoped = error as? ScopedModuleAssemblerError { return format(knitError: .scoped(scoped), dependencyTree: dependencyTree) @@ -59,10 +60,10 @@ public enum KnitAssemblyError { case scoped(ScopedModuleAssemblerError) /// List of errors related to abstract registrations - case abstractList(Container.AbstractRegistrationErrors) - + case abstractList(AbstractRegistrationErrors) + /// A single abstract registration error - case abstract(Container.AbstractRegistrationError) + case abstract(AbstractRegistrationError) public var localizedDescription: String { switch self { diff --git a/Sources/Knit/Module/ModuleAssembly.swift b/Sources/Knit/Module/ModuleAssembly.swift index 44b7d380..4e922025 100644 --- a/Sources/Knit/Module/ModuleAssembly.swift +++ b/Sources/Knit/Module/ModuleAssembly.swift @@ -5,7 +5,7 @@ import Foundation import Swinject -public protocol ModuleAssembly { +public protocol ModuleAssembly { associatedtype TargetResolver @@ -15,7 +15,7 @@ public protocol ModuleAssembly { // Knit will always call assemble on the MainActor. The annotation is not included because Swift is unable to // correctly check the concurrency of closures inside a @MainActor function - func assemble(container: Container) + func assemble(container: Container) /// A ModuleAssembly can replace any number of other module assemblies. /// If this assembly replaces another it is expected to provide all registrations from the replaced assemblies. diff --git a/Sources/Knit/Module/ScopedModuleAssembler.swift b/Sources/Knit/Module/ScopedModuleAssembler.swift index dbf8dfd2..846a600a 100644 --- a/Sources/Knit/Module/ScopedModuleAssembler.swift +++ b/Sources/Knit/Module/ScopedModuleAssembler.swift @@ -3,30 +3,30 @@ // import Foundation +import Swinject /// Module assembly which only allows registering assemblies which target a particular resolver type. -public final class ScopedModuleAssembler { +public final class ScopedModuleAssembler { public let internalAssembler: ModuleAssembler - public var resolver: ScopedResolver { - // swiftlint:disable:next force_cast - internalAssembler.resolver as! ScopedResolver + public var resolver: TargetResolver { + internalAssembler.resolver.resolve(Knit.Container.self)!.resolver } - /// The container that registrations have been placed in. Prefer using resolver unless mutable access is required - public var _container: Container { - return internalAssembler._container + /// Access the underlying Swinject.Resolver to resolve without type safety. + var unsafeResolver: Swinject.Resolver { + internalAssembler.resolver } @MainActor public convenience init( parent: ModuleAssembler? = nil, - _ modules: [any ModuleAssembly], + _ modules: [any ModuleAssembly], overrideBehavior: OverrideBehavior = .defaultOverridesWhenTesting, errorFormatter: ModuleAssemblerErrorFormatter = DefaultModuleAssemblerErrorFormatter(), behaviors: [Behavior] = [], - postAssemble: ((Container) -> Void)? = nil, + postAssemble: ((Container) -> Void)? = nil, file: StaticString = #fileID, line: UInt = #line ) { @@ -52,18 +52,18 @@ public final class ScopedModuleAssembler { @MainActor required init( parent: ModuleAssembler? = nil, - _modules modules: [any ModuleAssembly], + _modules modules: [any ModuleAssembly], overrideBehavior: OverrideBehavior = .defaultOverridesWhenTesting, errorFormatter: ModuleAssemblerErrorFormatter = DefaultModuleAssemblerErrorFormatter(), behaviors: [Behavior] = [], - postAssemble: ((Container) -> Void)? = nil + postAssemble: ((Container) -> Void)? = nil ) throws { // For provided modules, fail early if they are scoped incorrectly for assembly in modules { let moduleAssemblyType = type(of: assembly) - if moduleAssemblyType.resolverType != ScopedResolver.self { + if moduleAssemblyType.resolverType != TargetResolver.self { let scopingError = ScopedModuleAssemblerError.incorrectTargetResolver( - expected: String(describing: ScopedResolver.self), + expected: String(describing: TargetResolver.self), actual: String(describing: moduleAssemblyType.resolverType) ) @@ -75,16 +75,24 @@ public final class ScopedModuleAssembler { _modules: modules, overrideBehavior: overrideBehavior, assemblyValidation: { moduleAssemblyType in - guard moduleAssemblyType.resolverType == ScopedResolver.self else { + guard moduleAssemblyType.resolverType == TargetResolver.self else { throw ScopedModuleAssemblerError.incorrectTargetResolver( - expected: String(describing: ScopedResolver.self), + expected: String(describing: TargetResolver.self), actual: String(describing: moduleAssemblyType.resolverType) ) } }, errorFormatter: errorFormatter, behaviors: behaviors, - postAssemble: postAssemble + preAssemble: { container in + // Register a Container for the the current-scoped `TargetResolver` + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + postAssemble: { swinjectContainer in + let container = swinjectContainer.resolve(Container.self)! + postAssemble?(container) + }, + autoConfigureContainers: false ) } diff --git a/Sources/Knit/Resolver+Additions.swift b/Sources/Knit/Resolver+Additions.swift index b109969b..3b1b20f9 100644 --- a/Sources/Knit/Resolver+Additions.swift +++ b/Sources/Knit/Resolver+Additions.swift @@ -4,7 +4,7 @@ import Swinject -public extension Resolver { +public extension Swinject.Resolver { /// Force unwrap that improves single line error logging to help track test failures /// This is used in knit generated type safe functions @@ -30,3 +30,32 @@ public extension Resolver { return unwrapped } } + +// MARK: - + +public extension Knit.Resolver { + + /// Force unwrap that improves single line error logging to help track test failures + /// This is used in knit generated type safe functions + func knitUnwrap( + _ value: T?, + file: StaticString = #fileID, + function: StaticString = #function, + line: UInt = #line, + // Allow for an additional frame of the call-stack for better context + callsiteFile: StaticString, + callsiteFunction: StaticString, + callsiteLine: UInt + ) -> T { + self.unsafeResolver.knitUnwrap( + value, + file: file, + function: function, + line: line, + callsiteFile: callsiteFile, + callsiteFunction: callsiteFunction, + callsiteLine: callsiteLine + ) + } + +} diff --git a/Sources/Knit/Resolver.swift b/Sources/Knit/Resolver.swift new file mode 100644 index 00000000..029d0076 --- /dev/null +++ b/Sources/Knit/Resolver.swift @@ -0,0 +1,16 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import Foundation +import Swinject + +/// This effectively removes all the unsafe resolve methods from the publicly available API. +public protocol Resolver: AnyObject { + + /// Returns `true` if the backing container is still available in memory, otherwise `false`. + var isAvailable: Bool { get } + + var unsafeResolver: Swinject.Resolver { get } + +} diff --git a/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift b/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift index 4c175b94..6159bd1c 100644 --- a/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift +++ b/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift @@ -1,4 +1,5 @@ import Foundation +import Swinject extension Container { @@ -24,13 +25,14 @@ extension Container { @discardableResult public func registerIntoCollection( _ service: Service.Type, - factory: @escaping @MainActor (Resolver) -> Service + factory: @escaping @MainActor (TargetResolver) -> Service ) -> ServiceEntry { - self.register( + self._unwrappedSwinjectContainer.register( service, name: makeUniqueCollectionRegistrationName(), - factory: { resolver in + factory: { r in MainActor.assumeIsolated { + let resolver = r.resolve(Container.self)! as! TargetResolver return factory(resolver) } } diff --git a/Sources/Knit/ServiceCollection/Resolver+ServiceCollection.swift b/Sources/Knit/ServiceCollection/Resolver+ServiceCollection.swift index d6b2bb98..fed4d314 100644 --- a/Sources/Knit/ServiceCollection/Resolver+ServiceCollection.swift +++ b/Sources/Knit/ServiceCollection/Resolver+ServiceCollection.swift @@ -2,7 +2,9 @@ // Copyright © Block, Inc. All rights reserved. // -extension Resolver { +import Swinject + +extension Swinject.Resolver { /// Resolves a collection of all services registered using /// ``Container/registerIntoCollection(_:factory:)`` or diff --git a/Sources/Knit/ServiceCollection/ServiceCollector.swift b/Sources/Knit/ServiceCollection/ServiceCollector.swift index bdd3ff0f..335d0fea 100644 --- a/Sources/Knit/ServiceCollection/ServiceCollector.swift +++ b/Sources/Knit/ServiceCollection/ServiceCollector.swift @@ -24,7 +24,7 @@ public final class ServiceCollector: Behavior { /// Maps a service type to an array of service factories /// Note: We use `ObjectIdentifier` to represent the service type since `Any.Type` isn't Hashable. - private var factoriesByService: [ObjectIdentifier: [(Resolver) -> Any]] = [:] + private var factoriesByService: [ObjectIdentifier: [(Swinject.Resolver) -> Any]] = [:] private let parent: ServiceCollector? @@ -33,7 +33,7 @@ public final class ServiceCollector: Behavior { } public func container( - _ container: Container, + _ container: Swinject.Container, didRegisterType type: Type.Type, toService entry: ServiceEntry, withName name: String? @@ -60,7 +60,7 @@ public final class ServiceCollector: Behavior { factoriesByService[ObjectIdentifier(Service.self)] = factories } - private func resolveServices(resolver: Resolver) -> ServiceCollection { + private func resolveServices(resolver: Swinject.Resolver) -> ServiceCollection { let parentCollection: ServiceCollection? = parent?.resolveServices(resolver: resolver) let factories = self.factoriesByService[ObjectIdentifier(Service.self)] ?? [] let entries = factories.map { $0(resolver) as! Service } diff --git a/Sources/Knit/WeakResolver.swift b/Sources/Knit/WeakResolver.swift index 6515a8e4..e1ea6d2c 100644 --- a/Sources/Knit/WeakResolver.swift +++ b/Sources/Knit/WeakResolver.swift @@ -4,14 +4,14 @@ import Swinject -/// A resolver that weakly holds onto the container. This allows keeping a reference without the risk of leaking the container +/// A resolver that weakly holds onto the `Swinject.Container`. This allows keeping a reference without the risk of leaking the container /// Classes holding onto a WeakResolver do not take ownership of the DI graph /// This allows the container to be deallocated even if services still have references to it public final class WeakResolver { - private weak var container: Container? + private weak var container: Swinject.Container? - public init(container: Container) { + public init(container: Swinject.Container) { self.container = container } @@ -20,19 +20,29 @@ public final class WeakResolver { /// Only provide a resolver if it is still available in memory, otherwise return `nil`. /// Syntax sugar to allow optional chaining on the instance. - public var optionalResolver: Resolver? { + public var optionalResolver: Swinject.Resolver? { // We are returning `self` rather than the container to maintain weak semantics isAvailable ? self : nil } } -// MARK: - Resolver conformance +// MARK: - Knit.Resolver conformance -extension WeakResolver: Resolver { +extension WeakResolver: Knit.Resolver { + + public var unsafeResolver: any Swinject.Resolver { + unwrapped + } + +} + +// MARK: - Swinject.Resolver conformance + +extension WeakResolver: Swinject.Resolver { // Force unwraps the weak Container // Convenience accessor for private implementation - private var unwrapped: Resolver { + private var unwrapped: Swinject.Resolver { guard let container else { fatalError("Attempting to resolve using a container which has been released") } diff --git a/Sources/KnitCodeGen/TypeSafetySourceFile.swift b/Sources/KnitCodeGen/TypeSafetySourceFile.swift index e1d45771..327f5e44 100644 --- a/Sources/KnitCodeGen/TypeSafetySourceFile.swift +++ b/Sources/KnitCodeGen/TypeSafetySourceFile.swift @@ -123,7 +123,7 @@ public enum TypeSafetySourceFile { usages: String ) throws -> FunctionDeclSyntax { try FunctionDeclSyntax("\(raw: modifier)func \(raw: functionName)(\(raw: inputs)) -> \(raw: registration.service)") { - "knitUnwrap(resolve(\(raw: usages)), callsiteFile: file, callsiteFunction: function, callsiteLine: line)" + "knitUnwrap(unsafeResolver.resolve(\(raw: usages)), callsiteFile: file, callsiteFunction: function, callsiteLine: line)" } } diff --git a/Sources/KnitTesting/Resolver+Asserts.swift b/Sources/KnitTesting/Resolver+Asserts.swift index 926fc8f0..03c01cc5 100644 --- a/Sources/KnitTesting/Resolver+Asserts.swift +++ b/Sources/KnitTesting/Resolver+Asserts.swift @@ -6,7 +6,7 @@ import Knit import Swinject import XCTest -public extension Resolver { +public extension Swinject.Resolver { func assertTypeResolved( _ result: T?, diff --git a/Tests/KnitCodeGenTests/ConfigurationSetTests.swift b/Tests/KnitCodeGenTests/ConfigurationSetTests.swift index 7496d05b..468bc14a 100644 --- a/Tests/KnitCodeGenTests/ConfigurationSetTests.swift +++ b/Tests/KnitCodeGenTests/ConfigurationSetTests.swift @@ -29,7 +29,7 @@ final class ConfigurationSetTests: XCTestCase { /// Generated from ``Module1Assembly`` extension Resolver { public func service1(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service1 { - knitUnwrap(resolve(Service1.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(Service1.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } extension Module1Assembly { @@ -43,10 +43,10 @@ final class ConfigurationSetTests: XCTestCase { /// Generated from ``Module2Assembly`` extension Resolver { public func service2(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service2 { - knitUnwrap(resolve(Service2.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(Service2.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } func argumentService(string: String, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ArgumentService { - knitUnwrap(resolve(ArgumentService.self, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ArgumentService.self, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } extension Module2Assembly { @@ -60,7 +60,7 @@ final class ConfigurationSetTests: XCTestCase { /// Generated from ``Module3Assembly`` extension Resolver { public func service3(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service3 { - knitUnwrap(resolve(Service3.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(Service3.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } extension Module3Assembly { @@ -343,7 +343,7 @@ final class ConfigurationSetTests: XCTestCase { /// Generated from ``CustomAssembly`` extension Resolver { func service1(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> Service1 { - knitUnwrap(resolve(Service1.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(Service1.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } """ diff --git a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift index 3881c0a5..030f0644 100644 --- a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift +++ b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift @@ -35,25 +35,25 @@ final class TypeSafetySourceFileTests: XCTestCase { /// Generated from ``ModuleAssembly`` extension Resolve { func serviceA(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceA { - knitUnwrap(resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func serviceD(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceD { - knitUnwrap(resolve(ServiceD.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceD.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func serviceDAlias(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceD { - knitUnwrap(resolve(ServiceD.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceD.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func serviceE(closure1: @escaping () -> Void, closure2: @escaping @Sendable (Bool) -> Void, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceE { - knitUnwrap(resolve(ServiceE.self, arguments: closure1, closure2), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceE.self, arguments: closure1, closure2), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func serviceF(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceF { - knitUnwrap(resolve(ServiceF.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceF.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func stringInt(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> (String, Int?) { - knitUnwrap(resolve((String, Int?).self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve((String, Int?).self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } func serviceB(name: ModuleAssembly.ServiceB_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceB { - knitUnwrap(resolve(ServiceB.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceB.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } extension ModuleAssembly { @@ -84,7 +84,7 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ public func a(string: String, url: URL, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self, arguments: string, url), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self, arguments: string, url), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -99,7 +99,7 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ public func a(string: String, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -114,7 +114,7 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ public func a(string1: String, string2: String, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self, arguments: string1, string2), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self, arguments: string1, string2), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -129,7 +129,7 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ public func a(name: MyAssembly.A_ResolutionKey, string: String, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self, name: name.rawValue, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self, name: name.rawValue, argument: string), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -144,7 +144,7 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ public func a(arg: String, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self, argument: arg), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self, argument: arg), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -161,7 +161,7 @@ final class TypeSafetySourceFileTests: XCTestCase { """ #if SOME_FLAG public func a(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } #endif """ @@ -179,10 +179,10 @@ final class TypeSafetySourceFileTests: XCTestCase { """ #if SOME_FLAG public func a(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func fooAlias(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } #endif """ @@ -198,7 +198,7 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ @_spi(Testing) public func a(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -212,10 +212,10 @@ final class TypeSafetySourceFileTests: XCTestCase { ), """ public func a(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } public func fooAlias(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> A { - knitUnwrap(resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(A.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } """ ) @@ -400,7 +400,7 @@ final class TypeSafetySourceFileTests: XCTestCase { /// Generated from ``MainActorAssembly`` extension Resolver { @MainActor func serviceA(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceA { - knitUnwrap(resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + knitUnwrap(unsafeResolver.resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line) } } extension MainActorAssembly { diff --git a/Tests/KnitMacrosTests/SwinjectResolutionTests.swift b/Tests/KnitMacrosTests/SwinjectResolutionTests.swift index d5a243be..80ff74d3 100644 --- a/Tests/KnitMacrosTests/SwinjectResolutionTests.swift +++ b/Tests/KnitMacrosTests/SwinjectResolutionTests.swift @@ -60,7 +60,7 @@ private struct Service1 { let string: String let value: Int - @Resolvable + @Resolvable init(string: String, value: Int) { self.string = string self.value = value @@ -70,7 +70,7 @@ private struct Service1 { private struct Service2 { let closure: () -> Void - @Resolvable + @Resolvable init(closure: @escaping () -> Void) { self.closure = closure } @@ -80,12 +80,12 @@ private struct Service3 { let value: Int - @Resolvable + @Resolvable init(@UseDefault defaultedValue: Int = 2) { self.value = defaultedValue } - @Resolvable + @Resolvable static func makeService() -> Service3 { return .init(defaultedValue: 5) } @@ -93,7 +93,7 @@ private struct Service3 { private struct Service4 { let value: Float - @Resolvable + @Resolvable init(@Argument value: Float) { self.value = value } @@ -101,14 +101,14 @@ private struct Service4 { private struct Service5 { let value: Float - @Resolvable + @Resolvable init(@Named("float2") value: Float) { self.value = value } } private enum Factory { - static var container: Container { + static var container: Swinject.Container { let container = Container() container.register(String.self) { _ in "Test" } container.register(Int.self) { _ in 5 } @@ -128,7 +128,7 @@ enum FloatName: String { case float2 } -private extension Resolver { +private extension Swinject.Resolver { func float(name: FloatName) -> Float { resolve(Float.self, name: name.rawValue)! diff --git a/Tests/KnitTests/AbstractRegistrationTests.swift b/Tests/KnitTests/AbstractRegistrationTests.swift index 1afa143a..18399404 100644 --- a/Tests/KnitTests/AbstractRegistrationTests.swift +++ b/Tests/KnitTests/AbstractRegistrationTests.swift @@ -4,13 +4,15 @@ import Combine @testable import Knit +import Swinject import XCTest final class AbstractRegistrationTests: XCTestCase { func testMissingRegistration() { - let container = Container() - let abstractRegistrations = container.registerAbstractContainer() + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + let abstractRegistrations = container._unwrappedSwinjectContainer.registerAbstractContainer() container.registerAbstract(String.self) container.registerAbstract(String.self, name: "test") container.registerAbstract(Optional.self) @@ -28,8 +30,9 @@ final class AbstractRegistrationTests: XCTestCase { } func testFilledRegistrations() { - let container = Container() - let abstractRegistrations = container.registerAbstractContainer() + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + let abstractRegistrations = container._unwrappedSwinjectContainer.registerAbstractContainer() container.registerAbstract(String.self) container.register(String.self) { _ in "Test" } @@ -38,13 +41,14 @@ final class AbstractRegistrationTests: XCTestCase { container.register(Optional.self) { _ in 1 } XCTAssertNoThrow(try abstractRegistrations.validate()) - XCTAssertEqual(container.resolve(String.self), "Test") - XCTAssertEqual(container.resolve(Optional.self), 1) + XCTAssertEqual(container._unwrappedSwinjectContainer.resolve(String.self), "Test") + XCTAssertEqual(container._unwrappedSwinjectContainer.resolve(Optional.self), 1) } func testNamedRegistrations() { - let container = Container() - let abstractRegistrations = container.registerAbstractContainer() + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + let abstractRegistrations = container._unwrappedSwinjectContainer.registerAbstractContainer() container.registerAbstract(String.self) container.registerAbstract(String.self, name: "test") @@ -59,8 +63,9 @@ final class AbstractRegistrationTests: XCTestCase { } func testPreRegistered() { - let container = Container() - let abstractRegistrations = container.registerAbstractContainer() + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + let abstractRegistrations = container._unwrappedSwinjectContainer.registerAbstractContainer() container.register(String.self) { _ in "Test" } container.registerAbstract(String.self) XCTAssertNoThrow(try abstractRegistrations.validate()) @@ -68,8 +73,8 @@ final class AbstractRegistrationTests: XCTestCase { func testAbstractErrorFormatting() throws { let builder = try DependencyBuilder(modules: [Assembly1()]) - let error = Container.AbstractRegistrationError(serviceType: "String", file: "Assembly2.swift", name: nil) - let errors = Container.AbstractRegistrationErrors(errors: [error]) + let error = AbstractRegistrationError(serviceType: "String", file: "Assembly2.swift", name: nil) + let errors = AbstractRegistrationErrors(errors: [error]) let formatter = DefaultModuleAssemblerErrorFormatter() let result = formatter.format(error: errors, dependencyTree: builder.dependencyTree) XCTAssertEqual( @@ -83,8 +88,8 @@ final class AbstractRegistrationTests: XCTestCase { } @MainActor - func testOptionalAbstractRegistrations() { - let assembler = ModuleAssembler([Assembly3()]) + func testOptionalAbstractRegistrations() throws { + let assembler = try ModuleAssembler(_modules: [Assembly3()]) let string = assembler.resolver.resolve(String?.self) ?? nil XCTAssertNil(string) @@ -93,8 +98,8 @@ final class AbstractRegistrationTests: XCTestCase { } @MainActor - func testAdditionalAbstractRegistration() { - let assembler = ModuleAssembler([Assembly4()]) + func testAdditionalAbstractRegistration() throws { + let assembler = try ModuleAssembler(_modules: [Assembly4()]) _ = assembler.resolver.resolve(AnyPublisher.self) } @@ -102,19 +107,19 @@ final class AbstractRegistrationTests: XCTestCase { private struct Assembly1: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [ Assembly2.self] } - func assemble(container: Container) {} + func assemble(container: Knit.Container) {} } private struct Assembly2: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { + func assemble(container: Knit.Container) { container.registerAbstract(String.self) } } private struct Assembly3: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { + func assemble(container: Knit.Container) { container.registerAbstract(Optional.self) container.registerAbstract(Int?.self) } @@ -122,7 +127,7 @@ private struct Assembly3: AutoInitModuleAssembly { private struct Assembly4: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { + func assemble(container: Knit.Container) { // Custom handling for AnyPublisher abstract registrations is defined below container.registerAbstract(AnyPublisher.self) } @@ -149,7 +154,7 @@ private struct AnyPublisherAbstractRegistration: AbstractR } } -private extension Container { +private extension Knit.Container { // The new abstract registration type requires an additional registerAbstract function with a more explicit type func registerAbstract( @@ -163,6 +168,6 @@ private extension Container { file: file, concurrency: concurrency ) - addAbstractRegistration(registration) + _unwrappedSwinjectContainer.addAbstractRegistration(registration) } } diff --git a/Tests/KnitTests/ComplexDependencyTests.swift b/Tests/KnitTests/ComplexDependencyTests.swift index 0c6c0c6e..e4cb24a1 100644 --- a/Tests/KnitTests/ComplexDependencyTests.swift +++ b/Tests/KnitTests/ComplexDependencyTests.swift @@ -51,36 +51,36 @@ final class ComplexDependencyTests: XCTestCase { // Assembly1 depends on Assembly2 and Assembly3 private struct Assembly1: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [ Assembly3.self, Assembly2.self ] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly2 has no dependencies private struct Assembly2: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly2Fake overrides Assembly2 private struct Assembly2Fake: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } static var replaces: [any ModuleAssembly.Type] { [Assembly2.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly3 depends on a fake version of Assembly2 private struct Assembly3: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly2Fake.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly4 has a cycle with Assembly5 private struct Assembly4: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly5.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly5 has a cycle with Assembly4 private struct Assembly5: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly4.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } diff --git a/Tests/KnitTests/DependencyBuilderTests.swift b/Tests/KnitTests/DependencyBuilderTests.swift index 7a7ed83e..4da10a40 100644 --- a/Tests/KnitTests/DependencyBuilderTests.swift +++ b/Tests/KnitTests/DependencyBuilderTests.swift @@ -193,7 +193,7 @@ private struct Assembly1: ModuleAssembly { ] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly2 has no dependencies but replaces Assembly8 @@ -205,7 +205,7 @@ private struct Assembly2: AutoInitModuleAssembly { static var replaces: [any ModuleAssembly.Type] { [Assembly8.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly3 depends on Assembly1 @@ -214,7 +214,7 @@ private struct Assembly3: ModuleAssembly { return [Assembly1.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly4 depends on Assembly1 and Assembly2 @@ -223,11 +223,11 @@ private struct Assembly4: ModuleAssembly { static let dependencies: [any ModuleAssembly.Type] = [ Assembly2.self, Assembly1.self ] - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly5: ModuleAssembly, DefaultModuleAssemblyOverride { - func assemble(container: Container) {} + func assemble(container: Container) {} static var dependencies: [any ModuleAssembly.Type] { [] } // This is not valid because Assembly6 does not implement Assembly5 @@ -235,23 +235,23 @@ private struct Assembly5: ModuleAssembly, DefaultModuleAssemblyOverride { } private struct Assembly6: AutoInitModuleAssembly { - func assemble(container: Container) {} + func assemble(container: Container) {} static var dependencies: [any ModuleAssembly.Type] { [] } } private struct Assembly7: ModuleAssembly { - func assemble(container: Container) {} + func assemble(container: Container) {} static var dependencies: [any ModuleAssembly.Type] { [Assembly5.self] } } private struct Assembly8: AutoInitModuleAssembly, DefaultModuleAssemblyOverride { - func assemble(container: Container) {} + func assemble(container: Container) {} static var dependencies: [any ModuleAssembly.Type] { [] } typealias OverrideType = Assembly2 } private struct Assembly9: AutoInitModuleAssembly { - func assemble(container: Container) {} + func assemble(container: Container) {} static var dependencies: [any ModuleAssembly.Type] { [Assembly8.self] } } diff --git a/Tests/KnitTests/DependencyTreeTests.swift b/Tests/KnitTests/DependencyTreeTests.swift index ab2ed197..550595eb 100644 --- a/Tests/KnitTests/DependencyTreeTests.swift +++ b/Tests/KnitTests/DependencyTreeTests.swift @@ -42,20 +42,20 @@ final class DependencyTreeTests: XCTestCase { private struct Assembly1: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { return [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly2: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { return [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly3: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { return [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly4: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { return [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } diff --git a/Tests/KnitTests/DuplicateRegistrationDetectorTests.swift b/Tests/KnitTests/DuplicateRegistrationDetectorTests.swift index 5c14b4ad..93a1af5f 100644 --- a/Tests/KnitTests/DuplicateRegistrationDetectorTests.swift +++ b/Tests/KnitTests/DuplicateRegistrationDetectorTests.swift @@ -3,6 +3,7 @@ // @testable import Knit +import Swinject import XCTest final class DuplicateRegistrationDetectorTests: XCTestCase { @@ -12,7 +13,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { let duplicateRegistrationDetector = DuplicateRegistrationDetector(duplicateWasDetected: { key in reportedDuplicates.append(key) }) - let container = Container( + let container = Container( behaviors: [duplicateRegistrationDetector] ) @@ -24,7 +25,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { XCTAssertEqual(reportedDuplicates.count, 1) let firstReport = try XCTUnwrap(reportedDuplicates.first) XCTAssert(firstReport.serviceType == String.self) - XCTAssert(firstReport.argumentsType == (Resolver).self) + XCTAssert(firstReport.argumentsType == (Swinject.Resolver).self) XCTAssertEqual(firstReport.name, nil) container.register(String.self, factory: { _ in "three" }) @@ -50,7 +51,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { XCTAssertEqual(reportedDuplicates.count, 1) let firstReport = try XCTUnwrap(reportedDuplicates.first) XCTAssert(firstReport.serviceType == String.self) - XCTAssert(firstReport.argumentsType == (Resolver).self) + XCTAssert(firstReport.argumentsType == (Swinject.Resolver).self) XCTAssertEqual(firstReport.name, "nameOne") } @@ -73,7 +74,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { XCTAssertEqual(reportedDuplicates.count, 1) let firstReport = try XCTUnwrap(reportedDuplicates.first) XCTAssert(firstReport.serviceType == String.self) - XCTAssert(firstReport.argumentsType == (Resolver, Int).self) + XCTAssert(firstReport.argumentsType == (Swinject.Resolver, Int).self) XCTAssertEqual(firstReport.name, nil) } @@ -135,7 +136,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { func testCustomStringDescription() throws { assertCustomStringDescription(key: DuplicateRegistrationDetector.Key( serviceType: String.self, - argumentsType: ((Resolver)).self, + argumentsType: ((Knit.Resolver)).self, name: nil ), expectedDescription: """ @@ -148,7 +149,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { assertCustomStringDescription(key: DuplicateRegistrationDetector.Key( serviceType: Int.self, - argumentsType: (Resolver, Bool).self, + argumentsType: (Knit.Resolver, Bool).self, name: nil ), expectedDescription: """ @@ -161,7 +162,7 @@ final class DuplicateRegistrationDetectorTests: XCTestCase { assertCustomStringDescription(key: DuplicateRegistrationDetector.Key( serviceType: String.self, - argumentsType: ((Resolver)).self, + argumentsType: ((Knit.Resolver)).self, name: "namedRegistration" ), expectedDescription: """ diff --git a/Tests/KnitTests/FakeAssemblyTests.swift b/Tests/KnitTests/FakeAssemblyTests.swift index a67d0d4f..9c7e0e93 100644 --- a/Tests/KnitTests/FakeAssemblyTests.swift +++ b/Tests/KnitTests/FakeAssemblyTests.swift @@ -13,12 +13,12 @@ final class FakeAssemblyTests: XCTestCase { private final class RealAssembly: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Swinject.Container) {} + func assemble(container: Container) {} } private final class FakeTestAssembly: FakeAssembly { typealias ReplacedAssembly = RealAssembly - func assemble(container: Swinject.Container) {} + func assemble(container: Container) {} } extension RealAssembly: DefaultModuleAssemblyOverride { diff --git a/Tests/KnitTests/GeneratedModuleAssemblyTests.swift b/Tests/KnitTests/GeneratedModuleAssemblyTests.swift index 01d63b15..85ca135e 100644 --- a/Tests/KnitTests/GeneratedModuleAssemblyTests.swift +++ b/Tests/KnitTests/GeneratedModuleAssemblyTests.swift @@ -26,7 +26,7 @@ final class GeneratedModuleAssemblyTests: XCTestCase { } private struct Assembly1: AutoInitModuleAssembly { - func assemble(container: Container) { } + func assemble(container: Container) { } } extension Assembly1: GeneratedModuleAssembly { @@ -34,7 +34,7 @@ extension Assembly1: GeneratedModuleAssembly { } private struct Assembly2: AutoInitModuleAssembly { - func assemble(container: Container) { } + func assemble(container: Container) { } // Assembly2 explicitly sets dependencies so ignores generatedDependencies static var dependencies: [any ModuleAssembly.Type] { [] } } @@ -44,6 +44,6 @@ extension Assembly2: GeneratedModuleAssembly { } private struct Assembly3: AutoInitModuleAssembly { - func assemble(container: Container) { } + func assemble(container: Container) { } static var dependencies: [any ModuleAssembly.Type] { [] } } diff --git a/Tests/KnitTests/ModuleAssemblerTests.swift b/Tests/KnitTests/ModuleAssemblerTests.swift index ca1f2f54..87f4234a 100644 --- a/Tests/KnitTests/ModuleAssemblerTests.swift +++ b/Tests/KnitTests/ModuleAssemblerTests.swift @@ -9,34 +9,53 @@ import XCTest final class ModuleAssemblerTests: XCTestCase { @MainActor - func test_auto_assembler() { - let resolver = ModuleAssembler([Assembly1()]).resolver + func test_auto_assembler() throws { + let resolver = try ModuleAssembler( + _modules: [Assembly1()] + ).resolver XCTAssertNotNil(resolver.resolve(Service1.self)) } @MainActor - func test_non_auto_assembler() { - let resolver = ModuleAssembler([ - Assembly3(), - Assembly1(), - ]).resolver + func test_non_auto_assembler() throws { + let resolver = try ModuleAssembler( + _modules: [ + Assembly3(), + Assembly1(), + ] + ).resolver XCTAssertNotNil(resolver.resolve(Service1.self)) XCTAssertNotNil(resolver.resolve(Service3.self)) } @MainActor - func test_registered_modules() { - let assembler = ModuleAssembler([Assembly1()]) + func test_registered_modules() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly1()] + ) XCTAssertTrue(assembler.registeredModules.contains(where: {$0 == Assembly1.self})) XCTAssertTrue(assembler.registeredModules.contains(where: {$0 == Assembly2.self})) XCTAssertFalse(assembler.registeredModules.contains(where: {$0 == Assembly3.self})) } @MainActor - func test_parent_assembler() { + func test_parent_assembler() throws { // Put some modules in the parent and some in the child. - let parent = ModuleAssembler([Assembly1()]) - let child = ModuleAssembler(parent: parent, [Assembly3()]) + let parent = try ModuleAssembler( + _modules: [Assembly1()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) + let child = try ModuleAssembler( + parent: parent, + _modules: [Assembly3()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) XCTAssertTrue(child.isRegistered(Assembly1.self)) XCTAssertTrue(child.isRegistered(Assembly3.self)) XCTAssertTrue(child.isRegistered(Assembly2.self)) @@ -56,7 +75,7 @@ final class ModuleAssemblerTests: XCTestCase { ), "Should throw an error for missing concrete registration to fulfill abstract registration", { error in - guard let abstractRegistrationErrors = error as? Container.AbstractRegistrationErrors else { + guard let abstractRegistrationErrors = error as? AbstractRegistrationErrors else { XCTFail("Incorrect error type \(error)") return } @@ -81,28 +100,28 @@ final class ModuleAssemblerTests: XCTestCase { overrideBehavior: .init(allowDefaultOverrides: true, useAbstractPlaceholders: true) ) - var services = assembler._container.services.filter { (key, value) in + var services = assembler._swinjectContainer.services.filter { (key, value) in // Filter out registrations for `AbstractRegistrationContainer` and `DependencyTree` return key.serviceType != Container.AbstractRegistrationContainer.self && key.serviceType != DependencyTree.self } - XCTAssertEqual(services.count, 2) - + XCTAssertEqual(services.count, 3) + XCTAssertNotNil(services.removeValue(forKey: .init( serviceType: Assembly5Protocol.self, - argumentsType: (Resolver).self, + argumentsType: (Swinject.Resolver).self, name: nil )), "Service entry for Assembly5Protocol without name should exist") - XCTAssertEqual(services.count, 1) + XCTAssertEqual(services.count, 2) XCTAssertNotNil(services.removeValue(forKey: .init( serviceType: Assembly5Protocol.self, - argumentsType: (Resolver).self, + argumentsType: (Swinject.Resolver).self, name: "testName" )), "Service entry for Assembly5Protocol with name should exist") - // No more registrations left - XCTAssertEqual(services.count, 0) + // The last registration is for the Knit container + XCTAssertEqual(services.count, 1) } } @@ -115,8 +134,13 @@ private struct Assembly1: ModuleAssembly { ] } - func assemble(container: Container) { - container.register(Service1.self) { Service1(service2: $0.resolve(Service2.self)!) } + func assemble(container: Knit.Container) { + container.register( + Service1.self, + factory: { resolver in + Service1(service2: resolver.service2()) + } + ) } } @@ -127,8 +151,13 @@ private struct Assembly2: AutoInitModuleAssembly { return [] } - func assemble(container: Container) { - container.register(Service2.self) { _ in Service2() } + func assemble(container: Knit.Container) { + container.register( + Service2.self, + factory: { _ in + Service2() + } + ) } } @@ -138,8 +167,8 @@ private struct Assembly3: ModuleAssembly { return [Assembly1.self] } - func assemble(container: Container) { - container.register(Service3.self) { _ in Service3()} + func assemble(container: Knit.Container) { + container.register(Service3.self, factory: { _ in Service3() }) } } @@ -161,7 +190,7 @@ private struct Service3 {} private struct Assembly4: AutoInitModuleAssembly { - func assemble(container: Swinject.Container) { + func assemble(container: Knit.Container) { // None } @@ -179,7 +208,7 @@ private struct AbstractAssembly5: AbstractAssembly { [] } - func assemble(container: Swinject.Container) { + func assemble(container: Knit.Container) { container.registerAbstract(Assembly5Protocol.self) container.registerAbstract(Assembly5Protocol.self, name: "testName") } @@ -192,7 +221,7 @@ private struct Assembly5: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Swinject.Container) { + func assemble(container: Knit.Container) { // Missing a concrete registration for `Assembly5Protocol` } @@ -201,3 +230,11 @@ private struct Assembly5: AutoInitModuleAssembly { } } + +private extension TestResolver { + + func service2() -> Service2 { + unsafeResolver.resolve(Service2.self)! + } + +} diff --git a/Tests/KnitTests/ModuleAssemblyOverrideTests.swift b/Tests/KnitTests/ModuleAssemblyOverrideTests.swift index 04bf1248..6afba90c 100644 --- a/Tests/KnitTests/ModuleAssemblyOverrideTests.swift +++ b/Tests/KnitTests/ModuleAssemblyOverrideTests.swift @@ -36,20 +36,23 @@ final class ModuleAssemblyOverrideTests: XCTestCase { } @MainActor - func test_serviceRegisteredWithoutFakes() { - let resolver = ModuleAssembler([Assembly2()]).resolver + func test_serviceRegisteredWithoutFakes() throws { + let resolver = try ModuleAssembler(_modules: [Assembly2()]).resolver XCTAssertTrue(resolver.resolve(Service2Protocol.self) is Service2) } @MainActor - func test_servicesRegisteredWithFakes() { - let resolver = ModuleAssembler([Assembly2(), Assembly2Fake()]).resolver + func test_servicesRegisteredWithFakes() throws { + let resolver = try ModuleAssembler(_modules: [Assembly2(), Assembly2Fake()]).resolver XCTAssertTrue(resolver.resolve(Service2Protocol.self) is Service2Fake) } @MainActor - func test_assemblerWithDefaultOverrides() { - let assembler = ModuleAssembler([Assembly2()], overrideBehavior: .useDefaultOverrides) + func test_assemblerWithDefaultOverrides() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly2()], + overrideBehavior: .useDefaultOverrides + ) XCTAssertTrue(assembler.registeredModules.contains(where: {$0 == Assembly1Fake.self})) XCTAssertTrue(assembler.isRegistered(Assembly1Fake.self)) // Treat Assembly1 as being registered because the mock is @@ -57,55 +60,68 @@ final class ModuleAssemblyOverrideTests: XCTestCase { } @MainActor - func test_noDefaultOverrideForInputModules() { - let assembler = ModuleAssembler([Assembly1()], overrideBehavior: .useDefaultOverrides) + func test_noDefaultOverrideForInputModules() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly1()], + overrideBehavior: .useDefaultOverrides + ) XCTAssertTrue(assembler.isRegistered(Assembly1.self)) // The fake is not automatically registered XCTAssertFalse(assembler.isRegistered(Assembly1Fake.self)) } @MainActor - func test_explicitInputOverride() { - let assembler = ModuleAssembler([Assembly1(), Assembly1Fake()], overrideBehavior: .useDefaultOverrides) + func test_explicitInputOverride() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly1(), Assembly1Fake()], + overrideBehavior: .useDefaultOverrides + ) XCTAssertTrue(assembler.isRegistered(Assembly1.self)) XCTAssertTrue(assembler.isRegistered(Assembly1Fake.self)) } @MainActor - func test_assemblerWithoutDefaultOverrides() { - let assembler = ModuleAssembler([Assembly2()], overrideBehavior: .disableDefaultOverrides) + func test_assemblerWithoutDefaultOverrides() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly2()], + overrideBehavior: .disableDefaultOverrides + ) XCTAssertTrue(assembler.isRegistered(Assembly1.self)) XCTAssertFalse(assembler.isRegistered(Assembly1Fake.self)) } @MainActor - func test_assemblerWithFakes() { - let assembler = ModuleAssembler([Assembly2Fake()]) + func test_assemblerWithFakes() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly2Fake()] + ) XCTAssertFalse(assembler.registeredModules.contains(where: {$0 == Assembly2.self})) XCTAssertTrue(assembler.isRegistered(Assembly2.self)) XCTAssertTrue(assembler.isRegistered(Assembly2Fake.self)) } @MainActor - func test_parentFakes() { - let parent = ModuleAssembler([Assembly1Fake()]) - let child = ModuleAssembler(parent: parent, [Assembly2()]) + func test_parentFakes() throws { + let parent = try ModuleAssembler(_modules: [Assembly1Fake()]) + let child = try ModuleAssembler(parent: parent, _modules: [Assembly2()]) XCTAssertTrue(child.isRegistered(Assembly1.self)) XCTAssertTrue(child.isRegistered(Assembly1Fake.self)) } @MainActor - func test_autoFake() { - let assembler = ModuleAssembler([Assembly5()]) + func test_autoFake() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly5()] + ) XCTAssertTrue(assembler.isRegistered(Assembly4.self)) XCTAssertTrue(assembler.isRegistered(Assembly4Fake.self)) XCTAssertTrue(assembler.isRegistered(Assembly5.self)) } @MainActor - func test_overrideDefaultOverride() { - let assembler = ModuleAssembler( - [Assembly4(), Assembly4Fake2()], + func test_overrideDefaultOverride() throws { + let assembler = try ModuleAssembler( + _modules: [Assembly4(), Assembly4Fake2()], overrideBehavior: .useDefaultOverrides ) XCTAssertTrue(assembler.isRegistered(Assembly4.self)) @@ -122,9 +138,9 @@ final class ModuleAssemblyOverrideTests: XCTestCase { } @MainActor - func test_parentNonAutoOverride() { - let parent = ModuleAssembler([NonAutoOverride()]) - let child = ModuleAssembler(parent: parent, [Assembly1()], overrideBehavior: .disableDefaultOverrides) + func test_parentNonAutoOverride() throws { + let parent = try ModuleAssembler(_modules: [NonAutoOverride()]) + let child = try ModuleAssembler(parent: parent, _modules: [Assembly1()], overrideBehavior: .disableDefaultOverrides) XCTAssertTrue(child.isRegistered(Assembly1.self)) XCTAssertTrue(child.registeredModules.isEmpty) @@ -137,9 +153,9 @@ final class ModuleAssemblyOverrideTests: XCTestCase { } @MainActor - func test_multipleOverrides() { - let assembler = ModuleAssembler( - [MultipleDependencyAssembly(), MultipleOverrideAssembly()], + func test_multipleOverrides() throws { + let assembler = try ModuleAssembler( + _modules: [MultipleDependencyAssembly(), MultipleOverrideAssembly()], overrideBehavior: .disableDefaultOverrides ) @@ -167,7 +183,7 @@ private struct Assembly1: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { return [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Depends on Assembly1 and registers Service2Protocol @@ -176,17 +192,17 @@ private struct Assembly2: ModuleAssembly { return [Assembly1.self] } - func assemble(container: Container) { - container.register(Service2Protocol.self) { _ in Service2() } + func assemble(container: Container) { + container.register(Service2Protocol.self, factory: { _ in Service2() }) } } // Mock implementation of Assembly2. Adds an extra dependency on Assembly3 private struct Assembly2Fake: AutoInitModuleAssembly { - func assemble(container: Container) { + func assemble(container: Container) { Assembly2().assemble(container: container) - container.register(Service2Protocol.self) { _ in Service2Fake() } + container.register(Service2Protocol.self, factory: { _ in Service2Fake() }) } static var replaces: [any ModuleAssembly.Type] { [Assembly2.self] } @@ -196,7 +212,7 @@ private struct Assembly2Fake: AutoInitModuleAssembly { } private struct Assembly1Fake: AutoInitModuleAssembly { - func assemble(container: Container) {} + func assemble(container: Container) {} static var dependencies: [any ModuleAssembly.Type] { [] } static var replaces: [any ModuleAssembly.Type] { [Assembly1.self] } } @@ -206,14 +222,14 @@ extension Assembly1: DefaultModuleAssemblyOverride { } private struct FakeAssembly3: AutoInitModuleAssembly { - func assemble(container: Container) { } + func assemble(container: Container) { } static var dependencies: [any ModuleAssembly.Type] { [] } } // An Assembly that is *not* AutoInit private struct Assembly4: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Container) { } } extension Assembly4: DefaultModuleAssemblyOverride { @@ -223,35 +239,35 @@ extension Assembly4: DefaultModuleAssemblyOverride { // The fake is AutoInit so can be created even when Assembly4 is unavailable private struct Assembly4Fake: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Container) { } static var replaces: [any ModuleAssembly.Type] { [Assembly4.self] } } private struct Assembly4Fake2: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Container) { } static var replaces: [any ModuleAssembly.Type] { [Assembly4.self] } } private struct NonAutoOverride: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Container) { } static var replaces: [any ModuleAssembly.Type] { [Assembly1.self] } } private struct Assembly5: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly4.self] } - func assemble(container: Container) { } + func assemble(container: Container) { } } private struct MultipleDependencyAssembly: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly1.self, Assembly5.self] } - func assemble(container: Container) { } + func assemble(container: Container) { } } private struct MultipleOverrideAssembly: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Container) { } static var replaces: [any ModuleAssembly.Type] { [Assembly1.self, Assembly4.self, Assembly5.self] } } diff --git a/Tests/KnitTests/ModuleAssemblyScopingTests.swift b/Tests/KnitTests/ModuleAssemblyScopingTests.swift index 689e5de9..73db5969 100644 --- a/Tests/KnitTests/ModuleAssemblyScopingTests.swift +++ b/Tests/KnitTests/ModuleAssemblyScopingTests.swift @@ -36,25 +36,25 @@ extension ChildResolver { private struct Assembly1: GeneratedModuleAssembly { typealias TargetResolver = ParentResolver static var generatedDependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly2: GeneratedModuleAssembly { typealias TargetResolver = ChildResolver static var generatedDependencies: [any ModuleAssembly.Type] { [Assembly1.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly3: GeneratedModuleAssembly { typealias TargetResolver = OtherResolver static var generatedDependencies: [any ModuleAssembly.Type] { [Assembly2.self, Assembly1.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly4: GeneratedModuleAssembly { typealias TargetResolver = ParentResolver static var generatedDependencies: [any ModuleAssembly.Type] { [Assembly1.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private extension ModuleAssembly { diff --git a/Tests/KnitTests/ModuleCycleTests.swift b/Tests/KnitTests/ModuleCycleTests.swift index fb85da7b..37986ce1 100644 --- a/Tests/KnitTests/ModuleCycleTests.swift +++ b/Tests/KnitTests/ModuleCycleTests.swift @@ -8,8 +8,8 @@ import XCTest final class ModuleCycleTests: XCTestCase { @MainActor - func test_cycleResolution() { - let assembler = ModuleAssembler([Assembly1()]) + func test_cycleResolution() throws { + let assembler = try ModuleAssembler(_modules: [Assembly1()]) XCTAssertTrue(assembler.isRegistered(Assembly1.self)) XCTAssertTrue(assembler.isRegistered(Assembly2.self)) XCTAssertTrue(assembler.isRegistered(Assembly3.self)) @@ -27,8 +27,8 @@ final class ModuleCycleTests: XCTestCase { } @MainActor - func test_sourceCycle() { - let assembler = ModuleAssembler([Assembly5()]) + func test_sourceCycle() throws { + let assembler = try ModuleAssembler(_modules: [Assembly5()]) XCTAssertEqual( assembler.builder.dependencyTree.sourcePath(moduleType: Assembly5.self), ["\(Assembly5.self)"] @@ -46,13 +46,13 @@ final class ModuleCycleTests: XCTestCase { private struct Assembly1: ModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly2.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly2 is overriden by default by Assembly3 and requires Assembly4 private struct Assembly2: ModuleAssembly, DefaultModuleAssemblyOverride { static var dependencies: [any ModuleAssembly.Type] { [Assembly4.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} typealias OverrideType = Assembly3 } @@ -60,28 +60,28 @@ private struct Assembly2: ModuleAssembly, DefaultModuleAssemblyOverride { private struct Assembly3: AutoInitModuleAssembly { init() {} static var dependencies: [any ModuleAssembly.Type] { [Assembly2.self, Assembly4.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} static var replaces: [any ModuleAssembly.Type] { [Assembly2.self] } } private struct Assembly4: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) {} + func assemble(container: Container) {} } // Assembly 5-6-7 form a dependency circle private struct Assembly5: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly6.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly6: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly7.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } private struct Assembly7: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] { [Assembly5.self] } - func assemble(container: Container) {} + func assemble(container: Container) {} } diff --git a/Tests/KnitTests/ScopedModuleAssemblerTests.swift b/Tests/KnitTests/ScopedModuleAssemblerTests.swift index 051e9666..73be42d2 100644 --- a/Tests/KnitTests/ScopedModuleAssemblerTests.swift +++ b/Tests/KnitTests/ScopedModuleAssemblerTests.swift @@ -31,50 +31,7 @@ final class ScopedModuleAssemblerTests: XCTestCase { let assembler = try ScopedModuleAssembler(_modules: [Assembly1()]) { container in container.register(String.self) { _ in "string" } } - XCTAssertEqual(assembler.resolver.resolve(String.self), "string") - } - - @MainActor - func testOutOfScopeAssemblyThrows() { - XCTAssertThrowsError( - try ScopedModuleAssembler( - _modules: [ Assembly2() ] - ), - "Assembly2 with target OutsideResolver should throw an error", - { error in - XCTAssertEqual( - error.localizedDescription, - """ - Assembly2 did not pass assembly validation check: The ModuleAssembly's TargetResolver is incorrect. - Expected: TestResolver - Actual: OutsideResolver - """ - ) - } - ) - } - - @MainActor - func testIncorrectInputScope() throws { - let parent = try ScopedModuleAssembler(_modules: [Assembly1()]) - // Even though Assembly1 is already registered, because it was explicitly provided the validation should fail - XCTAssertThrowsError( - try ScopedModuleAssembler( - parent: parent.internalAssembler, - _modules: [Assembly3(), Assembly1()] - ), - "Assembly1 with target TestResolver should throw an error", - { error in - XCTAssertEqual( - error.localizedDescription, - """ - Assembly1 did not pass assembly validation check: The ModuleAssembly's TargetResolver is incorrect. - Expected: OutsideResolver - Actual: TestResolver - """ - ) - } - ) + XCTAssertEqual(assembler.unsafeResolver.resolve(String.self), "string") } @MainActor @@ -87,7 +44,7 @@ final class ScopedModuleAssemblerTests: XCTestCase { [], behaviors: [testBehavior] ) - let container = scopedModuleAssembler._container + let container = scopedModuleAssembler.internalAssembler._swinjectContainer // ModuleAssembler automatically adds behaviors for ServiceCollector and AbstractRegistrationContainer // so first filter those out let foundBehaviors = container.behaviors.filter { behavior in @@ -104,28 +61,29 @@ final class ScopedModuleAssemblerTests: XCTestCase { } private struct Assembly1: AutoInitModuleAssembly { + typealias TargetResolver = TestResolver static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Knit.Container) { } } -protocol OutsideResolver: Resolver { } +protocol OutsideResolver: Swinject.Resolver { } private struct Assembly2: AutoInitModuleAssembly { typealias TargetResolver = OutsideResolver static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { } + func assemble(container: Knit.Container) { } } private struct Assembly3: AutoInitModuleAssembly { typealias TargetResolver = OutsideResolver static var dependencies: [any ModuleAssembly.Type] { [Assembly1.self] } - func assemble(container: Container) { } + func assemble(container: Knit.Container) { } } private final class TestBehavior: Behavior { func container( - _ container: Container, + _ container: Swinject.Container, didRegisterType type: Type.Type, toService entry: Swinject.ServiceEntry, withName name: String? diff --git a/Tests/KnitTests/ServiceCollectorTests.swift b/Tests/KnitTests/ServiceCollectorTests.swift index 5437400f..2777c4b3 100644 --- a/Tests/KnitTests/ServiceCollectorTests.swift +++ b/Tests/KnitTests/ServiceCollectorTests.swift @@ -2,7 +2,8 @@ // Copyright © Block, Inc. All rights reserved. // -import Knit +@testable import Knit +import Swinject import XCTest protocol ServiceProtocol {} @@ -14,7 +15,7 @@ struct ServiceB: ServiceProtocol {} struct AssemblyA: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] = [] - func assemble(container: Container) { + func assemble(container: Knit.Container) { container.registerIntoCollection(ServiceProtocol.self, factory: { _ in ServiceA() }) } } @@ -22,7 +23,7 @@ struct AssemblyA: AutoInitModuleAssembly { struct AssemblyB: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] = [] - func assemble(container: Container) { + func assemble(container: Knit.Container) { container.registerIntoCollection(ServiceProtocol.self, factory: { _ in ServiceB() }) } } @@ -30,7 +31,7 @@ struct AssemblyB: AutoInitModuleAssembly { struct AssemblyC: AutoInitModuleAssembly { static var dependencies: [any ModuleAssembly.Type] = [] - func assemble(container: Container) { } + func assemble(container: Knit.Container) { } } final class CustomService: ServiceProtocol { @@ -55,8 +56,9 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection() { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register some services into a collection container.registerIntoCollection(ServiceProtocol.self) { _ in ServiceA() } @@ -67,12 +69,12 @@ final class ServiceCollectorTests: XCTestCase { container.registerIntoCollection(CustomService.self) { _ in CustomService(name: "Custom 2") } // Resolving each collection should produce the expected services - let serviceProtocolCollection = container.resolveCollection(ServiceProtocol.self) + let serviceProtocolCollection = container._unwrappedSwinjectContainer.resolveCollection(ServiceProtocol.self) XCTAssertEqual(serviceProtocolCollection.entries.count, 2) XCTAssert(serviceProtocolCollection.entries.first is ServiceA) XCTAssert(serviceProtocolCollection.entries.last is ServiceB) - let customServiceCollection = container.resolveCollection(CustomService.self) + let customServiceCollection = container._unwrappedSwinjectContainer.resolveCollection(CustomService.self) XCTAssertEqual( customServiceCollection.entries.map(\.name), ["Custom 1", "Custom 2"] @@ -81,18 +83,20 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_emptyWithBehavior() { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container._unwrappedSwinjectContainer.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 0) } @MainActor func test_registerIntoCollection_emptyWithoutBehavior() { - let container = Container() + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container._unwrappedSwinjectContainer.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 0) } @@ -100,8 +104,9 @@ final class ServiceCollectorTests: XCTestCase { /// A conflict here would be confusing and surprising to the user. @MainActor func test_registerIntoCollection_doesntConflictWithArray() throws { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register A into a collection container.registerIntoCollection(ServiceProtocol.self) { _ in ServiceA() } @@ -110,20 +115,21 @@ final class ServiceCollectorTests: XCTestCase { container.register([ServiceProtocol].self) { _ in [ServiceB()] } // Resolving the collection should produce A - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container._unwrappedSwinjectContainer.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 1) XCTAssert(collection.entries.first is ServiceA) // Resolving the array should produce B - let array = try XCTUnwrap(container.resolve([ServiceProtocol].self)) + let array = try XCTUnwrap(container._unwrappedSwinjectContainer.resolve([ServiceProtocol].self)) XCTAssertEqual(array.count, 1) XCTAssert(array.first is ServiceB) } @MainActor func test_registerIntoCollection_doesntImplicitlyAggregateInstances() throws { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register A and B into a collection _ = container.registerIntoCollection(ServiceProtocol.self) { _ in ServiceA() } @@ -133,19 +139,20 @@ final class ServiceCollectorTests: XCTestCase { _ = container.register(ServiceProtocol.self) { _ in ServiceB() } // Resolving the collection should produce A and B - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container._unwrappedSwinjectContainer.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 2) XCTAssert(collection.entries.first is ServiceA) XCTAssert(collection.entries.last is ServiceB) // Resolving the service individually should produce B - XCTAssert(container.resolve(ServiceProtocol.self) is ServiceB) + XCTAssert(container._unwrappedSwinjectContainer.resolve(ServiceProtocol.self) is ServiceB) } @MainActor func test_registerIntoCollection_allowsDuplicates() { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register some duplicate services _ = container.registerIntoCollection(ServiceProtocol.self) { _ in CustomService(name: "Dry Cleaning") } @@ -153,7 +160,7 @@ final class ServiceCollectorTests: XCTestCase { _ = container.registerIntoCollection(ServiceProtocol.self) { _ in CustomService(name: "Car Repair") } // Resolving the collection should produce all services - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container._unwrappedSwinjectContainer.resolveCollection(ServiceProtocol.self) XCTAssertEqual( collection.entries.compactMap { ($0 as? CustomService)?.name }, ["Dry Cleaning", "Car Repair", "Car Repair"] @@ -164,8 +171,9 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_supportsTransientScopedObjects() throws { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register a service with the `transient` scope. // It should be recreated each time the ServiceCollection is resolved. @@ -173,8 +181,8 @@ final class ServiceCollectorTests: XCTestCase { .registerIntoCollection(CustomService.self) { _ in CustomService(name: "service") } .inObjectScope(.transient) - let collection1 = container.resolveCollection(CustomService.self) - let collection2 = container.resolveCollection(CustomService.self) + let collection1 = container._unwrappedSwinjectContainer.resolveCollection(CustomService.self) + let collection2 = container._unwrappedSwinjectContainer.resolveCollection(CustomService.self) let instance1 = try XCTUnwrap(collection1.entries.first) let instance2 = try XCTUnwrap(collection2.entries.first) @@ -184,8 +192,9 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_supportsContainerScopedObjects() throws { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register a service with the `container` scope. // The same instance should be shared, even if the collection is resolved many times. @@ -193,8 +202,8 @@ final class ServiceCollectorTests: XCTestCase { .registerIntoCollection(CustomService.self) { _ in CustomService(name: "service") } .inObjectScope(.container) - let collection1 = container.resolveCollection(CustomService.self) - let collection2 = container.resolveCollection(CustomService.self) + let collection1 = container._unwrappedSwinjectContainer.resolveCollection(CustomService.self) + let collection2 = container._unwrappedSwinjectContainer.resolveCollection(CustomService.self) let instance1 = try XCTUnwrap(collection1.entries.first) let instance2 = try XCTUnwrap(collection2.entries.first) @@ -204,8 +213,9 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_supportsWeakScopedObjects() throws { - let container = Container() - container.addBehavior(ServiceCollector()) + let swinjectContainer = Swinject.Container() + let container = Knit.Container._instantiateAndRegister(_swinjectContainer: swinjectContainer) + container._unwrappedSwinjectContainer.addBehavior(ServiceCollector()) // Register a service with the `weak` scope. // The same instance should be shared while the instance is alive. @@ -219,26 +229,39 @@ final class ServiceCollectorTests: XCTestCase { .inObjectScope(.weak) // Resolve the initial instance - var instance1: CustomService? = try XCTUnwrap(container.resolveCollection(CustomService.self).entries.first) + var instance1: CustomService? = try XCTUnwrap(container._unwrappedSwinjectContainer.resolveCollection(CustomService.self).entries.first) XCTAssertEqual(factoryCallCount, 1) // Resolving again shouldn't increase `factoryCallCount` since `instance1` is still retained. - var instance2: CustomService? = try XCTUnwrap(container.resolveCollection(CustomService.self).entries.first) + var instance2: CustomService? = try XCTUnwrap(container._unwrappedSwinjectContainer.resolveCollection(CustomService.self).entries.first) XCTAssertEqual(factoryCallCount, 1) XCTAssert(instance2 === instance1) // Release our instances and resolve again. This time a new instance should be created. instance1 = nil instance2 = nil - _ = container.resolveCollection(CustomService.self) + _ = container._unwrappedSwinjectContainer.resolveCollection(CustomService.self) XCTAssertEqual(factoryCallCount, 2) } @MainActor - func test_parentChildContainersWithAssemblers() { - let parent = ModuleAssembler([AssemblyA()]) - let child = ModuleAssembler(parent: parent, [AssemblyB()]) - + func test_parentChildContainersWithAssemblers() throws { + let parent = try ModuleAssembler( + _modules: [AssemblyA()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) + let child = try ModuleAssembler( + parent: parent, + _modules: [AssemblyB()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) + // When resolving from the parent resolver we only get services from AssemblyA XCTAssertEqual( parent.resolver.resolveCollection(ServiceProtocol.self).entries.count, @@ -260,9 +283,22 @@ final class ServiceCollectorTests: XCTestCase { } @MainActor - func test_childWithEmptyParent() { - let parent = ModuleAssembler([AssemblyC()]) - let child = ModuleAssembler(parent: parent, [AssemblyB()]) + func test_childWithEmptyParent() throws { + let parent = try ModuleAssembler( + _modules: [AssemblyC()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) + let child = try ModuleAssembler( + parent: parent, + _modules: [AssemblyB()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) // Parent has no services registered XCTAssertEqual( @@ -277,10 +313,10 @@ final class ServiceCollectorTests: XCTestCase { } @MainActor - func test_emptyChildWithParent() { - let parent = ModuleAssembler([AssemblyB()]) - let child = ModuleAssembler(parent: parent, [AssemblyC()]) - + func test_emptyChildWithParent() throws { + let parent = try ModuleAssembler(_modules: [AssemblyB()]) + let child = try ModuleAssembler(parent: parent, _modules: [AssemblyC()]) + // The parent itself has no services so they come from the child XCTAssertEqual( child.resolver.resolveCollection(ServiceProtocol.self).entries.count, @@ -295,10 +331,30 @@ final class ServiceCollectorTests: XCTestCase { } @MainActor - func test_grandparentRelationship() { - let grandParent = ModuleAssembler([AssemblyA()]) - let parent = ModuleAssembler(parent: grandParent, [AssemblyC()]) - let child = ModuleAssembler(parent: parent, [AssemblyB()]) + func test_grandparentRelationship() throws { + let grandParent = try ModuleAssembler( + _modules: [AssemblyA()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) + let parent = try ModuleAssembler( + parent: grandParent, + _modules: [AssemblyC()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) + let child = try ModuleAssembler( + parent: parent, + _modules: [AssemblyB()], + preAssemble: { container in + Knit.Container._instantiateAndRegister(_swinjectContainer: container) + }, + autoConfigureContainers: false + ) // The child has access to all services XCTAssertEqual( diff --git a/Tests/KnitTests/SynchronizationTests.swift b/Tests/KnitTests/SynchronizationTests.swift index 591fd495..68f66e6c 100644 --- a/Tests/KnitTests/SynchronizationTests.swift +++ b/Tests/KnitTests/SynchronizationTests.swift @@ -11,16 +11,16 @@ final class SynchronizationTests: XCTestCase { @MainActor func testMultiThreadResolving() async throws { // Use a parent/child relationship to test synchronization between containers - let parent = ModuleAssembler([Assembly1()]) - let assembler = ModuleAssembler(parent: parent, [Assembly2()]) + let parent = ScopedModuleAssembler([Assembly1()]) + let assembler = ScopedModuleAssembler(parent: parent.internalAssembler, [Assembly2()]) // Resolve the same service in 2 separate tasks async let task1 = try Task { - return assembler.resolver.resolve(Service2.self)! + return assembler.resolver.service2() }.result.get() async let task2 = try Task { - return assembler.resolver.resolve(Service2.self)! + return assembler.resolver.service2() }.result.get() let result = try await (task1, task2) @@ -35,11 +35,11 @@ final class SynchronizationTests: XCTestCase { // Resolve the same service in 2 separate tasks async let task1 = try Task { - return assembler.resolver.resolve(Service2.self)! + return assembler.resolver.service2() }.result.get() async let task2 = try Task { - return assembler.resolver.resolve(Service2.self)! + return assembler.resolver.service2() }.result.get() let result = try await (task1, task2) @@ -53,17 +53,22 @@ final class SynchronizationTests: XCTestCase { private struct Assembly1: AutoInitModuleAssembly { typealias TargetResolver = TestScopedResolver static var dependencies: [any ModuleAssembly.Type] { [] } - func assemble(container: Container) { - container.register(Service1.self) { _ in Service1() } + func assemble(container: Container) { + container.register(Service1.self, factory: { _ in Service1() }) } } private struct Assembly2: ModuleAssembly { typealias TargetResolver = TestScopedResolver static var dependencies: [any ModuleAssembly.Type] { [Assembly1.self] } - func assemble(container: Container) { - container.register(Service2.self) { Service2(service1: $0.resolve(Service1.self)! )} - .inObjectScope(.weak) + func assemble(container: Container) { + container.register( + Service2.self, + factory: { resolver in + Service2(service1: resolver.service1()) + } + ) + .inObjectScope(.weak) } } @@ -79,5 +84,17 @@ private final class Service2 { } } -public protocol TestScopedResolver: Resolver { } -extension Container: TestScopedResolver {} +private protocol TestScopedResolver: Knit.Resolver { + func service1() -> Service1 + func service2() -> Service2 +} +extension TestScopedResolver { + fileprivate func service1() -> Service1 { + self.unsafeResolver.resolve(Service1.self)! + } + + fileprivate func service2() -> Service2 { + self.unsafeResolver.resolve(Service2.self)! + } +} +extension Container: TestScopedResolver {} diff --git a/Tests/KnitTests/TestResolver.swift b/Tests/KnitTests/TestResolver.swift index 801da3a4..99b1ad1f 100644 --- a/Tests/KnitTests/TestResolver.swift +++ b/Tests/KnitTests/TestResolver.swift @@ -5,9 +5,9 @@ @testable import Knit import Swinject -protocol TestResolver: Resolver { } +protocol TestResolver: Knit.Resolver { } -extension Container: TestResolver {} +extension Knit.Container: TestResolver {} extension ModuleAssembly { @@ -15,3 +15,30 @@ extension ModuleAssembly { typealias TargetResolver = TestResolver } + +extension ModuleAssembler { + + // Convenience throwing init that fills in preAssemble and autoConfigureContainers + @MainActor convenience init( + parent: ModuleAssembler? = nil, + _modules modules: [any ModuleAssembly], + overrideBehavior: OverrideBehavior = .defaultOverridesWhenTesting, + assemblyValidation: ((any ModuleAssembly.Type) throws -> Void)? = nil, + errorFormatter: ModuleAssemblerErrorFormatter = DefaultModuleAssemblerErrorFormatter(), + behaviors: [Behavior] = [], + postAssemble: ((Swinject.Container) -> Void)? = nil + ) throws { + try self.init( + parent: parent, + _modules: modules, + overrideBehavior: overrideBehavior, + assemblyValidation: assemblyValidation, + errorFormatter: errorFormatter, + behaviors: behaviors, + preAssemble: nil, + postAssemble: postAssemble, + autoConfigureContainers: true + ) + } + +} diff --git a/Tests/KnitTests/WeakResolverTests.swift b/Tests/KnitTests/WeakResolverTests.swift index d2a58770..c11b1508 100644 --- a/Tests/KnitTests/WeakResolverTests.swift +++ b/Tests/KnitTests/WeakResolverTests.swift @@ -2,13 +2,14 @@ // Copyright © Block, Inc. All rights reserved. // -import Knit +@testable import Knit +import Swinject import XCTest final class WeakResolverTests: XCTestCase { func test_weakResolver() { - var container: Container? = Container() + var container: Swinject.Container? = Swinject.Container() weak var weakContainer = container container?.register(String.self) { _ in "Test" } @@ -30,7 +31,7 @@ final class WeakResolverTests: XCTestCase { // It is probably unusual if a consumer retains the result of the `optionalResolver` property, // but in case that happens we don't want to accidentally leak the container. - var container: Container? = Container() + var container: Swinject.Container? = Swinject.Container() weak var weakConatiner = container container?.register(String.self) { _ in "Test" }