diff --git a/src/basic_memory/cli/commands/cloud/cloud_utils.py b/src/basic_memory/cli/commands/cloud/cloud_utils.py index b4c1158a..566ca5aa 100644 --- a/src/basic_memory/cli/commands/cloud/cloud_utils.py +++ b/src/basic_memory/cli/commands/cloud/cloud_utils.py @@ -16,12 +16,23 @@ class CloudUtilsError(Exception): pass +def _workspace_headers(workspace: str | None = None) -> dict[str, str]: + """Build workspace header if workspace is specified.""" + if workspace: + return {"X-Workspace-ID": workspace} + return {} + + async def fetch_cloud_projects( *, + workspace: str | None = None, api_request=make_api_request, ) -> CloudProjectList: """Fetch list of projects from cloud API. + Args: + workspace: Cloud workspace tenant_id to list projects from + Returns: CloudProjectList with projects from cloud """ @@ -30,7 +41,11 @@ async def fetch_cloud_projects( config = config_manager.config host_url = config.cloud_host.rstrip("/") - response = await api_request(method="GET", url=f"{host_url}/proxy/v2/projects/") + response = await api_request( + method="GET", + url=f"{host_url}/proxy/v2/projects/", + headers=_workspace_headers(workspace), + ) return CloudProjectList.model_validate(response.json()) except Exception as e: @@ -40,12 +55,14 @@ async def fetch_cloud_projects( async def create_cloud_project( project_name: str, *, + workspace: str | None = None, api_request=make_api_request, ) -> CloudProjectCreateResponse: """Create a new project on cloud. Args: project_name: Name of project to create + workspace: Cloud workspace tenant_id to create project in Returns: CloudProjectCreateResponse with project details from API @@ -64,10 +81,13 @@ async def create_cloud_project( set_default=False, ) + headers = {"Content-Type": "application/json"} + headers.update(_workspace_headers(workspace)) + response = await api_request( method="POST", url=f"{host_url}/proxy/v2/projects/", - headers={"Content-Type": "application/json"}, + headers=headers, json_data=project_data.model_dump(), ) @@ -91,17 +111,23 @@ async def sync_project(project_name: str, force_full: bool = False) -> None: raise CloudUtilsError(f"Failed to sync project '{project_name}': {e}") from e -async def project_exists(project_name: str, *, api_request=make_api_request) -> bool: +async def project_exists( + project_name: str, + *, + workspace: str | None = None, + api_request=make_api_request, +) -> bool: """Check if a project exists on cloud. Args: project_name: Name of project to check + workspace: Cloud workspace tenant_id to check in Returns: True if project exists, False otherwise """ try: - projects = await fetch_cloud_projects(api_request=api_request) + projects = await fetch_cloud_projects(workspace=workspace, api_request=api_request) project_names = {p.name for p in projects.projects} return project_name in project_names except Exception: diff --git a/src/basic_memory/cli/commands/cloud/upload_command.py b/src/basic_memory/cli/commands/cloud/upload_command.py index b27c83ff..220c1ef3 100644 --- a/src/basic_memory/cli/commands/cloud/upload_command.py +++ b/src/basic_memory/cli/commands/cloud/upload_command.py @@ -13,6 +13,7 @@ sync_project, ) from basic_memory.cli.commands.cloud.upload import upload_path +from basic_memory.config import ConfigManager from basic_memory.mcp.async_client import get_cloud_control_plane_client console = Console() @@ -73,12 +74,21 @@ def upload( """ async def _upload(): + # Resolve workspace: per-project workspace_id, then default_workspace + config = ConfigManager().config + workspace = None + entry = config.projects.get(project) + if entry and entry.workspace_id: + workspace = entry.workspace_id + elif config.default_workspace: + workspace = config.default_workspace + # Check if project exists - if not await project_exists(project): + if not await project_exists(project, workspace=workspace): if create_project: console.print(f"[blue]Creating cloud project '{project}'...[/blue]") try: - await create_cloud_project(project) + await create_cloud_project(project, workspace=workspace) console.print(f"[green]Created project '{project}'[/green]") except Exception as e: console.print(f"[red]Failed to create project: {e}[/red]") @@ -93,6 +103,8 @@ async def _upload(): raise typer.Exit(1) # Perform upload (or dry run) + if workspace: + console.print(f"[dim]Using workspace: {workspace}[/dim]") if dry_run: console.print( f"[yellow]DRY RUN: Showing what would be uploaded to '{project}'[/yellow]" @@ -100,13 +112,16 @@ async def _upload(): else: console.print(f"[blue]Uploading {path} to project '{project}'...[/blue]") + def _client_factory(): + return get_cloud_control_plane_client(workspace=workspace) + success = await upload_path( path, project, verbose=verbose, use_gitignore=not no_gitignore, dry_run=dry_run, - client_cm_factory=get_cloud_control_plane_client, + client_cm_factory=_client_factory, ) if not success: console.print("[red]Upload failed[/red]") diff --git a/src/basic_memory/mcp/async_client.py b/src/basic_memory/mcp/async_client.py index 5b794c23..e191e1f9 100644 --- a/src/basic_memory/mcp/async_client.py +++ b/src/basic_memory/mcp/async_client.py @@ -88,15 +88,20 @@ async def _cloud_client( @asynccontextmanager -async def get_cloud_control_plane_client() -> AsyncIterator[AsyncClient]: +async def get_cloud_control_plane_client( + workspace: Optional[str] = None, +) -> AsyncIterator[AsyncClient]: """Create a control-plane cloud client for endpoints outside /proxy.""" config = ConfigManager().config timeout = _build_timeout() token = await _resolve_cloud_token(config) + headers: dict[str, str] = {"Authorization": f"Bearer {token}"} + if workspace: + headers["X-Workspace-ID"] = workspace logger.info(f"Creating HTTP client for cloud control plane at: {config.cloud_host}") async with AsyncClient( base_url=config.cloud_host, - headers={"Authorization": f"Bearer {token}"}, + headers=headers, timeout=timeout, ) as client: yield client @@ -179,8 +184,16 @@ async def get_client( project_mode = config.get_project_mode(project_name) if project_mode == ProjectMode.CLOUD: logger.debug(f"Project '{project_name}' is cloud mode - using cloud proxy client") + # Resolve workspace from project config if not explicitly provided + effective_workspace = workspace + if effective_workspace is None: + entry = config.projects.get(project_name) + if entry and entry.workspace_id: + effective_workspace = entry.workspace_id + elif config.default_workspace: + effective_workspace = config.default_workspace try: - async with _cloud_client(config, timeout, workspace=workspace) as client: + async with _cloud_client(config, timeout, workspace=effective_workspace) as client: yield client except RuntimeError as exc: raise RuntimeError( diff --git a/tests/cli/cloud/test_cloud_api_client_and_utils.py b/tests/cli/cloud/test_cloud_api_client_and_utils.py index 65f2f99e..0247a835 100644 --- a/tests/cli/cloud/test_cloud_api_client_and_utils.py +++ b/tests/cli/cloud/test_cloud_api_client_and_utils.py @@ -10,6 +10,7 @@ make_api_request, ) from basic_memory.cli.commands.cloud.cloud_utils import ( + _workspace_headers, create_cloud_project, fetch_cloud_projects, project_exists, @@ -165,6 +166,81 @@ async def api_request(**kwargs): assert seen["create_payload"]["path"] == "my-project" +def test_workspace_headers_with_workspace(): + """_workspace_headers returns X-Workspace-ID when workspace is provided.""" + assert _workspace_headers("tenant-123") == {"X-Workspace-ID": "tenant-123"} + + +def test_workspace_headers_without_workspace(): + """_workspace_headers returns empty dict when no workspace.""" + assert _workspace_headers(None) == {} + assert _workspace_headers() == {} + + +@pytest.mark.asyncio +async def test_cloud_utils_pass_workspace_header(config_home, config_manager): + """fetch_cloud_projects, project_exists, and create_cloud_project pass workspace header.""" + config = config_manager.load_config() + config.cloud_host = "https://cloud.example.test" + config_manager.save_config(config) + + auth = CLIAuth(client_id="cid", authkit_domain="https://auth.example.test") + auth.token_file.parent.mkdir(parents=True, exist_ok=True) + auth.token_file.write_text( + '{"access_token":"token-123","refresh_token":null,"expires_at":9999999999,"token_type":"Bearer"}', + encoding="utf-8", + ) + + seen_headers: list[dict] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + seen_headers.append(dict(request.headers)) + if request.method == "GET": + return httpx.Response(200, json={"projects": []}) + if request.method == "POST": + payload = json.loads(request.content.decode("utf-8")) + return httpx.Response( + 200, + json={ + "message": "created", + "status": "success", + "default": False, + "old_project": None, + "new_project": {"name": payload["name"], "path": payload["path"]}, + }, + ) + raise AssertionError(f"Unexpected: {request.method}") + + transport = httpx.MockTransport(handler) + + @asynccontextmanager + async def http_client_factory(): + async with httpx.AsyncClient( + transport=transport, base_url="https://cloud.example.test" + ) as client: + yield client + + async def api_request(**kwargs): + return await make_api_request(auth=auth, http_client_factory=http_client_factory, **kwargs) + + # fetch with workspace + await fetch_cloud_projects(workspace="tenant-abc", api_request=api_request) + assert seen_headers[-1].get("x-workspace-id") == "tenant-abc" + + # project_exists with workspace + await project_exists("test", workspace="tenant-abc", api_request=api_request) + assert seen_headers[-1].get("x-workspace-id") == "tenant-abc" + + # create with workspace + await create_cloud_project("new-proj", workspace="tenant-abc", api_request=api_request) + assert seen_headers[-1].get("x-workspace-id") == "tenant-abc" + + # Without workspace — header should not be present + seen_headers.clear() + await fetch_cloud_projects(api_request=api_request) + assert "x-workspace-id" not in seen_headers[-1] + + @pytest.mark.asyncio async def test_make_api_request_prefers_api_key_over_oauth(config_home, config_manager): """API key in config should be used without needing an OAuth token on disk.""" diff --git a/tests/cli/cloud/test_upload_command_routing.py b/tests/cli/cloud/test_upload_command_routing.py index 6bb0115b..71586c17 100644 --- a/tests/cli/cloud/test_upload_command_routing.py +++ b/tests/cli/cloud/test_upload_command_routing.py @@ -10,7 +10,7 @@ runner = CliRunner() -def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path): +def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path, config_manager): """Upload command should use control-plane cloud client for WebDAV PUT operations.""" import basic_memory.cli.commands.cloud.upload_command as upload_command @@ -20,11 +20,11 @@ def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path): seen: dict[str, str] = {} - async def fake_project_exists(_project_name: str) -> bool: + async def fake_project_exists(_project_name: str, workspace: str | None = None) -> bool: return True @asynccontextmanager - async def fake_get_client(): + async def fake_get_client(workspace=None): async with httpx.AsyncClient(base_url="https://cloud.example.test") as client: yield client diff --git a/tests/mcp/test_async_client_modes.py b/tests/mcp/test_async_client_modes.py index f65c4f2e..402fab0b 100644 --- a/tests/mcp/test_async_client_modes.py +++ b/tests/mcp/test_async_client_modes.py @@ -274,6 +274,64 @@ async def test_get_cloud_control_plane_client_uses_oauth_token(config_manager): assert client.headers.get("Authorization") == "Bearer oauth-control-123" +@pytest.mark.asyncio +async def test_get_cloud_control_plane_client_with_workspace(config_manager): + """Control plane client passes X-Workspace-ID header when workspace is provided.""" + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + config_manager.save_config(cfg) + + async with get_cloud_control_plane_client(workspace="tenant-abc") as client: + assert client.headers.get("X-Workspace-ID") == "tenant-abc" + + # Without workspace, header should not be present + async with get_cloud_control_plane_client() as client: + assert "X-Workspace-ID" not in client.headers + + +@pytest.mark.asyncio +async def test_get_client_auto_resolves_workspace_from_project_config(config_manager): + """get_client resolves workspace from project entry when not explicitly passed.""" + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + cfg.set_project_mode("research", ProjectMode.CLOUD) + cfg.projects["research"].workspace_id = "tenant-from-config" + config_manager.save_config(cfg) + + async with get_client(project_name="research") as client: + assert client.headers.get("X-Workspace-ID") == "tenant-from-config" + + +@pytest.mark.asyncio +async def test_get_client_auto_resolves_workspace_from_default(config_manager): + """get_client falls back to default_workspace when project has no workspace_id.""" + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + cfg.set_project_mode("research", ProjectMode.CLOUD) + cfg.default_workspace = "default-tenant-456" + config_manager.save_config(cfg) + + async with get_client(project_name="research") as client: + assert client.headers.get("X-Workspace-ID") == "default-tenant-456" + + +@pytest.mark.asyncio +async def test_get_client_explicit_workspace_overrides_config(config_manager): + """Explicit workspace param takes priority over project config.""" + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + cfg.set_project_mode("research", ProjectMode.CLOUD) + cfg.projects["research"].workspace_id = "tenant-from-config" + config_manager.save_config(cfg) + + async with get_client(project_name="research", workspace="explicit-tenant") as client: + assert client.headers.get("X-Workspace-ID") == "explicit-tenant" + + @pytest.mark.asyncio async def test_get_cloud_control_plane_client_raises_without_credentials(config_manager): cfg = config_manager.load_config()