Bug 303657. Fix for session management of Vista. Admin users have 2 LUIDs. Make these the same userId

This commit is contained in:
Jim Norman 2007-08-23 16:58:51 +00:00
parent 2c9bdbb3c1
commit 4cb470084c
5 changed files with 81 additions and 18 deletions

View File

@ -189,9 +189,9 @@ namespace AppModule.NamedPipes {
return (int)this.Handle.Handle; return (int)this.Handle.Handle;
} }
} }
public int GetLocalUserID(ref int lowPart, ref int highPart, ref string sSIDString) public int GetLocalUserID(ref int lowPart, ref int highPart, ref string sSIDString, ref int lowPartElevated, ref int highPartElevated)
{ {
return ImpersonateWrapper.GetLocalUserID(this.Handle, ref lowPart, ref highPart, ref sSIDString); return ImpersonateWrapper.GetLocalUserID(this.Handle, ref lowPart, ref highPart, ref sSIDString, ref lowPartElevated, ref highPartElevated);
} }
} }
} }

View File

@ -181,10 +181,11 @@ namespace AppModule.NamedPipes
linkedToken = (ImpersonateNative.TOKEN_LINKED_TOKEN)Marshal.PtrToStructure(ptrLinkedToken, typeof(ImpersonateNative.TOKEN_LINKED_TOKEN)); linkedToken = (ImpersonateNative.TOKEN_LINKED_TOKEN)Marshal.PtrToStructure(ptrLinkedToken, typeof(ImpersonateNative.TOKEN_LINKED_TOKEN));
} }
} }
catch (OutOfMemoryException e) catch (Exception e)
{ {
System.Diagnostics.Trace.WriteLine(e.ToString()); System.Diagnostics.Trace.WriteLine(e.ToString());
} }
finally finally
{ {
if (ptrLinkedToken != IntPtr.Zero) if (ptrLinkedToken != IntPtr.Zero)
@ -197,7 +198,7 @@ namespace AppModule.NamedPipes
return TokenInfoSuccess; return TokenInfoSuccess;
} }
public static int GetLocalUserID(PipeHandle handle, ref int lowPart, ref int highPart, ref string SidString) public static int GetLocalUserID(PipeHandle handle, ref int lowPart, ref int highPart, ref string SidString, ref int lowPartElevated, ref int highPartElevated)
{ {
int rcode = -1; int rcode = -1;
// get client userID // get client userID
@ -214,6 +215,7 @@ namespace AppModule.NamedPipes
IntPtr hThread = ImpersonateNative.GetCurrentThread(); IntPtr hThread = ImpersonateNative.GetCurrentThread();
uint iDesiredInfo = 24; //TOKEN_QUERY | TOKEN_QUERY_SOURCE; uint iDesiredInfo = 24; //TOKEN_QUERY | TOKEN_QUERY_SOURCE;
IntPtr userToken = Marshal.AllocHGlobal(4); IntPtr userToken = Marshal.AllocHGlobal(4);
IntPtr userTokenElevated = IntPtr.Zero;
if (ImpersonateNative.OpenThreadToken(hThread, iDesiredInfo, true, out userToken)) if (ImpersonateNative.OpenThreadToken(hThread, iDesiredInfo, true, out userToken))
{ {
@ -225,7 +227,7 @@ namespace AppModule.NamedPipes
// on Vista use the elevated token if there is one. // on Vista use the elevated token if there is one.
System.OperatingSystem os = System.Environment.OSVersion; System.OperatingSystem os = System.Environment.OSVersion;
System.Diagnostics.Trace.WriteLine("OS Version: {0}", os.Version.ToString()); System.Diagnostics.Trace.WriteLine("OS Version: " + os.Version.ToString());
if (os.Version.Major > 5) if (os.Version.Major > 5)
{ {
if (ImpersonateNative.GetTokenInformation(userToken, ImpersonateNative.TOKEN_INFORMATION_CLASS.TokenElevationType, tu, cb, ref cb)) if (ImpersonateNative.GetTokenInformation(userToken, ImpersonateNative.TOKEN_INFORMATION_CLASS.TokenElevationType, tu, cb, ref cb))
@ -233,20 +235,24 @@ namespace AppModule.NamedPipes
int iTokenType; int iTokenType;
iTokenType = (int)Marshal.PtrToStructure(tu, typeof(int)); iTokenType = (int)Marshal.PtrToStructure(tu, typeof(int));
System.Diagnostics.Trace.WriteLine("Token Type : {0}", iTokenType.ToString()); System.Diagnostics.Trace.WriteLine("Token Type: " + iTokenType.ToString());
if (iTokenType == 3) //.ToString().Equals(ImpersonateNative.TOKEN_ELEVATION_TYPE.TokenElevationTypeLimited)) if (iTokenType == (int)ImpersonateNative.TOKEN_ELEVATION_TYPE.TokenElevationTypeLimited)
{ {
System.Diagnostics.Trace.WriteLine("Getting linked token");
ImpersonateNative.TOKEN_LINKED_TOKEN newLinkedToken; ImpersonateNative.TOKEN_LINKED_TOKEN newLinkedToken;
if (GetLinkedToken(userToken, out newLinkedToken)) if (GetLinkedToken(userToken, out newLinkedToken))
{ {
userToken = newLinkedToken.LinkedToken; //userToken = newLinkedToken.LinkedToken;
userTokenElevated = Marshal.AllocHGlobal(4);
userTokenElevated = newLinkedToken.LinkedToken;
} }
} }
} }
else else
{ {
uint error = ImpersonateNative.GetLastError(); uint error = ImpersonateNative.GetLastError();
System.Diagnostics.Trace.WriteLine("linked token error: {0}", error.ToString()); System.Diagnostics.Trace.WriteLine("linked token error: " + error.ToString());
} }
} }
@ -273,6 +279,20 @@ namespace AppModule.NamedPipes
highPart = stats.AuthenticationId.HighPart; highPart = stats.AuthenticationId.HighPart;
rcode = -1; rcode = -1;
} }
// get elevated token stats
if (userTokenElevated != IntPtr.Zero)
{
cb = bufLength;
if (ImpersonateNative.GetTokenInformation(userTokenElevated, ImpersonateNative.TOKEN_INFORMATION_CLASS.TokenStatistics, tu, cb, ref cb))
{
stats = (ImpersonateNative.TOKEN_STATISTICS)Marshal.PtrToStructure(tu, typeof(ImpersonateNative.TOKEN_STATISTICS));
// copy low and high part
lowPartElevated = stats.AuthenticationId.LowPart;
highPartElevated = stats.AuthenticationId.HighPart;
rcode = -1;
}
}
} }
else else
{ {
@ -295,6 +315,10 @@ namespace AppModule.NamedPipes
} }
Marshal.FreeHGlobal(userToken); Marshal.FreeHGlobal(userToken);
if (userTokenElevated != IntPtr.Zero)
{
Marshal.FreeHGlobal(userTokenElevated);
}
} }
catch (Exception ex) catch (Exception ex)
{ {

View File

@ -410,6 +410,11 @@ NPLogonNotify (
extension.version = 0x00010000; // 1.0.0 extension.version = 0x00010000; // 1.0.0
extension.ext = (void *)lpLogonId; extension.ext = (void *)lpLogonId;
#ifdef _DEBUG
DebugPrint("Setting credential for: %d (high) %d (low)\r\n", lpLogonId->HighPart, lpLogonId->LowPart);
#endif
ccode = (*pCASASetCredential)( ccode = (*pCASASetCredential)(
0, 0,
&desktopCredential, &desktopCredential,

View File

@ -28,8 +28,24 @@ namespace sscs.common
{ {
private int uidLow; private int uidLow;
private int uidHigh; private int uidHigh;
private int elevatedUidLow = 0;
private int elevatedUidHigh = 0;
private string m_sSID = ""; private string m_sSID = "";
internal WinUserIdentifier(int uidLowPart, int uidHighPart, string sSID, int elevatedUidLow, int elevatedUidHigh)
{
this.uidLow = uidLowPart;
this.uidHigh = uidHighPart;
this.m_sSID = sSID;
if (elevatedUidLow != null)
this.elevatedUidLow = elevatedUidLow;
if (elevatedUidHigh != null)
this.elevatedUidHigh = elevatedUidHigh;
}
internal WinUserIdentifier(int uidLowPart, int uidHighPart, string sSID) internal WinUserIdentifier(int uidLowPart, int uidHighPart, string sSID)
{ {
this.uidLow = uidLowPart; this.uidLow = uidLowPart;
@ -52,7 +68,9 @@ namespace sscs.common
public override bool Equals(Object obj) public override bool Equals(Object obj)
{ {
WinUserIdentifier u = (WinUserIdentifier)obj; WinUserIdentifier u = (WinUserIdentifier)obj;
if ((u.uidLow == uidLow) && (u.uidHigh == uidHigh)) if (((u.uidLow == uidLow) && (u.uidHigh == uidHigh)) ||
((u.uidLow == elevatedUidLow) && (u.uidHigh == elevatedUidHigh)) ||
((u.elevatedUidLow == uidLow) && (u.elevatedUidHigh == uidHigh)))
{ {
// we have a match, set the SID if we can // we have a match, set the SID if we can
if ((this.m_sSID.Length < 1) && (u.GetSID().Length>0)) if ((this.m_sSID.Length < 1) && (u.GetSID().Length>0))

View File

@ -68,6 +68,8 @@ namespace sscs.communication
{ {
int localUserIDLow = 0; int localUserIDLow = 0;
int localUserIDHigh = 0; int localUserIDHigh = 0;
int localUserIDLowElevated = 0;
int localUserIDHighElevated = 0;
string sSIDString = ""; string sSIDString = "";
byte[] incoming = null; byte[] incoming = null;
@ -77,12 +79,26 @@ namespace sscs.communication
incoming = m_serverPipeConnection.ReadBytes(); incoming = m_serverPipeConnection.ReadBytes();
// get local Userid and SID // get local Userid and SID
m_serverPipeConnection.GetLocalUserID(ref localUserIDLow, ref localUserIDHigh, ref sSIDString); m_serverPipeConnection.GetLocalUserID(ref localUserIDLow,
ref localUserIDHigh,
ref sSIDString,
ref localUserIDLowElevated,
ref localUserIDHighElevated);
if (localUserIDLow != 0 || localUserIDHigh !=0) if (localUserIDLowElevated != 0 || localUserIDHighElevated != 0)
{ {
userId = new WinUserIdentifier(localUserIDLow, localUserIDHigh, sSIDString); if (localUserIDLow != 0 || localUserIDHigh != 0)
} {
userId = new WinUserIdentifier(localUserIDLow, localUserIDHigh, sSIDString, localUserIDLowElevated, localUserIDHighElevated);
}
}
else
{
if (localUserIDLow != 0 || localUserIDHigh != 0)
{
userId = new WinUserIdentifier(localUserIDLow, localUserIDHigh, sSIDString);
}
}
return incoming; return incoming;
} }